description
你将向敌方发起进攻!敌方的防御阵地可以用一个 \(N\times M\) 的 \(01\) 矩阵表示,标为 \(1\) 的表示有效区域,标为 \(0\) 的是敌人的预警装置。
你将发起 \(K\) 轮进攻,每一轮从所有 \(\frac{NM(N+1)(M+1)}{4}\) 种可能中选定一个矩形区域对其进行轰炸。如果 \(K\) 轮后存在一个有效区域每次都被轰炸到,并且没有一次触发敌人的预警装置,那么将对敌人造成致命打击。现在你想知道一共有多少种不同的轰炸方案能对敌人造成致命打击,输出对 \(998244353\) 取模的结果。
solution
如果是树上选连通块要求有交集,可以用 "点 - 边" 的容斥技巧(参考「十二省联考 2019」希望)。
如果是网格图,我们类似地有 "1×1 - 1×2 - 2×1 + 2×2"。这样算出来每种交集恰好贡献 1。
接下来只需要考虑求多少个合法矩形包含某个 "1×1"(其他三种同理)。
可以考虑差分。分别求每个点作为左上角/右上角/左下角/右下角的时候有多少合法矩形,从而计算每个点在差分中的贡献。
这个可以单调栈 \(O(n^2)\) 随便做。
因为要快速幂,总时间复杂度 \(O(n^2\log k)\)。
accepted code
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN = 2000;
const int MOD = 998244353;
#define rep(i, x, n) for(int i=x;i<=n;i++)
#define per(i, x, n) for(int i=x;i>=n;i--)
inline int add(int x, int y) {x += y; return x >= MOD ? x - MOD : x;}
inline int sub(int x, int y) {x -= y; return x < 0 ? x + MOD : x;}
inline int mul(int x, int y) {return (int)(1LL * x * y % MOD);}
int pow_mod(int b, int p) {
int ret = 1;
for(int i=p;i;i>>=1,b=mul(b,b))
if( i & 1 ) ret = mul(ret, b);
return ret;
}
int a[MAXN + 5][MAXN + 5], N, M, K;
int h[MAXN + 5], stk[MAXN + 5], tp;
int s1[MAXN + 5][MAXN + 5], s2[MAXN + 5][MAXN + 5];
int s3[MAXN + 5][MAXN + 5], s4[MAXN + 5][MAXN + 5];
void init() {
rep(j, 1, M) h[j] = 0;
rep(i, 1, N) {
rep(j, 1, M) h[j] = (a[i][j] ? h[j] + 1 : 0);
int cnt = 0; stk[tp = 1] = 0;
rep(j, 1, M) {
while( tp && h[j] < h[stk[tp]] ) {
int x = stk[tp--];
cnt -= (x - stk[tp])*h[x];
}
cnt += (j - stk[tp])*h[j], stk[++tp] = j, s1[i][j] = cnt;
}
}
rep(j, 1, M) h[j] = 0;
rep(i, 1, N) {
rep(j, 1, M) h[j] = (a[i][j] ? h[j] + 1 : 0);
int cnt = 0; stk[tp = 1] = M + 1;
per(j, M, 1) {
while( tp && h[j] < h[stk[tp]] ) {
int x = stk[tp--];
cnt -= (stk[tp] - x)*h[x];
}
cnt += (stk[tp] - j)*h[j], stk[++tp] = j, s2[i][j] = cnt;
}
}
rep(j, 1, M) h[j] = 0;
per(i, N, 1) {
rep(j, 1, M) h[j] = (a[i][j] ? h[j] + 1 : 0);
int cnt = 0; stk[tp = 1] = 0;
rep(j, 1, M) {
while( tp && h[j] < h[stk[tp]] ) {
int x = stk[tp--];
cnt -= (x - stk[tp])*h[x];
}
cnt += (j - stk[tp])*h[j], stk[++tp] = j, s3[i][j] = cnt;
}
}
rep(j, 1, M) h[j] = 0;
per(i, N, 1) {
rep(j, 1, M) h[j] = (a[i][j] ? h[j] + 1 : 0);
int cnt = 0; stk[tp = 1] = M + 1;
per(j, M, 1) {
while( tp && h[j] < h[stk[tp]] ) {
int x = stk[tp--];
cnt -= (stk[tp] - x)*h[x];
}
cnt += (stk[tp] - j)*h[j], stk[++tp] = j, s4[i][j] = cnt;
}
}
per(i, N, 1) per(j, M, 1) {
s1[i][j] = add(s1[i][j], sub(add(s1[i+1][j], s1[i][j+1]), s1[i+1][j+1]));
s2[i][j] = add(s2[i][j], sub(add(s2[i+1][j], s2[i][j+1]), s2[i+1][j+1]));
s3[i][j] = add(s3[i][j], sub(add(s3[i+1][j], s3[i][j+1]), s3[i+1][j+1]));
s4[i][j] = add(s4[i][j], sub(add(s4[i+1][j], s4[i][j+1]), s4[i+1][j+1]));
}
}
int get(int dx, int dy) {
int ans = 0;
rep(i, dx + 1, N) rep(j, dy + 1, M) {
int x = sub(add(s1[i][j], s4[i+1-dx][j+1-dy]), add(s2[i][j+1-dy], s3[i+1-dx][j]));
ans = add(ans, pow_mod(x, K));
}
return ans;
}
char str[MAXN + 5];
int main() {
scanf("%d%d%d", &N, &M, &K);
rep(i, 1, N) {
scanf("%s", str + 1);
rep(j, 1, M) a[i][j] = str[j] - '0';
}
int ans = 0; init();
ans = add(ans, get(0, 0)), ans = sub(ans, get(1, 0));
ans = sub(ans, get(0, 1)), ans = add(ans, get(1, 1));
printf("%d\n", ans);
}
details
感觉学了这么久,啥也没学会(无奈.jpg)
这种非常套路而且之前见过类似套路的题都做不出来。果然人还是菜啊。