Algorithm/Baekjoon

[백준 7469 : JAVA] K번째 수 / MergeSort Tree

팡트루야 2021. 4. 2. 22:17

문제


 

 

 

풀이


 

어떤 식으로 K번째 수를 찾을 수 있을까 고민을 많이 했던 문제다...

 

i~j 구간에서 k번째 수가 X라고 하자. 우리는 이 X가 뭔지 모르기 때문에 찾아나가야 한다.

MergeSort Tree 특성상 각 노드가 담당하는  구간은 정렬된 상태의 배열이기 때문에 k번째 수인 X보다 작은 수의 갯수는 k-1개가 된다.

 

문제의 조건상 X가 될 수 있는 값의 범위는 $-10^{9}$~$10^{9}$ 이니까, 이를 토대로 이분탐색을 진행해나간다.

즉, MergeSort Tree의 쿼리 함수가 담당하는 기능은 해당 구간에서 파라미터로 들어온 값보다 작은 수의 갯수이다.

 

[1, 2, 3, 8] 수열이 있고, 2~4 구간에서 3번째 값을 찾는다고 했을 때, X가 될 수 있는 후보는 4~8이다. 즉, X보다 작은 수의 갯수가 k-1개가 되는 수 중 가장 큰 수가 답이 된다.

 

추가로 upperBound() 함수도 주의깊게 봐둬야할 필요가 있다. (C++의 lower_bound(), upper_bound() 함수에서 유례)

upper_bound() : 해당 배열에서 X보다 큰 첫 번째 원소의 인덱스 반환한다.

lower_bound() : 해당 배열에서 X보다 크거나 같은 첫 번째 원소의 인덱스를 반환한다.

 

참고로, upper_bound() - lower_bound() 로 배열에서 값이 같은 원소의 갯수도 구할 수 있다.

 

 

 

코드


import java.io.*;
import java.util.*;

// 특정 구간의 정렬된 수열에서 k번째 수 찾기.
public class Main {

    private static int N, M;
    private static int[] nums;
    private static MergeSortTree tree;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringBuilder sb = new StringBuilder();
        StringTokenizer st = new StringTokenizer(br.readLine());
        N = Integer.parseInt(st.nextToken());
        M = Integer.parseInt(st.nextToken());

        nums = Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray();
        tree = new MergeSortTree(nums);

        int i, j, k;
        while (M-- > 0) {
            st = new StringTokenizer(br.readLine());
            i = Integer.parseInt(st.nextToken());
            j = Integer.parseInt(st.nextToken());
            k = Integer.parseInt(st.nextToken());

            // k번째 수가 X라면 i~j 구간에서 X보다 작은 수의 갯수는 k-1개다.
            // 가능한 X의 범위는 [-10^9, 10^9] 이다. X를 이분탐색으로 찾아나간다.
            int left = (int) -1e9, right = (int) 1e9;
            while (left <= right) {
                int mid = (left + right) / 2;
                int result = tree.query(mid, i, j, 1, 1, N);
                if (result < k) {
                    left = mid + 1;
                } else {
                    right = mid - 1;
                }
            }
            sb.append(left).append("\n");
        }
        System.out.println(sb.toString());
    }

    private static class MergeSortTree {

        int[] arr;
        List<Integer>[] tree;

        public MergeSortTree(int[] arr) {
            this.arr = arr;
            int k = (int) Math.ceil(Math.log(N) / Math.log(2));
            int height = k + 1;
            int size = (int) Math.pow(2, height);
            tree = new List[size];
            init(1, 1, N);
        }

        public List<Integer> init(int node, int start, int end) {
            if (start == end) {
                tree[node] = new ArrayList<>();
                tree[node].add(arr[start - 1]);
                return tree[node];
            }
            int mid = (start + end) / 2;
            return tree[node] = merge(init(node * 2, start, mid), init(node * 2 + 1, mid + 1, end));
        }

        public List<Integer> merge(List<Integer> left, List<Integer> right) {
            List<Integer> result = new ArrayList<>();
            int i = 0, j = 0;
            while (i < left.size() && j < right.size()) {
                if (left.get(i) <= right.get(j)) {
                    result.add(left.get(i++));
                } else {
                    result.add(right.get(j++));
                }
            }
            while (i < left.size()) {
                result.add(left.get(i++));
            }
            while (j < right.size()) {
                result.add(right.get(j++));
            }
            return result;
        }

        public int query(int x, int left, int right, int node, int start, int end) {
            if (right < start || end < left) return 0;
            if (left <= start && end <= right) return upperBound(tree[node], x);
            int mid = (start + end) / 2;
            return query(x, left, right, node * 2, start, mid) + query(x, left, right, node * 2 + 1, mid + 1, end);
        }

        // x보다 큰 원소가 나오는 첫 번째 위치를 반환한다.
        public int upperBound(List<Integer> node, int x) {
            int length = node.size();
            int left = 0, right = length - 1, mid = 0;
            while (left < right) {
                if (node.get(mid) <= x) left = mid + 1;
                else right = mid;
                mid = (left + right) / 2;
                if (mid == right) {
                    if (node.get(mid) <= x) return length;
                    else return right;
                }
            }
            if (node.get(left) > x) return 0;
            return left + 1;
        }
    }
}