[洛谷U40581]树上统计treecnt
题目大意:
给定一棵\(n(n\le10^5)\)个点的树。
定义\(Tree[l,r]\)表示为了使得\(l\sim r\)号点两两连通,最少需要选择的边的数量。
求\(\sum_{l=1}^n\sum_{r=l}^nTree[l,r]\)。
思路:
对于每个边考虑贡献,若我们将出现在子树内的点记作\(1\),出现在子树外的点记作\(0\),那么答案就是\(\frac{n(n-1)}2-\)全\(0\)、全\(1\)串的个数。线段树合并,维护前缀/后缀最长全\(0\)/全\(1\)串即可。
时间复杂度\(\mathcal O(n\log n)\)。
源代码:
#include<cstdio>
#include<cctype>
#include<vector>
inline int getint() {
register char ch;
while(!isdigit(ch=getchar()));
register int x=ch^'0';
while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
return x;
}
typedef long long int64;
const int N=1e5+1,logN=18;
int n;
int64 ans;
std::vector<int> e[N];
inline void add_edge(const int &u,const int &v) {
e[u].push_back(v);
e[v].push_back(u);
}
inline int64 calc(const int &n) {
return 1ll*n*(n-1)/2;
}
struct Node {
int pre[2],suf[2],len;
int64 sum;
Node() {}
Node(const int &l,const bool &v) {
pre[!v]=suf[!v]=0;
pre[v]=suf[v]=len=l;
sum=calc(l);
}
friend Node operator + (const Node &l,const Node &r) {
Node ret;
ret.pre[0]=l.pre[0]+r.pre[0]*(l.pre[0]==l.len);
ret.pre[1]=l.pre[1]+r.pre[1]*(l.pre[1]==l.len);
ret.suf[0]=r.suf[0]+l.suf[0]*(r.suf[0]==r.len);
ret.suf[1]=r.suf[1]+l.suf[1]*(r.suf[1]==r.len);
ret.len=l.len+r.len;
ret.sum=l.sum+r.sum+1ll*l.suf[0]*r.pre[0]+1ll*l.suf[1]*r.pre[1];
return ret;
}
};
class SegmentTree {
#define mid ((b+e)>>1)
private:
Node node[N*logN];
int left[N*logN],right[N*logN];
int sz,new_node() {
return ++sz;
}
int len(const int &b,const int &e) {
return e-b+1;
}
void push_up(const int &p,const int &b,const int &e) {
if(!left[p]) node[p]=Node(len(b,mid),0)+node[right[p]];
if(!right[p]) node[p]=node[left[p]]+Node(len(mid+1,e),0);
if(left[p]&&right[p]) {
node[p]=node[left[p]]+node[right[p]];
}
}
public:
int root[N];
void insert(int &p,const int &b,const int &e,const int &x) {
if(!p) p=new_node();
if(b==e) {
node[p]=Node(1,1);
return;
}
if(x<=mid) insert(left[p],b,mid,x);
if(x>mid) insert(right[p],mid+1,e,x);
push_up(p,b,e);
}
void merge(int &p,const int &q,const int &b,const int &e) {
if(!p||!q) {
p=p|q;
return;
}
if(b==e) return;
merge(left[p],left[q],b,mid);
merge(right[p],right[q],mid+1,e);
push_up(p,b,e);
}
int64 query(const int &p) const {
return node[p].sum;
}
#undef mid
};
SegmentTree t;
void dfs(const int &x,const int &par) {
t.insert(t.root[x],1,n,x);
for(auto &y:e[x]) {
if(y==par) continue;
dfs(y,x);
t.merge(t.root[x],t.root[y],1,n);
}
if(x!=1) ans-=t.query(t.root[x]);
}
int main() {
n=getint();
ans=calc(n)*(n-1);
for(register int i=1;i<n;i++) {
add_edge(getint(),getint());
}
dfs(1,0);
printf("%lld\n",ans);
return 0;
}