Sky Full of Stars
题目链接:http://codeforces.com/problemset/problem/997/C
数据范围:略。
题解:
首先考虑拟对象,如果至少有一行完全相等即可。
这个的答案就需要多步容斥:$\sum\limits_{i = 1} ^ n (-1)^{i + 1}\cdot 3 ^ i\cdot 3 ^ {n \cdot (n - i)}$。
那么至少有一列的答案跟这个一样。
把他俩加一起就是答案么?我们需要减去什么?
显然,需要减掉至少有一行且至少有一列的。
这个怎么弄?
是这样的,如果我们钦定了$i$行$j$列都必须相等之后,$x$行$y$列相等这种情况会被算$C_{x} ^ {i}\times C_{y} ^ {j}$次。
假设,$F_{(i,j)}$表示$i$行$j$列都相等的答案。
假设$A_{(i,j)}$表示$F_{(i,j)}$的容斥系数,发现当$A_{(i,j)} = (-1)^{i + j + 1}$时满足题意。
故此,至少有一列切至少有一行的答案是:
$\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3\cdot 3^{(n - i)\times (n - j)}$。
这个怎么看都是$O(n^2)$的对不对....
我们考虑把他转化转化:
$=\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3^{(n - i)\times (n - j) + 1}$
$=\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3^{n^2 - in - jn +ij + 1}$
$=\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3^{n^2}\cdot 3^{-in}\cdot 3^{-jn}\cdot 3^{ij}$
$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \sum\limits_{j = 1} ^ n (-1) ^ j C_{n}^{j} \cdot 3^{-jn}\cdot 3^{ij}$
$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \sum\limits_{j = 1} ^ n (-1) ^ j C_{n}^{j} \cdot (3^{-n})^j\cdot (3^i)^j$
$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \sum\limits_{j = 1} ^ n C_{n}^{j} \cdot (-3^{i - n})^j$
发现,第二个$\sum$后面的式子类似于二项式反演,$(-3^{i - n})$相当于$(a-b)^n$中的$b$。
故此式子可以被化简成:
$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \cdot (1 - (-3^{i - n}))^n$。
这个就能$O(n)$求了。
代码:
#include <bits/stdc++.h> #define setIO(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout) #define N 1000010 using namespace std; typedef long long ll; const int mod = 998244353 ; int fac[N], inv[N]; int qpow(int x, int y) { int ans = 1; while (y) { if (y & 1) { ans = (ll)ans * x % mod; } y >>= 1; x = (ll)x * x % mod; } return ans; } inline int C(int x, int y) { return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod; } int main() { // setIO("c"); int n; cin >> n ; fac[0] = inv[0] = 1; for (int i = 1; i <= n; i ++ ) { fac[i] = (ll)fac[i - 1] * i % mod; inv[i] = qpow(fac[i], mod - 2); } int ans = 0; for (int i = 1; i <= n; i ++ ) { ans = (ans + (ll)qpow(3, ((ll)n * (n - i) % (mod - 1) + i) % (mod - 1)) * qpow(mod - 1, i + 1) % mod * C(n, i) % mod) % mod; } ans = ans * 2 % mod; int mdl = 0; for (int i = 0; i < n; i ++ ) { int t = (mod - qpow(3, i)) % mod; mdl = (mdl + (ll)C(n, i) * qpow(mod - 1, i + 1) % mod * ( ( (ll) qpow(t + 1, n) + mod - qpow(t, n) ) % mod) % mod) % mod; } ans = (ans + (ll)mdl * 3) % mod; cout << ans << endl ; fclose(stdin), fclose(stdout); return 0; }