题目
比赛界面。
T1
数据范围明示直接\(O(n^2)\)计算,问题就在如何快速计算。
树上路径统计通常会用到差分方法。这里有两棵树,因此我们可以做“差分套差分”,在 A 树上对 B 的差分信息进行差分。在修改的时候,我们就会在 A 上 4 个位置进行修改,每次修改会涉及 B 上 4 个位置的差分修改,因此总共会涉及 16 个差分信息的修改。
回收标记的时候,我们可以先在 A 树上进行 DFS ,回收好子树内的差分信息后,再进行一次 B 的回收,得到当前节点上 B 的真实信息。
时间是\(O(n^2+q\log_2n)\)。
#include <cmath>
#include <cstdio>
typedef long long LL;
const int MAXN = 1e4 + 5, MAXLOG = 15;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
struct Tree
{
struct edge
{
int to, nxt;
}Graph[MAXN << 1];
int f[MAXN][MAXLOG];
int head[MAXN], dep[MAXN], seq[MAXN];
int n, ID, lg2, cnt;
void addEdge( const int from, const int to )
{
Graph[++ cnt].to = to, Graph[cnt].nxt = head[from];
head[from] = cnt;
}
void DFS( const int u, const int fa )
{
f[u][0] = fa, seq[++ ID] = u, dep[u] = dep[fa] + 1;
for( int i = head[u], v ; i ; i = Graph[i].nxt )
if( ( v = Graph[i].to ) ^ fa )
DFS( v, u );
}
void init()
{
read( n );
for( int i = 1, a, b ; i < n ; i ++ )
read( a ), read( b ), addEdge( a, b ), addEdge( b, a );
DFS( 1, 0 );
lg2 = log2( n );
for( int j = 1 ; j <= lg2 ; j ++ )
for( int i = 1 ; i <= n ; i ++ )
f[i][j] = f[f[i][j - 1]][j - 1];
}
void balance( int &u, const int stp ) const
{
for( int i = 0 ; ( 1 << i ) <= stp ; i ++ )
if( stp & ( 1 << i ) )
u = f[u][i];
}
int LCA( int u, int v ) const
{
if( dep[u] > dep[v] ) balance( u, dep[u] - dep[v] );
if( dep[v] > dep[u] ) balance( v, dep[v] - dep[u] );
if( u == v ) return u;
for( int i = lg2 ; ~ i ; i -- ) if( f[u][i] ^ f[v][i] ) u = f[u][i], v = f[v][i];
return f[u][0];
}
int fa( const int u ) const { return f[u][0]; }
};
Tree A, B;
int dif[MAXN][MAXN];
LL ans;
void change( int *d, const int u, const int v, const int lca, const int c )
{
d[u] += c, d[v] += c, d[lca] -= c, d[B.fa( lca )] -= c;
}
void recovery( const int u, const int fa )
{
for( int i = A.head[u], v ; i ; i = A.Graph[i].nxt )
if( ( v = A.Graph[i].to ) ^ fa )
recovery( v, u );
for( int i = 1 ; i <= B.n ; i ++ ) dif[fa][i] += dif[u][i];
for( int i = B.n ; i ; i -- )
{
int cur = B.seq[i];
ans ^= 1ll * u * cur * dif[u][cur];
dif[u][B.fa( cur )] += dif[u][cur];
}
}
int main()
{
int Q, a1, a2, b1, b2, c, lcaa, lcab;
A.init(), B.init(), read( Q );
while( Q -- )
{
read( a1 ), read( a2 ), read( b1 ), read( b2 ), read( c );
lcaa = A.LCA( a1, a2 ), lcab = B.LCA( b1, b2 );
change( dif[a1], b1, b2, lcab, c );
change( dif[a2], b1, b2, lcab, c );
change( dif[lcaa], b1, b2, lcab, -c );
change( dif[A.fa( lcaa )], b1, b2, lcab, -c );
}
recovery( 1, 0 );
write( ans ), putchar( '\n' );
return 0;
}
T2
神奇的 DP 配合优化,放个官方题解吧:
T3
字符串简单题,考虑容斥。总方案很好计算,然后考虑不相交字符串的方案数。枚举分界点,然后计算左边和右边的回文串数量即可。
所以我为什么会想到奇怪的 slink 方法呀......它还不卡我......
所以就贴了 slink 的代码。
#include <cstdio>
#include <cstring>
typedef long long LL;
const int MAXN = 2e6 + 5;
const int mod = 998244353, inv2 = 499122177;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int BIT[MAXN], g[MAXN];
int ed[MAXN], slink[MAXN], dif[MAXN];
int ch[MAXN][26], len[MAXN], fa[MAXN], dep[MAXN];
int N, tot, lst;
char S[MAXN];
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int sub( const int x, const int v ) { return ( x < v ? x - v + mod : x - v ); }
LL mul( const int x, const LL v ) { LL tmp = x * v; return tmp < mod ? tmp : tmp % mod; }
void up( int &x ) { x += x & ( -x ); }
void down( int &x ) { x -= x & ( -x ); }
void update( int x, const int v ) { for( ; x <= N ; up( x ) ) add( BIT[x], v ); }
int getSum( int x ) { int ret = 0; for( ; x ; down( x ) ) add( ret, BIT[x] ); return ret; }
int query( const int l, const int r ) { return sub( getSum( r ), getSum( l - 1 ) ); }
void build()
{
int x;
N = strlen( S + 1 );
len[fa[0] = ++ tot] = -1;
for( int i = 1 ; i <= N ; i ++ )
{
x = S[i] - 'a';
while( S[i] ^ S[i - len[lst] - 1] ) lst = fa[lst];
if( ! ch[lst][x] )
{
int cur = ++ tot, p = fa[lst];
while( S[i] ^ S[i - len[p] - 1] ) p = fa[p];
len[cur] = len[lst] + 2, fa[cur] = ch[p][x], ch[lst][x] = cur;
}
ed[i] = lst = ch[lst][x];
}
}
int main()
{
int ans = 0;
scanf( "%s", S + 1 ), build();
for( int i = 2 ; i <= tot ; i ++ )
{
dep[i] = dep[fa[i]] + 1, dif[i] = len[i] - len[fa[i]];
if( dif[i] ^ dif[fa[i]] ) slink[i] = fa[i];
else slink[i] = slink[fa[i]];
}
for( int i = 1 ; i <= N ; i ++ )
{
for( int p = ed[i] ; p ; p = slink[p] )
{
g[p] = query( i - len[slink[p]] - dif[p] + 1, i );
if( slink[p] ^ fa[p] )
{
add( g[p], g[fa[p]] );
add( g[p], mul( ( dep[p] - dep[slink[p]] - 1 ), query( i - dif[p], i ) ) );
}
add( ans, g[p] );
}
add( ans, mul( mul( dep[ed[i]], dep[ed[i]] - 1 ), inv2 ) );
update( i, dep[ed[i]] );
}
write( ans ), putchar( '\n' );
return 0;
}
小结
对于简单题和基础的技巧都还不太熟练,要多加运用。