题目

LOJ2292

分析

比较神奇的一个区间 DP ,我看了很多题解都没看懂,

先明确一下题意:abcde 取完 c 后变成 abde ,可以取 bd 这样取 c 后新增的连续段。因此这题需要区间 DP。

能发现取一段区间的代价只与这段区间的最大值和最小值有关。那么用 \(f_{i,j,l,r}\) 表示将区间 \([i,j]\) 取到只剩下值在 \([l,r]\) 中的数的最小代价,\(g_{i,j}\) 表示取完区间 \([i,j]\) 的最小代价,则 \(g_{1,n}\) 就是答案。

考虑怎么转移。对于一段区间而言,取区间末尾的那个数不会创造出新的连续段,也就是说不存在一步必须要取完末尾的数才能取到。因此末尾的数一定可以最后一步再取。那么大力枚举末尾的数是和前面多少个数一起取的,就有转移:

\[f_{i,j,l,r}=\min_{k=i+1}^j f_{i,k-1,l,r}+g_{k,j}\]

此外,如果末尾的数在 \([l,r]\) 中,那么也可以不取。此时有转移:

\[f_{i,j,l,r}=f_{i,j-1,l,r}\]

\(g\) 的转移就是枚举取最后一步时剩下的最大值和最小值,然后加上取这一次的代价。即:

\[g_{i,j}=\min_{l=1}^{m}\min_{r=l}^m f_{i,j,l,r}+a+b(r-l)^2\]

其中 \(m\) 是权值的最大值。

离散化后时间复杂度 \(O(n^5)\)

代码

#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

namespace zyt
{
    const int N = 55, INF = 0x3f3f3f3f;
    int n, a, b, f[N][N][N][N], g[N][N], arr[N], tmp[N];
    int sq(const int x)
    {
        return x * x;
    }
    int work()
    {
        memset(f, INF, sizeof(f));
        memset(g, INF, sizeof(g));
        scanf("%d%d%d", &n, &a, &b);
        for (int i = 1; i <= n; i++)
            scanf("%d", &arr[i]), tmp[i] = arr[i];
        sort(tmp + 1, tmp + n + 1);
        int cnt = unique(tmp + 1, tmp + n + 1) - tmp - 1;
        for (int i = 1; i <= n; i++)
            arr[i] = lower_bound(tmp + 1, tmp + cnt + 1, arr[i]) - tmp;
        for (int i = 1; i <= n; i++)
            memset(f[i][i - 1], 0, sizeof(f[i][i - 1]));
        for (int len = 1; len <= n; len++)
            for (int i = 1; i + len - 1 <= n; i++)
            {
                int j = i + len - 1;
                for (int l = 1; l <= cnt; l++)
                    for (int r = l; r <= cnt; r++)
                    {
                        if (l <= arr[j] && arr[j] <= r)
                            f[i][j][l][r] = min(f[i][j][l][r], f[i][j - 1][l][r]);
                        for (int k = i + 1; k <= j; k++)
                            f[i][j][l][r] = min(f[i][j][l][r], f[i][k - 1][l][r] + g[k][j]);
                    }
                for (int l = 1; l <= cnt; l++)
                    for (int r = l; r <= cnt; r++)
                        g[i][j] = min(g[i][j], f[i][j][l][r] + a + b * sq(tmp[r] - tmp[l]));
            }
        printf("%d", g[1][n]);
        return 0;
    }
}
int main()
{
    return zyt::work();
}
12-22 05:30
查看更多