Sky Full of Stars

题目链接http://codeforces.com/problemset/problem/997/C

数据范围:略。


题解

首先考虑拟对象,如果至少有一行完全相等即可。

这个的答案就需要多步容斥:$\sum\limits_{i = 1} ^ n (-1)^{i + 1}\cdot 3 ^ i\cdot 3 ^ {n \cdot (n - i)}$。

那么至少有一列的答案跟这个一样。

把他俩加一起就是答案么?我们需要减去什么?

显然,需要减掉至少有一行且至少有一列的。

这个怎么弄?

是这样的,如果我们钦定了$i$行$j$列都必须相等之后,$x$行$y$列相等这种情况会被算$C_{x} ^ {i}\times C_{y} ^ {j}$次。

假设,$F_{(i,j)}$表示$i$行$j$列都相等的答案。

假设$A_{(i,j)}$表示$F_{(i,j)}$的容斥系数,发现当$A_{(i,j)} = (-1)^{i + j + 1}$时满足题意。

故此,至少有一列切至少有一行的答案是:

$\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3\cdot 3^{(n - i)\times (n - j)}$。

这个怎么看都是$O(n^2)$的对不对....

我们考虑把他转化转化:

$=\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3^{(n - i)\times (n - j) + 1}$

$=\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3^{n^2 - in - jn +ij + 1}$

$=\sum\limits_{i = 1} ^ n\sum\limits_{j = 1} ^ n (-1) ^ {i + j + 1} C_{n} ^ {i}C_{n}^{j} \cdot 3^{n^2}\cdot 3^{-in}\cdot 3^{-jn}\cdot 3^{ij}$

$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \sum\limits_{j = 1} ^ n (-1) ^ j C_{n}^{j} \cdot 3^{-jn}\cdot 3^{ij}$

$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \sum\limits_{j = 1} ^ n (-1) ^ j C_{n}^{j} \cdot (3^{-n})^j\cdot (3^i)^j$

$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \sum\limits_{j = 1} ^ n C_{n}^{j} \cdot (-3^{i - n})^j$

发现,第二个$\sum$后面的式子类似于二项式反演,$(-3^{i - n})$相当于$(a-b)^n$中的$b$。

故此式子可以被化简成:

$=-3^{n^2}\cdot \sum\limits_{i = 1} ^ n (-1) ^ i\cdot C_{n} ^ {i}\cdot 3^{-in} \cdot (1 - (-3^{i - n}))^n$。

这个就能$O(n)$求了。

代码

#include <bits/stdc++.h>

#define setIO(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)

#define N 1000010

using namespace std;

typedef long long ll;

const int mod = 998244353 ;

int fac[N], inv[N];

int qpow(int x, int y) {
    int ans = 1;
    while (y) {
        if (y & 1) {
            ans = (ll)ans * x % mod;
        }
        y >>= 1;
        x = (ll)x * x % mod;
    }
    return ans;
}

inline int C(int x, int y) {
    return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}

int main() {
    // setIO("c");
    int n;
    cin >> n ;
    fac[0] = inv[0] = 1;
    for (int i = 1; i <= n; i ++ ) {
        fac[i] = (ll)fac[i - 1] * i % mod;
        inv[i] = qpow(fac[i], mod - 2);
    }

    int ans = 0;
    for (int i = 1; i <= n; i ++ ) {
        ans = (ans + (ll)qpow(3, ((ll)n * (n - i) % (mod - 1) + i) % (mod - 1))
        * qpow(mod - 1, i + 1) % mod
        * C(n, i) % mod) % mod;
    }
    ans = ans * 2 % mod;

    int mdl = 0;
    for (int i = 0; i < n; i ++ ) {
        int t = (mod - qpow(3, i)) % mod;
        mdl = (mdl + (ll)C(n, i)
        * qpow(mod - 1, i + 1) % mod
        * ( ( (ll) qpow(t + 1, n) + mod - qpow(t, n) ) % mod) % mod) % mod;
    }
    ans = (ans + (ll)mdl * 3) % mod;
    cout << ans << endl ;
    fclose(stdin), fclose(stdout);
    return 0;
}
01-12 15:45