题意
有一个只含有\(+,\times\)与数字的表达式,请求出它的值取模\(998244353\)
这个表达式以这样的形式给出:
有\(n\)段,每段有一个数\(r_i\)与一个字符串\(s_i\),表示这个字符串首尾相接重复\(r_i\)次
这\(n\)段的所有字符串首尾拼接起来即为该表达式
\(n\leq 5\times10^4,r_i\leq 10^9,|s_i|\leq 10\)
解法
由于整个串只有\(+ ,\times\)两种符号,考虑分开处理:
设三个量来表示当前状态:\(S,M,T\)
\(S:\)最后一个\(+\)号之前的总和
\(M:\)最后一个\(+\)号之后的到最后一个\(\times\)为止的乘积(没有\(\times\)号就为\(1\))
\(T:\)最后一个\(+\)号到末尾的项的结果(如果末尾为一个符号那么\(T=0\))
\(e:e=1\)(由于矩阵内有\(1\),所以要添加\(e\)这一项方便转移)
最后的答案即为\(S+T\)
这里\(M\)的状态设计很巧妙,实际上是为了给\(T\)的转移做准备的
对于\(+\)号:\((S,M,T,e)\to (S+T,e,0,e)\)
对于\(\times\)号:\((S,M,T,e)\to (S,T,0,e)\)
对于一个数字\(k\): \((S,M,T,e)\to (S,M,10T+kM,e)\)
实际上是一个乘法分配率,由于最后一个数是\(\frac{T}{M}\),新的\(T\)应该是\((\frac{10T}{M}+k)\times M\)即\(10T+kM\)
这样,每个字符的贡献就被我们分开了,成了一个独立的转移矩阵
由于矩乘有结合律,我们处理出每段的转移矩阵后快速幂矩乘结合即可
代码
#include <cstdio>
#include <cstring>
using namespace std;
const int mod = 998244353;
int N, R;
char op[20];
inline int mtp(int x, int y) { return 1LL * x * y % mod; }
inline int inc(int& x, int y) { (x += y) > mod ? x -= mod : x; }
struct matrix {
int n, m;
int a[5][5];
matrix(int _n = 0, int _m = 0, int v = 0) {
memset(a, 0, sizeof a);
n = _n, m = _m;
for (int i = 0; i < n; ++i) a[i][i] = v;
}
matrix operator * (const matrix& b) const {
matrix res = matrix(n, b.m);
for (int i = 0; i < n; ++i)
for (int j = 0; j < m; ++j)
for (int k = 0; k < b.m; ++k)
inc(res.a[i][j], mtp(a[i][k], b.a[k][j]));
return res;
}
matrix operator ^ (int k) const {
matrix x = *this;
matrix res = matrix(x.n, x.m, 1);
for (; k; x = x * x, k >>= 1)
if (k & 1) res = res * x;
return res;
}
};
matrix add, mul, ans, num[10];
void init() {
add = matrix(4, 4);
add.a[0][0] = add.a[2][0] = add.a[3][1] = add.a[3][3] = 1;
mul = matrix(4, 4);
mul.a[0][0] = mul.a[2][1] = mul.a[3][3] = 1;
for (int i = 0; i < 10; ++i) {
num[i] = matrix(4, 4);
num[i].a[0][0] = num[i].a[1][1] = num[i].a[3][3] = 1;
num[i].a[2][2] = 10, num[i].a[1][2] = i;
}
}
int main() {
scanf("%d", &N);
init();
ans = matrix(1, 4);
ans.a[0][1] = ans.a[0][3] = 1;
while (N--) {
scanf("%d%s", &R, op + 1);
matrix tmp = matrix(4, 4, 1);
for (int i = 1, sz = strlen(op + 1); i <= sz; ++i)
if (op[i] == '+')
tmp = tmp * add;
else if (op[i] == '*')
tmp = tmp * mul;
else
tmp = tmp * num[op[i] - 48];
ans = ans * (tmp ^ R);
}
printf("%d\n", (ans.a[0][0] + ans.a[0][2]) % mod);
return 0;
}