题目大意:给出n个旧单词,要从这n个旧单词中构造新单词。构造条件是 S = Sa + Sb,其中Sa为某个旧单词的非空前缀,Sb为某个单词的非空后缀。求所有的新单词和旧单词中有多少个不同的单词。

思路:将所有单词建成一棵字典树,再将所有单词反转并建成一棵字典树。则第一棵树的结点个数即为不同前缀的数量,第二棵树的结点个数为不同后缀的数量。如果不算重复,取两者的乘积即为新单词的数量。

接下来考虑重复的情况。什么样子会重复呢?假设第一棵树中某个前缀为a1a2a3...X,第二棵树中某个后缀为Xb1b2b3...。那么这就一定会有重复,也就是前缀的末字符和后缀的首字符相同时会重复。考虑a1a2a3...Xb1b2b3...,这个X可能出现在前缀里,也可能出现在后缀里,这也就是多算了一种情况。考虑一般情况,假设第一棵树中某个前缀为a1a2a3...XXX..X(一共m个X),第二棵树中某个后缀为XXX...Xb1b2b3...(一共n个X),那么这两个串组成的串a1a2a3XX...Xb1b2b3...中X的数目有m+n+1种可能(0~m+n),而前缀中X的数目有(m+1)种可能(0~m个),同理后缀中X的数目(n+1)种可能,因此我们在实际计算中一共多算了(m+1)*(n+1) - (m+n+1) = m*n次,即前缀中X的数目与后缀中X数目的乘积。

那么,我们只要分别求出两棵树中每个字符的出现次数并减去它们的乘积即可(这是用了加法乘法原理,假设第一棵树中末尾为X的前缀共有3个,分别含a1、a2、a3个X;第二棵树中开头为X的后缀为3个,分别含b1、b2、b3个X,则一共重复算了a1*b1+a1*b2+a1*b3+a2*b1+a2*b2+a2*b3+a3*b1+a3*b2+a3*b3=(a1+a2+a3)*(b1+b2+b3)次)。这里要注意,只统计深度大于1的结点,因为我们要保证前后缀均非空(如前缀X与后缀Xb就不会有重复,因为第一个的X必须要选)。最后,再加上长度为1的旧单词即可,因为我们构造单词时不会造出长度为1的单词。

#include<cstdio>
#include<cstring>
#include<string>
#include<cctype>
#include<iostream>
#include<set>
#include<map>
#include<cmath>
#include<sstream>
#include<vector>
#include<stack>
#include<queue>
#include<algorithm>
#define fin freopen("a.txt","r",stdin)
#define fout freopen("a.txt","w",stdout)
typedef long long LL;
typedef unsigned long long ULL;
using namespace std;
const int inf = 1e9 + ;
const int maxnode = 4e5 + ;
const int sigma_size = ;
const int maxn = + ;
char s[];
int vis[sigma_size]; struct Tree
{
int ch[maxnode][sigma_size];
int val[maxnode];
int cnt[sigma_size];
int sz;
int idx(char c) { return c - 'a'; }
void init() { memset(ch[], , sizeof ch[]); sz = ; memset(cnt, , sizeof cnt); } void insert(char *s)
{
int n = strlen(s), u = ;
for(int i = ; i < n; i++)
{
int c = idx(s[i]);
if(!ch[u][c])
{
memset(ch[sz], , sizeof ch[sz]);
val[sz] = ;
ch[u][c] = sz++;
if(i) cnt[c]++;
}
u = ch[u][c]; }
val[u] = ;
} }Pre, Suf; int main()
{ int n;
while(scanf("%d", &n) == )
{
Pre.init(); Suf.init();
memset(vis, , sizeof vis);
for(int i = ; i <= n; i++)
{
scanf("%s", s);
Pre.insert(s);
int len = strlen(s);
reverse(s, s+len);
Suf.insert(s);
if(len == ) vis[s[]-'a'] = ;
}
LL ans = LL(Pre.sz-)*LL(Suf.sz-);
for(int i = ; i < sigma_size; i++)
ans -= (LL)Pre.cnt[i] * LL(Suf.cnt[i]);
for(int i = ; i < sigma_size; i++)
if(vis[i]) ++ans;
cout << ans << endl;
}
return ;
}
05-28 16:21