首先思考一个朴素的做法
将b[]维护成一个线段树套有序表,每次修改a[]用线段树+lazy tag
并在线段树的子区间上在有序表中二分更新这段区间中a[i]>=b[i]的值,复杂度O(nlog^2)
有没有更优的做法?
考虑在一次修改操作中,查询的数是相同的,并且b[]这个有序表始终不变
因此在生成b[]的线段树套有序表过程中(其实就是归并排序的记录)
我们维护每个数在左右子区间中名次。
这样修改的时候我们只要对b[]排好序的整体做一次二分,找到b[m]<=x<b[m+1]
把修改成x当作修改成b[m],这样我们就能O(1)地更新每个子区间的计数,问题得解
#include<bits/stdc++.h>
#define mp make_pair
using namespace std;
const int mo=;
vector< pair<int,int> > tr[*];
int laz[*],s[*],a[],b[],c[];
int A,B,n,m,t;
int C = ~(<<), M = (<<)-;
int rnd(int last,int &a,int &b)
{
a = ( + (last >> )) * (a & M) + (a >> );
b = ( + (last >> )) * (b & M) + (b >> );
return (C & ((a << ) + b)) % ;
} void build(int i,int l,int r)
{
laz[i]=-;
tr[i].clear();
if (l==r)
{
s[i]=(a[l]>=b[l]);
}
else {
int m=(l+r)>>;
build(i*,l,m);
build(i*+,m+,r);
s[i]=s[i*]+s[i*+];
int x=l,y=m+,h=l;
while (x<=m&&y<=r)
{
if (b[x]<b[y])
{
tr[i].push_back(mp(x-l,y-m-));
c[h++]=b[x++];
}
else {
tr[i].push_back(mp(x-l-,y-m-));
c[h++]=b[y++];
}
}
while (x<=m)
{
tr[i].push_back(mp(x-l,r-m-));
c[h++]=b[x++];
}
while (y<=r)
{
tr[i].push_back(mp(m-l,y-m-));
c[h++]=b[y++];
}
for (int k=l; k<=r; k++) b[k]=c[k];
}
} void push(int i)
{
int x=laz[i];
laz[i*]=(x==-)?-:tr[i][x].first;
laz[i*+]=(x==-)?-:tr[i][x].second;
s[i*]=laz[i*]+;
s[i*+]=laz[i*+]+;
laz[i]=-;
} void add(int i,int l,int r,int x,int y,int w)
{
if (x<=l&&y>=r)
{
laz[i]=w;
s[i]=w+;
}
else {
int m=(l+r)>>;
if (laz[i]>-) push(i);
if (x<=m) add(i*,l,m,x,y,w==-?w:tr[i][w].first);
if (y>m) add(i*+,m+,r,x,y,w==-?w:tr[i][w].second);
s[i]=s[i*]+s[i*+];
}
} int ask(int i,int l,int r,int x,int y)
{
if (x<=l&&y>=r) return s[i];
else {
int m=(l+r)>>,ans=;
if (laz[i]>-) push(i);
if (x<=m) ans+=ask(i*,l,m,x,y);
if (y>m) ans+=ask(i*+,m+,r,x,y);
return ans;
}
} int main()
{
int cas;
scanf("%d",&cas);
while (cas--)
{
scanf("%d%d%d%d",&n,&m,&A,&B);
for (int i=; i<=n; i++) scanf("%d",&a[i]);
for (int i=; i<=n; i++) scanf("%d",&b[i]);
b[n+]=;
build(,,n);
int ans=;
int last=;
for (int i=; i<=m; i++)
{
int l=rnd(last,A,B)%n+,r=rnd(last,A,B)%n+,x=rnd(last,A,B)+;
if (l>r) swap(l,r);
if ((l+r+x)%==)
{
last=ask(,,n,l,r);
ans=(ans+1ll*i*last%mo)%mo;
}
else {
int w=upper_bound(b+,b++n,x)-b-;
add(,,n,l,r,w);
}
}
printf("%d\n",ans);
}
}