题意

有一个只含有\(+,\times\)与数字的表达式,请求出它的值取模\(998244353\)

这个表达式以这样的形式给出:

\(n\)段,每段有一个数\(r_i\)与一个字符串\(s_i\),表示这个字符串首尾相接重复\(r_i\)

\(n\)段的所有字符串首尾拼接起来即为该表达式

\(n\leq 5\times10^4,r_i\leq 10^9,|s_i|\leq 10\)


解法

由于整个串只有\(+ ,\times\)两种符号,考虑分开处理:

设三个量来表示当前状态:\(S,M,T\)

\(S:\)最后一个\(+\)号之前的总和

\(M:\)最后一个\(+\)号之后的到最后一个\(\times\)为止的乘积(没有\(\times\)号就为\(1\)

\(T:\)最后一个\(+\)号到末尾的项的结果(如果末尾为一个符号那么\(T=0\)

\(e:e=1\)(由于矩阵内有\(1\),所以要添加\(e\)这一项方便转移)

最后的答案即为\(S+T\)

这里\(M\)的状态设计很巧妙,实际上是为了给\(T\)的转移做准备的

对于\(+\)号:\((S,M,T,e)\to (S+T,e,0,e)\)

对于\(\times\)号:\((S,M,T,e)\to (S,T,0,e)\)

对于一个数字\(k\)\((S,M,T,e)\to (S,M,10T+kM,e)\)

实际上是一个乘法分配率,由于最后一个数是\(\frac{T}{M}\),新的\(T\)应该是\((\frac{10T}{M}+k)\times M\)\(10T+kM\)

这样,每个字符的贡献就被我们分开了,成了一个独立的转移矩阵

由于矩乘有结合律,我们处理出每段的转移矩阵后快速幂矩乘结合即可


代码

#include <cstdio>
#include <cstring>

using namespace std;

const int mod = 998244353;

int N, R;

char op[20];

inline int mtp(int x, int y) { return 1LL * x * y % mod; }
inline int inc(int& x, int y) { (x += y) > mod ? x -= mod : x; }

struct matrix {

    int n, m;
    int a[5][5];

    matrix(int _n = 0, int _m = 0, int v = 0) {
        memset(a, 0, sizeof a);
        n = _n, m = _m;
        for (int i = 0; i < n; ++i)  a[i][i] = v;
    }

    matrix operator * (const matrix& b) const {
        matrix res = matrix(n, b.m);
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                for (int k = 0; k < b.m; ++k)
                    inc(res.a[i][j], mtp(a[i][k], b.a[k][j]));
        return res;
    }

    matrix operator ^ (int k) const {
        matrix x = *this;
        matrix res = matrix(x.n, x.m, 1);
        for (; k; x = x * x, k >>= 1)
            if (k & 1)  res = res * x;
        return res;
    }

};

matrix add, mul, ans, num[10];

void init() {
    add = matrix(4, 4);
    add.a[0][0] = add.a[2][0] = add.a[3][1] = add.a[3][3] = 1;
    mul = matrix(4, 4);
    mul.a[0][0] = mul.a[2][1] = mul.a[3][3] = 1;
    for (int i = 0; i < 10; ++i) {
        num[i] = matrix(4, 4);
        num[i].a[0][0] = num[i].a[1][1] = num[i].a[3][3] = 1;
        num[i].a[2][2] = 10, num[i].a[1][2] = i;
    }
}

int main() {

    scanf("%d", &N);

    init();

    ans = matrix(1, 4);
    ans.a[0][1] = ans.a[0][3] = 1;

    while (N--) {
        scanf("%d%s", &R, op + 1);

        matrix tmp = matrix(4, 4, 1);

        for (int i = 1, sz = strlen(op + 1); i <= sz; ++i)
            if (op[i] == '+')
                tmp = tmp * add;
            else if (op[i] == '*')
                tmp = tmp * mul;
            else
                tmp = tmp * num[op[i] - 48];

        ans = ans * (tmp ^ R);
    }

    printf("%d\n", (ans.a[0][0] + ans.a[0][2]) % mod);

    return 0;
}
01-26 05:50
查看更多