我得到了两个积分,第一个是线段的数目(Xi,Xj),第二个是可以或不能在线段内的点的数目。
例如,输入可以是:

2 3
0 5
8 10
1 6 11

其中,在第一行中,2表示“2段”,3表示“3点”。
这两段是“0到5”和“8到10”,要查找的点是1、6、11。
输出是
1 0 0

其中点1在段“0到5”中,点6和11不在任何段中。如果一个点出现在多个段中,如3,则输出为2。
最初的代码,只是一个双循环来搜索段之间的点我使用Java数组快速排序(在对段的端点进行排序时进行了修改,也对startpoints进行排序,使start[I]和end[I]属于同一段I)来提高双循环的速度,但这还不够。
下一个代码运行良好,但当段太多时,它会变得非常慢:
public class PointsAndSegments {

    private static int[] fastCountSegments(int[] starts, int[] ends, int[] points) {
        sort(starts, ends);
        int[] cnt2 = CountSegments(starts,ends,points);
        return cnt2;
    }

    private static void dualPivotQuicksort(int[] a, int[] b, int left,int right, int div) {
    int len = right - left;
    if (len < 27) { // insertion sort for tiny array
        for (int i = left + 1; i <= right; i++) {
            for (int j = i; j > left && b[j] < b[j - 1]; j--) {
                swap(a, b, j, j - 1);
            }
        }
        return;
    }
    int third = len / div;
    // "medians"
    int m1 = left  + third;
    int m2 = right - third;
    if (m1 <= left) {
        m1 = left + 1;
    }
    if (m2 >= right) {
        m2 = right - 1;
    }
    if (a[m1] < a[m2]) {
        swap(a, b, m1, left);
        swap(a, b, m2, right);
    }
    else {
        swap(a, b, m1, right);
        swap(a, b, m2, left);
    }
    // pivots
    int pivot1 = b[left];
    int pivot2 = b[right];
    // pointers
    int less  = left  + 1;
    int great = right - 1;
    // sorting
    for (int k = less; k <= great; k++) {
        if (b[k] < pivot1) {
            swap(a, b, k, less++);
        }
        else if (b[k] > pivot2) {
            while (k < great && b[great] > pivot2) {
                great--;
            }
            swap(a, b, k, great--);
            if (b[k] < pivot1) {
                swap(a, b, k, less++);
            }
        }
    }
    // swaps
    int dist = great - less;
    if (dist < 13) {
       div++;
    }
    swap(a, b, less  - 1, left);
    swap(a, b, great + 1, right);
    // subarrays
    dualPivotQuicksort(a, b, left,   less - 2, div);
    dualPivotQuicksort(a, b, great + 2, right, div);

    // equal elements
    if (dist > len - 13 && pivot1 != pivot2) {
        for (int k = less; k <= great; k++) {
            if (b[k] == pivot1) {
                swap(a, b, k, less++);
            }
            else if (b[k] == pivot2) {
                swap(a, b, k, great--);
                if (b[k] == pivot1) {
                    swap(a, b, k, less++);
                }
            }
        }
    }
    // subarray
    if (pivot1 < pivot2) {
        dualPivotQuicksort(a, b, less, great, div);
    }
    }

    public static void sort(int[] a, int[] b) {
        sort(a, b, 0, b.length);
    }

    public static void sort(int[] a, int[] b, int fromIndex, int toIndex) {
        rangeCheck(a.length, fromIndex, toIndex);
        dualPivotQuicksort(a, b, fromIndex, toIndex - 1, 3);
    }

    private static void rangeCheck(int length, int fromIndex, int toIndex) {
        if (fromIndex > toIndex) {
            throw new IllegalArgumentException("fromIndex > toIndex");
        }
        if (fromIndex < 0) {
            throw new ArrayIndexOutOfBoundsException(fromIndex);
        }
        if (toIndex > length) {
            throw new ArrayIndexOutOfBoundsException(toIndex);
        }
    }

    private static void swap(int[] a, int[] b, int i, int j) {
        int swap1 = a[i];
        int swap2 = b[i];
        a[i] = a[j];
        b[i] = b[j];
        a[j] = swap1;
        b[j] = swap2;
    }

    private static int[] naiveCountSegments(int[] starts, int[] ends, int[] points) {
        int[] cnt = new int[points.length];
        for (int i = 0; i < points.length; i++) {
            for (int j = 0; j < starts.length; j++) {
                if (starts[j] <= points[i] && points[i] <= ends[j]) {
                    cnt[i]++;
                }
            }
        }
        return cnt;
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n, m;
        n = scanner.nextInt();
        m = scanner.nextInt();
        int[] starts = new int[n];
        int[] ends = new int[n];
        int[] points = new int[m];
        for (int i = 0; i < n; i++) {
            starts[i] = scanner.nextInt();
            ends[i] = scanner.nextInt();
        }
        for (int i = 0; i < m; i++) {
            points[i] = scanner.nextInt();
        }
        //use fastCountSegments
        int[] cnt = fastCountSegments(starts, ends, points);
        for (int x : cnt) {
            System.out.print(x + " ");
        }
    }

我相信问题出在CountSegments()方法中,但我不确定是否有其他方法来解决它按理说,我应该使用分而治之的算法,但4天后,我就可以找到任何解决方案了。
我发现a similar problem in CodeForces但输出不同,大多数解决方案都是C++。因为我刚开始学java三个月,我想我已经达到了我的知识极限。

最佳答案

这是一个类似于@Shole想法的实现:

public class SegmentsAlgorithm {

    private PriorityQueue<int[]> remainSegments = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[0], o1[0]));
    private SegmentWeight[] arraySegments;

    public void addSegment(int begin, int end) {
        remainSegments.add(new int[]{begin, end});
    }

    public void prepareArrayCache() {
        List<SegmentWeight> preCalculate = new ArrayList<>();
        PriorityQueue<int[]> currentSegmentsByEnds = new PriorityQueue<>((o0, o1) -> Integer.compare(o0[1], o1[1]));
        int begin = remainSegments.peek()[0];
        while (!remainSegments.isEmpty() && remainSegments.peek()[0] == begin) {
            currentSegmentsByEnds.add(remainSegments.poll());
        }
        preCalculate.add(new SegmentWeight(begin, currentSegmentsByEnds.size()));
        int next;
        while (!remainSegments.isEmpty()) {
            if (currentSegmentsByEnds.isEmpty()) {
                next = remainSegments.peek()[0];
            } else {
                next = Math.min(currentSegmentsByEnds.peek()[1], remainSegments.peek()[0]);
            }
            while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) {
                currentSegmentsByEnds.poll();
            }
            while (!remainSegments.isEmpty() && remainSegments.peek()[0] == next) {
                currentSegmentsByEnds.add(remainSegments.poll());
            }
            preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
        }
        while (!currentSegmentsByEnds.isEmpty()) {
            next = currentSegmentsByEnds.peek()[1];
            while (!currentSegmentsByEnds.isEmpty() && currentSegmentsByEnds.peek()[1] == next) {
                currentSegmentsByEnds.poll();
            }
            preCalculate.add(new SegmentWeight(next, currentSegmentsByEnds.size()));
        }
        SegmentWeight[] arraySearch = new SegmentWeight[preCalculate.size()];
        int i = 0;
        for (SegmentWeight l : preCalculate) {
            arraySearch[i++] = l;
        }
        this.arraySegments = arraySearch;
    }

    public int searchPoint(int p) {
        int result = 0;
        if (arraySegments != null && arraySegments.length > 0 && arraySegments[0].begin <= p) {
            int index = Arrays.binarySearch(arraySegments, new SegmentWeight(p, 0), (o0, o1) -> Integer.compare(o0.begin, o1.begin));
            if (index < 0){  // Bug fixed
                index = - 2 - index;
            }
            if (index >= 0 && index < arraySegments.length) { // Protection added
                result = arraySegments[index].weight;
            }
        }
        return result;
    }

    public static void main(String[] args) {
        SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
        int[][] segments = {{0, 5},{3, 10},{8, 9},{14, 20},{12, 28}};
        for (int[] segment : segments) {
            algorithm.addSegment(segment[0], segment[1]);
        }
        algorithm.prepareArrayCache();

        int[] points = {-1, 2, 4, 6, 11, 28};

        for (int point: points) {
            System.out.println(point + ": " + algorithm.searchPoint(point));
        }
    }

    public static class SegmentWeight {

        int begin;
        int weight;

        public SegmentWeight(int begin, int weight) {
            this.begin = begin;
            this.weight = weight;
        }
    }
}

它打印:
-1: 0
2: 1
4: 2
6: 1
11: 2
28: 0

编辑:
public static void main(String[] args) {
    SegmentsAlgorithm algorithm = new SegmentsAlgorithm();
    Scanner scanner = new Scanner(System.in);
    int n = scanner.nextInt();
    int m = scanner.nextInt();
    for (int i = 0; i < n; i++) {
        algorithm.addSegment(scanner.nextInt(), scanner.nextInt());
    }
    algorithm.prepareArrayCache();
    for (int i = 0; i < m; i++) {
        System.out.print(algorithm.searchPoint(scanner.nextInt())+ " ");
    }
    System.out.println();
}

关于java - 如何改进此算法以优化运行时间(分段中的查找点),我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/38090501/

10-09 02:49