跟仙人掌其实没啥关系…

Here

注意 每一次都O(n)O(n)O(n)一下算某些点都是黑点的概率其实并不是O(n2)O(n^2)O(n2),因为每个环只用算一次.

#include <cctype>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MAXN = 100005;
const int MAXM = 250005;
const int mod = 998244353;
inline void read(int &num) {
char ch; while(!isdigit(ch=getchar()));
for(num=0; isdigit(ch); num=num*10+ch-'0',ch=getchar());
}
inline int qmul(int a, int b) {
int res = 1;
while(b) {
if(b&1) res = 1ll * res * a % mod;
a = 1ll * a * a % mod; b >>= 1;
}
return res;
}
int n, m, t, w, bjc[MAXN], u[MAXM], v[MAXM]; bool in[MAXM];
int find(int x) { return bjc[x] == x ? x : bjc[x] = find(bjc[x]); }
namespace pou {
int fir[MAXN], cnt; struct edge { int to, nxt; }e[MAXN<<1];
inline void addedge(int x, int y) { e[cnt] = (edge){ y, fir[x] }, fir[x] = cnt++; }
int sz[MAXN], top[MAXN], fa[MAXN], son[MAXN], dep[MAXN], dfn[MAXN], tmr, seq[MAXN];
inline void dfs1(int u, int ff) {
dep[u] = dep[fa[u]=ff] + (sz[u]=1);
for(int v, i = fir[u]; ~i; i = e[i].nxt)
if((v=e[i].to) != ff) {
dfs1(v, u), sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
}
inline void dfs2(int u, int tp) {
top[u] = tp; seq[dfn[u]=++tmr] = u;
if(son[u]) dfs2(son[u], tp);
for(int v, i = fir[u]; ~i; i = e[i].nxt)
if((v=e[i].to) != fa[u] && v != son[u])
dfs2(v, v);
}
bool col[MAXN<<2], lz[MAXN<<2];
void upd(int i) { col[i] = col[i<<1] | col[i<<1|1]; }
void mt(int i) {
if(lz[i]) {
lz[i<<1] = lz[i<<1|1] = col[i<<1] = col[i<<1|1] = 1;
lz[i] = 0;
}
}
void cover(int i, int l, int r, int x, int y) {
if(l == x && r == y) {
col[i] = lz[i] = 1;
return;
}
mt(i);
int mid = (l + r) >> 1;
if(y <= mid) cover(i<<1, l, mid, x, y);
else if(x > mid) cover(i<<1|1, mid+1, r, x, y);
else cover(i<<1, l, mid, x, mid), cover(i<<1|1, mid+1, r, mid+1, y);
upd(i);
}
inline int Cover(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
cover(1, 1, n, dfn[top[x]], dfn[x]); x = fa[top[x]];
}
if(x == y) return x;
if(dep[x] < dep[y]) swap(x, y);
cover(1, 1, n, dfn[y]+1, dfn[x]);
return y;
}
bool query(int i, int l, int r, int x, int y) {
if(l == x && r == y) return col[i];
mt(i);
int mid = (l + r) >> 1; bool res;
if(y <= mid) res = query(i<<1, l, mid, x, y);
else if(x > mid) res = query(i<<1|1, mid+1, r, x, y);
else res = (query(i<<1, l, mid, x, mid) || query(i<<1|1, mid+1, r, mid+1, y));
upd(i); return res;
}
inline bool Query(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
if(query(1, 1, n, dfn[top[x]], dfn[x])) return 1;
x = fa[top[x]];
}
if(x == y) return 0;
if(dep[x] < dep[y]) swap(x, y);
return query(1, 1, n, dfn[y]+1, dfn[x]);
}
}
int P[MAXN], ans[2], inv[MAXN], fac[MAXN], MULT[MAXN];
inline void pre(int N) {
inv[0] = inv[1] = fac[0] = fac[1] = 1;
for(int i = 2; i <= N; ++i)
fac[i] = 1ll * fac[i-1] * i % mod, inv[i] = 1ll * (mod - mod/i) * inv[mod%i] % mod;
for(int i = 2; i <= N; ++i)
inv[i] = 1ll * inv[i-1] * inv[i] % mod;
for(int i = 2; i <= N; ++i) MULT[i] = qmul(i, t);
}
inline int C(int n, int m) {
if(m > n) return 0;
return 1ll * fac[n] * inv[m] % mod * inv[n-m] % mod;
}
int all, invall, invn, POINT, EDGE, f[MAXN];
inline int solve(int i) {
if(!i || i > t) return 0;
if(~f[i]) return f[i]; //记忆化一下
int res = all;
for(int j = 1, flg = 1; j <= i; flg *= -1, ++j)
res = (res - 1ll * flg * C(i, j) * MULT[n-j] % mod) % mod;
return f[i] = 1ll * res * invall % mod;
}
int main () {
freopen("cactus.in", "r", stdin);
freopen("cactus.out", "w", stdout);
read(n), read(m), read(t), read(w); pre(n);
memset(f, -1, sizeof f);
all = qmul(n, t), invall = qmul(all, mod-2), invn = qmul(n, mod-2);
for(int i = 1; i < n; ++i) P[i] = qmul(1ll*(n-i)*invn%mod, t); ans[0] = 1ll * P[1] * n % mod;
ans[1] = 1ll * (POINT=solve(1)) * n % mod;
EDGE = solve(2);
for(int i = 1; i <= n; ++i) bjc[i] = i, pou::fir[i] = -1;
for(int i = 1; i <= m; ++i) {
read(u[i]), read(v[i]);
int x = find(u[i]), y = find(v[i]);
if(x != y) { in[i] = 1;
pou::addedge(u[i], v[i]);
pou::addedge(v[i], u[i]);
bjc[y] = x;
}
}
for(int i = 1; i <= n; ++i)
if(!pou::sz[i]) pou::dfs1(i, 0), pou::dfs2(i, i);
for(int i = 1; i <= m; ++i) {
if(in[i]) {
ans[0] = (ans[0] - P[2]) % mod;
if(w)ans[1] = (ans[1] - EDGE) % mod;
}
else if(pou::Query(u[i], v[i]) == 0) {
ans[0] = (ans[0] - P[2]) % mod;
if(w)ans[1] = (ans[1] - EDGE) % mod;
int lca = pou::Cover(u[i], v[i]);
int len = pou::dep[u[i]] + pou::dep[v[i]] - (pou::dep[lca]<<1) + 1;
ans[0] = (ans[0] + P[len]) % mod;
if(w) ans[1] = (ans[1] + solve(len)) % mod;
}
printf("%d\n", w ? ((ans[0] + ans[1]) % mod + mod) % mod : (ans[0]+mod) % mod);
}
}
05-12 23:33