[ZJOI2009]对称的正方形(矩阵哈希+二分)

对称的正方形

时间限制: 1000 M S 1000MS 1000MS 内存限制: 128 M B 128 MB 128MB

问题描述

O r e z Orez Orez很喜欢搜集一些神秘的数据,并经常把它们排成一个矩阵进行研究。最近, O r e z Orez Orez又得到了一些数据,并已经把它们排成了一个 n n n m m m列的矩阵。通过观察, O r e z Orez Orez发现这些数据蕴涵了一个奇特的数,就是矩阵中上下对称且左右对称的正方形子矩阵的个数。 O r e z Orez Orez自然很想知道这个数是多少,可是矩阵太大,无法去数。只能请你编个程序来计算出这个数。

输入格式

文件的第一行为两个整数 n n n m m m。接下来 n n n行每行包含 m m m个正整数,表示 O r e z Orez Orez得到的矩阵。

输出格式

文件中仅包含一个整数 a n s w e r answer answer,表示矩阵中有 a n s w e r answer answer个上下左右对称的正方形子矩阵。

样例输入

5 5 5 5 5 5
4 4 4 2 2 2 4 4 4 4 4 4 4 4 4
3 3 3 1 1 1 4 4 4 4 4 4 3 3 3
3 3 3 5 5 5 3 3 3 3 3 3 3 3 3
3 3 3 1 1 1 5 5 5 3 3 3 3 3 3
4 4 4 2 2 2 1 1 1 2 2 2 4 4 4

样例输出

27 27 27

数据范围

对于 30 30 30%的数据 n , m ≤ 100 n,m≤100 nm100
对于 100 100 100%的数据 n , m ≤ 1000 n,m≤1000 nm1000 ,矩阵中的数的大小 ≤ 1 0 9 ≤10^9 109


解析

听说有很多神仙是用 M a n a c h e r Manacher Manacher做的 ,然而本蒟蒻并不会
所以我们用一种比较简单粗暴且易于理解的算法—— h a s h hash hash来替代。(说白了就是弱)
首先我们会发现两个很 (显然且) 有用的性质:

  1. 如果一个正方形子矩阵是对称的且边长 > 2 >2 >2,那么比它小一圈的正方形子矩阵也一定是对称的。
  2. 正方形子矩阵的对称中心至多只有 O ( 2 n m ) O(2nm) O(2nm)
    那么我们可以枚举正方形子矩阵的对称中心,并二分此对称中心的最大边长。
    至于判定该正方形子矩阵是否对称,我们可以通过矩阵hash来解决。
    设一个矩阵 a [ 1.. i ] [ 1.. j ] a[1..i][1..j] a[1..i][1..j] h a s h hash hash值为 Σ p 1 i ∗ p 2 j ∗ a [ i ] [ j ] Σp_1^i*p_2^j*a[i][j] Σp1ip2ja[i][j]
    我们把原矩阵、原矩阵上下翻转、原矩阵左右翻转分别做一次 h a s h hash hash,判定时只要把对应矩阵的 h a s h hash hash值用二维前缀和求出来并简单处理一下行差、列差对 p 1 , p 2 p_1,p_2 p1,p2乘方次数的影响之后判断是否相等即可。

T i p s : Tips: Tips:

  1. 本题时限较紧,请提前预处理 p 1 i , p 2 j p_1^i,p_2^j p1i,p2j
  2. 如果你对自己的常数不是非常自信的话请不要写双 h a s h hash hash,写了也不要用 p a i r pair pair cyl大佬已经身先士卒地T了
  3. 别把 m m m打成 n n n,本蒟蒻对此已经不想说什么了
  4. 枚举偶数边长时不一定有答案

代码

#include 
#define ll long long
using namespace std;
const int maxn = 1005;
const int p1 = 29;
const int p2 = 31;
const int mod = 1e9 + 7;
int n , m;
int a[maxn][maxn] , b[maxn][maxn] , c[maxn][maxn];
ll pow_x[maxn] , pow_y[maxn];
int min(int x , int y){return x < y ? x : y;}
int read()
{
    char ch = getchar(); bool f = 1;
    while(ch < '0' || ch > '9') f &= ch != '-' , ch = getchar();
    int res = 0;
    while(ch >= '0' && ch <= '9')  res = (res << 3) + (res << 1) + (ch ^ 48) , ch = getchar();
    return f ? res : -res;
}
void pow_init()
{
    pow_x[0] = pow_y[0] = 1;
    for(int i = 1;i <= n;i++) pow_x[i] = pow_x[i - 1] * p1 % mod;
    for(int i = 1;i <= m;i++) pow_y[i] = pow_y[i - 1] * p2 % mod;
}
struct HASH
{
    private:
        ll s[maxn][maxn];
    public:
        void init(int (*x)[maxn])
        {
            for(int i = 1;i <= n;i++)
                for(int j = 1;j <= m;j++)
                {
                    ll tmp = pow_x[i] * pow_y[j] % mod * x[i][j] % mod;
                    s[i][j] = ((s[i - 1][j] + s[i][j - 1]) % mod - s[i - 1][j - 1] + tmp + mod) % mod;
                }
        }
        ll sum(int bx , int by , int ex , int ey){return ((s[ex][ey] - s[bx - 1][ey] - s[ex][by - 1] + s[bx - 1][by - 1]) % mod + mod) % mod;}
}hash1 , hash2 , hash3;
bool check1(int len , int i , int j)
{
    int bx1 = i - len , by1 = j - len , ex1 = i + len , ey1 = j + len;
    ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
    int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
    ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
    int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
    ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
    ll res4 = res1;
    if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
    if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
    if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
    if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
    return res1 == res2 && res4 == res3;
}
bool check2(int len , int i , int j)
{
    int bx1 = i - len , by1 = j - len , ex1 = i + len + 1 , ey1 = j + len + 1;
    ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
    int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
    ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
    int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
    ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
    ll res4 = res1;
    if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
    if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
    if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
    if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
    return res1 == res2 && res4 == res3;
}
int main()
{
    n = read() , m = read();
    for(int i = 1;i <= n;i++)
        for(int j = 1;j <= m;j++) a[i][j] = b[i][m - j + 1] = c[n - i + 1][j] = read();
    pow_init();
    hash1.init(a) , hash2.init(b) , hash3.init(c);
    int ans = 0;
    for(int i = 1;i <= n;i++)
        for(int j = 1;j <= m;j++)
        {
            int l = 0 , r = min(min(i - 1 , n - i) , min(j - 1 , m - j));
            while(l < r)
            {
                int mid = l + r + 1 >> 1;
                if(check1(mid , i , j)) l = mid;
                else r = mid - 1;
            }
            ans += l + 1;
        }
    for(int i = 1;i < n;i++)
        for(int j = 1;j < m;j++)
        {
            int l = 0 , r = min(min(i - 1 , n - i - 1) , min(j - 1 , m - j - 1)) , res = -1;
            while(l <= r)
            {
                int mid = l + r >> 1;
                if(check2(mid , i , j)) res = mid , l = mid + 1;
                else r = mid - 1;
            }
            ans += res + 1;
        }
    printf("%d\n",ans);
    return 0;
}

你可能感兴趣的:(二分,哈希)