对称的正方形
时间限制: 1000 M S 1000MS 1000MS 内存限制: 128 M B 128 MB 128MB
问题描述
O r e z Orez Orez很喜欢搜集一些神秘的数据,并经常把它们排成一个矩阵进行研究。最近, O r e z Orez Orez又得到了一些数据,并已经把它们排成了一个 n n n行 m m m列的矩阵。通过观察, O r e z Orez Orez发现这些数据蕴涵了一个奇特的数,就是矩阵中上下对称且左右对称的正方形子矩阵的个数。 O r e z Orez Orez自然很想知道这个数是多少,可是矩阵太大,无法去数。只能请你编个程序来计算出这个数。
输入格式
文件的第一行为两个整数 n n n和 m m m。接下来 n n n行每行包含 m m m个正整数,表示 O r e z Orez Orez得到的矩阵。
输出格式
文件中仅包含一个整数 a n s w e r answer answer,表示矩阵中有 a n s w e r answer answer个上下左右对称的正方形子矩阵。
样例输入
5 5 5 5 5 5
4 4 4 2 2 2 4 4 4 4 4 4 4 4 4
3 3 3 1 1 1 4 4 4 4 4 4 3 3 3
3 3 3 5 5 5 3 3 3 3 3 3 3 3 3
3 3 3 1 1 1 5 5 5 3 3 3 3 3 3
4 4 4 2 2 2 1 1 1 2 2 2 4 4 4样例输出
27 27 27
数据范围
对于 30 30 30%的数据 n , m ≤ 100 n,m≤100 n,m≤100
对于 100 100 100%的数据 n , m ≤ 1000 n,m≤1000 n,m≤1000 ,矩阵中的数的大小 ≤ 1 0 9 ≤10^9 ≤109
听说有很多神仙是用 M a n a c h e r Manacher Manacher做的 ,然而本蒟蒻并不会 。
所以我们用一种比较简单粗暴且易于理解的算法—— h a s h hash hash来替代。(说白了就是弱)
首先我们会发现两个很 (显然且) 有用的性质:
T i p s : Tips: Tips:
#include
#define ll long long
using namespace std;
const int maxn = 1005;
const int p1 = 29;
const int p2 = 31;
const int mod = 1e9 + 7;
int n , m;
int a[maxn][maxn] , b[maxn][maxn] , c[maxn][maxn];
ll pow_x[maxn] , pow_y[maxn];
int min(int x , int y){return x < y ? x : y;}
int read()
{
char ch = getchar(); bool f = 1;
while(ch < '0' || ch > '9') f &= ch != '-' , ch = getchar();
int res = 0;
while(ch >= '0' && ch <= '9') res = (res << 3) + (res << 1) + (ch ^ 48) , ch = getchar();
return f ? res : -res;
}
void pow_init()
{
pow_x[0] = pow_y[0] = 1;
for(int i = 1;i <= n;i++) pow_x[i] = pow_x[i - 1] * p1 % mod;
for(int i = 1;i <= m;i++) pow_y[i] = pow_y[i - 1] * p2 % mod;
}
struct HASH
{
private:
ll s[maxn][maxn];
public:
void init(int (*x)[maxn])
{
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++)
{
ll tmp = pow_x[i] * pow_y[j] % mod * x[i][j] % mod;
s[i][j] = ((s[i - 1][j] + s[i][j - 1]) % mod - s[i - 1][j - 1] + tmp + mod) % mod;
}
}
ll sum(int bx , int by , int ex , int ey){return ((s[ex][ey] - s[bx - 1][ey] - s[ex][by - 1] + s[bx - 1][by - 1]) % mod + mod) % mod;}
}hash1 , hash2 , hash3;
bool check1(int len , int i , int j)
{
int bx1 = i - len , by1 = j - len , ex1 = i + len , ey1 = j + len;
ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
ll res4 = res1;
if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
return res1 == res2 && res4 == res3;
}
bool check2(int len , int i , int j)
{
int bx1 = i - len , by1 = j - len , ex1 = i + len + 1 , ey1 = j + len + 1;
ll res1 = hash1.sum(bx1 , by1 , ex1 , ey1);
int bx2 = bx1 , by2 = m - ey1 + 1 , ex2 = ex1 , ey2 = m - by1 + 1;
ll res2 = hash2.sum(bx2 , by2 , ex2 , ey2);
int bx3 = n - ex1 + 1 , by3 = by1 , ex3 = n - bx1 + 1 , ey3 = ey1;
ll res3 = hash3.sum(bx3 , by3 , ex3 , ey3);
ll res4 = res1;
if(by1 > by2) res2 = (res2 * pow_y[by1 - by2]) % mod;
if(by1 < by2) res1 = (res1 * pow_y[by2 - by1]) % mod;
if(bx1 > bx3) res3 = (res3 * pow_x[bx1 - bx3]) % mod;
if(bx1 < bx3) res4 = (res4 * pow_x[bx3 - bx1]) % mod;
return res1 == res2 && res4 == res3;
}
int main()
{
n = read() , m = read();
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++) a[i][j] = b[i][m - j + 1] = c[n - i + 1][j] = read();
pow_init();
hash1.init(a) , hash2.init(b) , hash3.init(c);
int ans = 0;
for(int i = 1;i <= n;i++)
for(int j = 1;j <= m;j++)
{
int l = 0 , r = min(min(i - 1 , n - i) , min(j - 1 , m - j));
while(l < r)
{
int mid = l + r + 1 >> 1;
if(check1(mid , i , j)) l = mid;
else r = mid - 1;
}
ans += l + 1;
}
for(int i = 1;i < n;i++)
for(int j = 1;j < m;j++)
{
int l = 0 , r = min(min(i - 1 , n - i - 1) , min(j - 1 , m - j - 1)) , res = -1;
while(l <= r)
{
int mid = l + r >> 1;
if(check2(mid , i , j)) res = mid , l = mid + 1;
else r = mid - 1;
}
ans += res + 1;
}
printf("%d\n",ans);
return 0;
}