str2int

\[Time Limit: 3000 ms\quad Memory Limit: 131072 kB
\]

题意

给出 \(n\) 个串,求出这 \(n\) 个串所有子串代表的数字的和。

思路

首先可以把这些串合并起来,串与串之间用没出现过的字符分隔开来,然后构建后缀自动机,因为后缀自动机上从 \(root\) 走到的任意节点都是一个子串,所有可以利用这个性质来做。

一开始我的做法是做 \(dfs\),令 \(dp[i]\) 表示节点 \(i\) 的贡献,转移就是 \(dp[v] = dp[v]+tmp*10+j\),表示从 \(root\) 到 \(u\) 的权值是\(tmp\),\(v\) 是 \(u\) 往 \(j\)走的下一个点。结果显然超时了。

我们发现对于\(dp[u]->dp[v]\)过程,如果之前走到 \(dp[u]\) 的有 \(12\),\(2\) 两步,假设现在往 \(3\) 这条边走,得到 \(12*10+3\),\(2*10+3\),那么其实这些值的贡献是可以一次性计算的,无论之前走到 \(dp[u]\) 的有几条路,都需要让他们全部 \(*10\),而 \(3\) 的贡献则是由走到 \(dp[u]\) 的路径数确定的。

那么我们就可以得到第二个方程:

  1. \(dp1[i]\) 表示节点 \(i\) 的贡献
  2. \(dp2[i]\) 表示之前有多少种方案走到 \(i\)
  3. \(dp1[v] = dp1[v] + dp1[u]*10 + dp2[u]*j\)
  4. \(dp2[v] = dp[2[v] + dp2[v]\)

最后为了去除前导零,只要控制从 \(root\) 出来的边最少从 \(1\) 开始就可以了。

如此计算后,\(\sum dp1[i]\) 就是最后的答案。

#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define lowbit(x) x & (-x)
#define mes(a, b) memset(a, b, sizeof a)
#define fi first
#define se second
#define pii pair<int, int>
#define INOPEN freopen("in.txt", "r", stdin)
#define OUTOPEN freopen("out.txt", "w", stdout) typedef unsigned long long int ull;
typedef long long int ll;
const int maxn = 2e5 + 10;
const int maxm = 1e5 + 10;
const ll mod = 2012;
const ll INF = 1e18 + 100;
const int inf = 0x3f3f3f3f;
const double pi = acos(-1.0);
const double eps = 1e-8;
using namespace std; int n, m;
int cas, tol, T; struct Sam {
struct Node {
int next[20];
int fa, len;
void init() {
mes(next, 0);
fa = len = 0;
}
} node[maxn<<1];
int dp1[maxn<<1], dp2[maxn<<1];
bool vis[maxn<<1];
int tax[maxn<<1], gid[maxn<<1];
int last, sz;
void init() {
mes(dp1, 0);
mes(dp2, 0);
last = sz = 1;
node[sz].init();
}
void insert(int k) {
int p = last, np = last = ++sz;
node[np].init();
node[np].len = node[p].len + 1;
for(; p&&!node[p].next[k]; p=node[p].fa)
node[p].next[k] = np;
if(p == 0) {
node[np].fa = 1;
} else {
int q = node[p].next[k];
if(node[q].len == node[p].len+1) {
node[np].fa = q;
} else {
int nq = ++sz;
node[nq] = node[q];
node[nq].len = node[p].len+1;
node[np].fa = node[q].fa = nq;
for(; p&&node[p].next[k]==q; p=node[p].fa)
node[p].next[k] = nq;
}
}
}
void solve() {
int ans = 0;
for(int i=0; i<=sz; i++) tax[i] = 0;
for(int i=1; i<=sz; i++) tax[node[i].len]++;
for(int i=1; i<=sz; i++) tax[i] += tax[i-1];
for(int i=1; i<=sz; i++) gid[tax[node[i].len]--] = i;
dp2[1] = 1;
for(int i=1; i<=sz; i++) {
int u = gid[i];
ans = (ans+dp1[u])%mod;
// printf("%d %d %d\n", u, dp1[u], dp2[u]);
for(int j=(u==1 ? 1:0); j<=9; j++) {
if(node[u].next[j+1] == 0) continue;
int nst = node[u].next[j+1];
dp1[nst] = (dp1[nst] + dp1[u]*10 + j*dp2[u])%mod;
dp2[nst] = (dp2[nst] + dp2[u])%mod;
}
}
printf("%d\n", ans);
}
} sam;
char s[maxn], t[maxn]; int main() {
while(~scanf("%d", &T)) {
mes(s, 0);
n = 0;
while(T--) {
scanf("%s", t+1);
int tlen = strlen(t+1);
for(int i=1; i<=tlen; i++) {
s[++n] = t[i]-'0'+1;
}
s[++n] = 11;
}
sam.init();
for(int i=1; i<=n; i++) {
sam.insert(s[i]);
}
sam.solve();
}
return 0;
}
05-11 14:44