其中
sum[N] 是前缀和,a[i]=1时+1,否则-1
tr1[N] 是以sum[i]为下标(?),存储dp[i]的线段树
tr2[N] 是以 i 为下标,存储dp[i]的线段树
两颗线段树都有单点修改,区间查询最小值的功能。
F 右移用的,避免负数下标。
求大佬找错,或者给个小一点的样例也行,谢谢qwq
#include<bits/stdc++.h>
using namespace std;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}
return x*f;
}
//================================================================//
const int N=1e6+4517,F=502005;
struct node{
int l,r,fm;
}tr1[N<<2],tr2[N<<2];
void pushup1(node &fa,node &ll,node &rr){
fa.fm=min(ll.fm,rr.fm);
return;
}
void build1(int l,int r,int u){
tr1[u].l=l;tr1[u].r=r;tr1[u].fm=N;
if(l!=r){
int mid=(l+r)>>1;
build1(l,mid,u<<1);
build1(mid+1,r,u<<1|1);
}
return;
}
void update1(int x,int u,int d){
if(tr1[u].l==x&&tr1[u].r==x){
tr1[u].fm=min(d,tr1[u].fm);
return;
}
int mid=(tr1[u].l+tr1[u].r)>>1;
if(x<=mid)update1(x,u<<1,d);
else update1(x,u<<1|1,d);
pushup1(tr1[u],tr1[u<<1],tr1[u<<1|1]);
return;
}
int find1(int x,int y,int u){
if(tr1[u].l>=x&&tr1[u].r<=y){
return tr1[u].fm;
}
int mid=(tr1[u].l+tr1[u].r)>>1;
if(y<=mid)return find1(x,y,u<<1);
if(x>mid)return find1(x,y,u<<1|1);
return min(find1(x,y,u<<1),find1(x,y,u<<1|1));
}
//===================================================================//
void pushup2(node &fa,node &ll,node &rr){
fa.fm=min(ll.fm,rr.fm);
return;
}
void build2(int l,int r,int u){
tr2[u].l=l;tr2[u].r=r;tr2[u].fm=N;
if(l!=r){
int mid=(l+r)>>1;
build2(l,mid,u<<1);
build2(mid+1,r,u<<1|1);
}
return;
}
void update2(int x,int u,int d){
if(tr2[u].l==x&&tr2[u].r==x){
tr2[u].fm=min(d,tr2[u].fm);
return;
}
int mid=(tr2[u].l+tr2[u].r)>>1;
if(x<=mid)update2(x,u<<1,d);
else update2(x,u<<1|1,d);
pushup2(tr2[u],tr2[u<<1],tr2[u<<1|1]);
return;
}
int find2(int x,int y,int u){
if(tr2[u].l>=x&&tr2[u].r<=y){
return tr2[u].fm;
}
int mid=(tr2[u].l+tr2[u].r)>>1;
if(y<=mid)return find2(x,y,u<<1);
if(x>mid)return find2(x,y,u<<1|1);
return min(find2(x,y,u<<1),find2(x,y,u<<1|1));
}
//=========================================================================================//
int n,m,a[N],sum[N],dp[N],pre[N],t1=1,t2=1;
int main(){
n=read();m=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++){
if(a[i]==1){
sum[i]=sum[i-1]+1;
pre[i]=t1;
t2=i+1;
}
else{
sum[i]=sum[i-1]-1;
pre[i]=t2;
t1=i+1;
}
dp[i]=N;
}
build1(1,N-459,1);
build2(1,N-459,1);
dp[0]=0;dp[1]=1;
update1(sum[0]+F,1,0);
update2(F,1,0);
update1(sum[1]+F,1,1);
update2(1+F,1,1);
for(int i=2;i<=n;i++){
if(pre[i]==1)dp[i]=1;
else{
dp[i]=min(find1(sum[i]-m+F,sum[i]+m+F,1),find2(pre[i]+F,i+F,1))+1;
}
update1(sum[i]+F,1,dp[i]);
update2(i+F,1,dp[i]);
}
cout<<dp[n];
return 0;
}