bzoj1414 [ZJOI2009]对称的正方形(二分答案+二维哈希)

首先我们考虑偶数个点和奇数个点的方阵枚举中心方式不太相同,我们用类似manacher的处理方法,填上一堆0,把他们全都变成奇数的情况。然后我们枚举每一个点作为中心,二分答案找到以这个点为中心最大的合法方阵。就可以直接统计这个点对答案的贡献了。这样已经是 O(n2logn) 的了,我们需要O(1)判断一个方阵是否上下左右均对称。类似不用manacher求最长回文子串的方法,把这个子串镜像过来求最长公共子串,我们分别做出这个矩阵的上下镜面和左右镜面,然后每次就只需要判定这三个方阵区域是否相同。可以用二维hash预处理来O(1)判断。

#include 
#include 
#include 
using namespace std;
#define ll long long
#define N 2010
#define inf 0x3f3f3f3f
#define k1 1000003
#define k2 101
#define uint unsigned int 
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
    return x*f;
}
int n,m,a[N][N],ans=0;
uint hs[3][N][N],bin1[N],bin2[N];
inline uint hash(int op,int x1,int y1,int x2,int y2){
    uint res=hs[op][x2][y2]-hs[op][x1-1][y2]*bin1[x2-x1+1]-hs[op][x2][y1-1]*bin2[y2-y1+1];
    return res+hs[op][x1-1][y1-1]*bin2[y2-y1+1]*bin1[x2-x1+1];
}
inline bool jud(int x,int y,int len){
    int x1=x-len+1,x2=x+len-1,y1=y-len+1,y2=y+len-1;
    uint v1=hash(0,x1,y1,x2,y2),v2=hash(1,n-x2+1,y1,n-x1+1,y2),v3=hash(2,x1,m-y2+1,x2,m-y1+1);
    if(v1!=v2||v1!=v3) return 0;
    return 1;
}
int main(){
//  freopen("a.in","r",stdin);
    n=read();m=read();bin1[0]=bin2[0]=1;
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) a[i*2-1][j*2-1]=read();
    n=n*2-1;m=m*2-1;
    for(int i=1;i<=n;++i) bin1[i]=bin1[i-1]*k1;
    for(int i=1;i<=m;++i) bin2[i]=bin2[i-1]*k2;
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) hs[0][i][j]=hs[0][i][j-1]*k2+a[i][j];
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) hs[0][i][j]+=hs[0][i-1][j]*k1;
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) hs[1][i][j]=hs[1][i][j-1]*k2+a[n-i+1][j];
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) hs[1][i][j]+=hs[1][i-1][j]*k1;
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) hs[2][i][j]=hs[2][i][j-1]*k2+a[i][m-j+1];
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j) hs[2][i][j]+=hs[2][i-1][j]*k1;
    for(int i=1;i<=n;++i)
        for(int j=1;j<=m;++j){
            if(i+j&1) continue;
            int l=1,r=min(min(i,n-i+1),min(j,m-j+1));
            while(l<=r){
                int mid=l+r>>1;
                if(jud(i,j,mid)) l=mid+1;else r=mid-1;
            }l--;if(i&1) ans+=l+1>>1;else ans+=l>>1;
        }printf("%d\n",ans);
    return 0;
}

你可能感兴趣的:(bzoj,二分答案,Hash,manacher)