「NOI2018」冒泡排序

考虑冒泡排序中一个位置上的数向左移动的步数 \(Lstep\) 为左边比它大的数的个数,向右移动的步数 \(Rstep\) 为右边比它大的数的个数,如果 \(Lstep,Rstep\) 中有一个不为 \(0\) ,那么显然不会取到下界,因为产生了浪费的步数,题面给的提示在这里非常有用,如果至少有一个为 \(0\) ,那么显然没有产生浪费操作,取到下界,所以一个合法排列的充要条件就是对于所有位置 \(Lstep\times Rstep=0\) ,即该排列的最长下降子序列长度 \(\leq 2\) 。

先不考虑字典序的限制,只考虑求出一个合法的排列,记 \(dp_{i,j}\) 为前 \(i\) 个数,后面数中有 \(j\) 个比前 \(i\) 个数的最大值要小,此时前 \(i\) 位是一个合法排列的方案数,那么考虑这一步如果选一个小于最大值的数,一定要选最小的数,否则就会出现长度 \(>2\) 的最长下降子序列,否则可以随便选,那么 \(dp_{i,j}\) 可以转移到 \(dp_{i+1,k},j-1\leq k\leq n-i-1\) 。考虑加上字典序的限制,相当于对每一次转移到的 \(k\) 做一个下界限制,稍微改一改就得到了一个 \(\mathcal O(n^2)\) 的 80分做法,这么简单的套路去年考的时候居然没想到。

其实每次 \(k\) 的取值是 \(\geq -1\) 的任何数,因为如果 \(k > n -i+1\) 的话,就再也转移不回 \(dp_{n,0}\) 了,对答案没有影响,然后把每次取的 \(k\) 都加 \(1\) ,问题就转化为 \((0,0)\) 到 \((n,n)\) 不能低于 \(y=-1\) 的一个格路计数问题了,此时不加上字典序的限制就是卡特兰数,加上字典序的限制就枚举再哪里超过了字典序的限制,然后的方案数也是可以 \(O(1)\) 算的,类似于卡特兰数的推导。

code

/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 1200005, mod = 998244353;
int a[N], mx[N], mn[N], js[N], lim[N], inv[N], n;
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline int C(int x, int y){
if(x < y || x < 0 || y < 0) return 0;
return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod;
}
inline int calc(int x, int y){
int res = 0;
up(res, C(n - x + n - y, n - x));
up(res, mod - C(n - x + n - y, n - y - 1));
return res;
}
namespace Bit{
int s[N];
inline void add(int x){
for(int i = x; i <= n; i += i & -i) s[i]++;
}
inline int query(int x){
int res = 0;
for(int i = x; i; i -= i & -i) res += s[i];
return res;
}
}
inline void solve(){
read(n);
for(int i = 1; i <= n; i++) read(a[i]);
mn[n] = a[n];
for(int i = n - 1; i >= 1; i--) mn[i] = min(a[i], mn[i+1]);
mx[1] = a[1];
for(int i = 2; i <= n; i++) mx[i] = max(mx[i-1], a[i]);
for(int i = n; i >= 1; i--)
lim[i] = Bit::query(mx[i]), Bit::add(a[i]);
for(int i = 1; i <= n; i++) lim[i] += i;
int res = 0;
for(int i = 1; i <= n; i++){
if(lim[i] < n) up(res, calc(i - 1, lim[i] + 1));
if(lim[i-1] > lim[i]) break;
if(lim[i-1] == lim[i] && mn[i] < a[i]) break;
}
cout << res << endl;
for(int i = 0; i <= n; i++) Bit::s[i] = 0;
}
int main(){
js[0] = inv[0] = 1;
for(int i = 1; i < N; i++){
js[i] = 1ll * js[i-1] * i % mod;
inv[i] = Pow(js[i], mod - 2);
}
int T; read(T); while(T--) solve();
}
05-15 21:19