我对段树很陌生,并希望通过对段树做更多的练习来让自己忙碌起来。

问题实际上更像 ACM 并且具有以下条件:
有n个数字和m个操作,n,m 1.通过减去一个数字x更新一个区间,x每次都可以不同
2.查询一个区间,找出区间内有多少个数字
在这里构建段树和更新显然可以在 O(nlog n)/O(log n) 中完成
但是我不知道如何在 O(log n) 中进行查询,谁能给我一些建议/提示?
任何的意见都将会有帮助!谢谢!

TL;DR:

给定 n 个数字和 2 种类型的操作:

  • 将 x 添加到 [a,b] 中的所有元素,x 每次都可以不同
  • 查询[a,b]中元素个数
    如何使操作 1 和 2 都可以在 O(log n) 中完成?

    最佳答案

    好问题:)

    我想了一会儿,但仍然无法解决段树的这个问题,但我已经尝试使用“ Bucket Method ”来解决这个问题。

    我们可以将初始的 n 个数字分成 B 个桶,对每个桶中的数字进行排序,并保持每个桶中的总加值。然后对于每个查询:

  • "添加"更新间隔 [a, b] 与 c

    我们最多只需要重建两个bucket并将c添加到(b - a)/BUCKET_SIZE buckets
  • "Query"查询间隔 [a, b]
    我们只需要一一扫描每个值最多两个桶,然后快速通过 (b-a)/BUCKET_SIZE 桶进行二分查找

  • 对于每个查询,它应该在 O( N/BUCKET_SIZE * log(BUCKET_SIZE, 2)) 中运行,这比蛮力方法 (O(N)) 小。虽然它比 O(logN) 大,但在大多数情况下可能就足够了。

    下面是测试代码:
    #include <iostream>
    #include <cstdio>
    #include <cstdlib>
    #include <string>
    #include <cstring>
    #include <cmath>
    #include <algorithm>
    #include <vector>
    #include <set>
    #include <map>
    #include <ctime>
    #include <cassert>
    
    using namespace std;
    
    struct Query {
        //A a b c  add c in [a, b] of arr
        //Q a b c  Query number of i in [a, b] which arr[i] <= c
        char ty;
        int a, b, c;
        Query(char _ty, int _a, int _b, int _c):ty(_ty), a(_a), b(_b), c(_c){}
    };
    
    int n, m;
    vector<int> arr;
    vector<Query> queries;
    
    vector<int> bruteforce() {
        vector<int> ret;
        vector<int> numbers = arr;
        for (int i = 0; i < m; i++) {
            Query q = queries[i];
            if (q.ty == 'A') {
                for (int i = q.a; i <= q.b; i++) {
                    numbers[i] += q.c;
                }
                ret.push_back(-1);
            } else {
                int tmp = 0;
                for(int i = q.a; i <= q.b; i++) {
                    tmp += numbers[i] <= q.c;
                }
                ret.push_back(tmp);
            }
        }
        return ret;
    }
    
    struct Bucket {
        vector<int> numbers;
        vector<int> numbers_sorted;
        int add;
        Bucket() {
            add = 0;
            numbers_sorted.clear();
            numbers.clear();
        }
        int query(int pos) {
            return numbers[pos] + add;
        }
        void add_pos(int pos, int val) {
            numbers[pos] += val;
        }
        void build() {
            numbers_sorted = numbers;
            sort(numbers_sorted.begin(), numbers_sorted.end());
        }
    };
    
    vector<int> bucket_count(int bucket_size) {
        vector<int> ret;
    
        vector<Bucket> buckets;
        buckets.resize(int(n / bucket_size) + 5);
        for (int i = 0; i < n; i++) {
            buckets[i / bucket_size].numbers.push_back(arr[i]);
        }
    
        for (int i = 0; i <= n / bucket_size; i++) {
            buckets[i].build();
        }
    
        for (int i = 0; i < m; i++) {
            Query q = queries[i];
            char ty = q.ty;
            int a, b, c;
            a = q.a, b = q.b, c = q.c;
            if (ty == 'A') {
                set<int> affect_buckets;
                while (a < b && a % bucket_size != 0) buckets[a/ bucket_size].add_pos(a % bucket_size, c), affect_buckets.insert(a/bucket_size), a++;
                while (a < b && b % bucket_size != 0) buckets[b/ bucket_size].add_pos(b % bucket_size, c), affect_buckets.insert(b/bucket_size), b--;
                while (a < b) {
                    buckets[a/bucket_size].add += c;
                    a += bucket_size;
                }
                buckets[a/bucket_size].add_pos(a % bucket_size, c), affect_buckets.insert(a / bucket_size);
                for (set<int>::iterator it = affect_buckets.begin(); it != affect_buckets.end(); it++) {
                    int id = *it;
                    buckets[id].build();
                }
                ret.push_back(-1);
            } else {
                int tmp = 0;
                while (a < b && a % bucket_size != 0) tmp += (buckets[a/ bucket_size].query(a % bucket_size) <=c), a++;
                while (a < b && b % bucket_size != 0) tmp += (buckets[b/ bucket_size].query(b % bucket_size) <=c), b--;
                while (a < b) {
                    int pos = a / bucket_size;
                    tmp += upper_bound(buckets[pos].numbers_sorted.begin(), buckets[pos].numbers_sorted.end(), c - buckets[pos].add) - buckets[pos].numbers_sorted.begin();
                    a += bucket_size;
                }
                tmp += (buckets[a / bucket_size].query(a % bucket_size) <= c);
                ret.push_back(tmp);
            }
        }
    
        return ret;
    }
    
    void process(int cas) {
    
        clock_t begin_t=clock();
    
        vector<int> bf_ans = bruteforce();
        clock_t  bf_end_t =clock();
        double bf_sec = ((1.0 * bf_end_t - begin_t)) / CLOCKS_PER_SEC;
    
        //bucket_size is important
        int bucket_size = 200;
        vector<int> ans = bucket_count(bucket_size);
    
        clock_t  bucket_end_t =clock();
        double bucket_sec = ((1.0 * bucket_end_t - bf_end_t)) / CLOCKS_PER_SEC;
    
        bool correct = true;
        for (int i = 0; i < ans.size(); i++) {
            if (ans[i] != bf_ans[i]) {
                cout << "query " << i + 1 << " bf = " << bf_ans[i] << " bucket  = " << ans[i] << "  bucket size = " <<  bucket_size << " " << n << " " << m <<  endl;
                correct = false;
            }
        }
        printf("Case #%d:%s bf_sec = %.9lf, bucket_sec = %.9lf\n", cas, correct ? "YES":"NO", bf_sec, bucket_sec);
    }
    
    void read() {
        cin >> n >> m;
        arr.clear();
        for (int i = 0; i < n; i++) {
            int val;
            cin >> val;
            arr.push_back(val);
        }
        queries.clear();
        for (int i = 0; i < m; i++) {
            char ty;
            int a, b, c;
            // a, b, c in [0, n - 1], a <= b
            cin >> ty >> a >> b >> c;
            queries.push_back(Query(ty, a, b, c));
        }
    }
    
    void run(int cas) {
        read();
        process(cas);
    }
    
    int main() {
        freopen("bucket.in", "r", stdin);
        //freopen("bucket.out", "w", stdout);
        int T;
        scanf("%d", &T);
        for (int cas  = 1; cas <= T; cas++) {
            run(cas);
        }
        return 0;
    }
    

    这是数据生成代码:
    #coding=utf8
    
    import random
    import math
    
    def gen_buckets(f):
        t = random.randint(10, 20)
        print >> f, t
        nlimit = 100000
        mlimit = 10000
        limit = 100000
        for i in xrange(t):
            n = random.randint(1, nlimit)
            m = random.randint(1, mlimit)
            print >> f, n, m
    
            for i in xrange(n):
                val = random.randint(1, limit)
                print >> f, val ,
            print >> f
            for i in xrange(m):
                ty = random.randint(1, 2)
                a = random.randint(0, n - 1)
                b = random.randint(a, n - 1)
                #a = 0
                #b = n - 1
                c = random.randint(-limit, limit)
                print >> f, 'A' if ty == 1 else 'Q', a, b, c
    
    
    f = open("bucket.in", "w")
    gen_buckets(f)
    

    关于algorithm - (ACM) 如何使用线段树计算 [a,b] 中有多少元素小于给定的常数?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/18800058/

    10-11 21:59