题目大意:
题解:
我们发现这道题可以在后缀树上瞎搞
我们知道:\(LCP(suf(i),suf(j)) = len(lca(i,j))\)
所以我们可以对后缀树上的所有节点dp一下,求出每个点的子树包含的点对数
同时dp出子树中存在的权的最大值,次大值,最小值,次小值
然后累加答案即可.
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
inline void read(int &x){
x=0;char ch;bool flag = false;
while(ch=getchar(),ch<'!');if(ch == '-') ch = getchar(),flag = true;
while(x=10*x+ch-'0',ch=getchar(),ch>'!');if(flag) x=-x;
}
const int maxn = 1000010;
struct Edge{
int to,next;
}G[maxn];
int head[maxn],cnt;
void add(int u,int v){
G[++cnt].to = v;
G[cnt].next = head[u];
head[u] = cnt;
}
struct Node{
int nx[26];
int len,fa;
}T[maxn];
int last,nodecnt = 0,n;
int a[maxn],siz[maxn],mx[maxn],cmx[maxn];
int mn[maxn],cmn[maxn];
inline void insert(char cha,int i){
int c = cha - 'a',cur = ++ nodecnt,p;
T[cur].len = T[last].len + 1;
for(p = last;p != -1 && !T[p].nx[c];p = T[p].fa) T[p].nx[c] = cur;
if(p == -1) T[cur].fa = 0;
else{
int q = T[p].nx[c];
if(T[q].len == T[p].len + 1) T[cur].fa = q;
else{
int co = ++ nodecnt;T[co] = T[q];T[co].len = T[p].len + 1;
for(;p != -1 && T[p].nx[c] == q;p = T[p].fa) T[p].nx[c] = co;
T[cur].fa = T[q].fa = co;
}
}
siz[last = cur]++;
mx[cur] = mn[cur] = a[i];
}
ll num[maxn];
char s[maxn];
ll ans1[maxn],ans2[maxn];
inline void update(int &x,int &y,int z){
if(z >= x) y = x,x = z;
else if(z >= y) y = z;
}
inline void downpdate(int &x,int &y,int z){
if(z <= x) y = x,x = z;
else if(z <= y) y = z;
}
#define v G[i].to
void dfs(int u,int fa){
for(int i = head[u];i;i = G[i].next){
if(v == fa) continue;
dfs(v,u);
num[u] += 1LL*siz[u]*siz[v];
siz[u] += siz[v];
update(mx[u],cmx[u],mx[v]);update(mx[u],cmx[u],cmx[v]);
downpdate(mn[u],cmn[u],mn[v]);
downpdate(mn[u],cmn[u],cmn[v]);
}
if(mx[u] != mx[maxn-1] && cmx[u] != cmx[maxn-1]){
ans2[T[u].len] = max(ans2[T[u].len],max(1LL*mx[u]*cmx[u],1LL*mn[u]*cmn[u]));
}
ans1[T[u].len] += num[u];
}
#undef v
int main(){
memset(mx,-0x3f,sizeof mx);memset(cmx,-0x3f,sizeof cmx);
memset(mn,0x3f,sizeof mn);memset(cmn,0x3f,sizeof cmn);
memset(ans2,-0x3f,sizeof ans2);
T[last = nodecnt = 0].fa = -1;
read(n);scanf("%s",s);
for(int i=0;i<n;++i) read(a[i]);
reverse(s,s+n);reverse(a,a+n);
for(int i=0;i<n;++i) insert(s[i],i);
for(int i=1;i<=nodecnt;++i) add(T[i].fa,i);
dfs(0,0);
for(int i=n-2;i>=0;--i){
ans1[i] += ans1[i+1];
ans2[i] = max(ans2[i],ans2[i+1]);
}
for(int i=0;i<n;++i){
if(ans2[i] == ans2[maxn-1]) ans2[i] = 0;
printf("%lld %lld\n",ans1[i],ans2[i]);
}
getchar();getchar();
return 0;
}