对于一个数 a , i , j a,_{i,j} a,i,j的贡献,其实就是 ( n − a 1 中 a i , j 这 个 数 出 现 的 次 数 ) ∗ ( n − a 1 中 a i , j 这 个 数 出 现 的 次 数 ) ∗ ( n − a 2 中 a i , j 这 个 数 出 现 的 次 数 ) ∗ . . ∗ ( n − a i − 1 中 a i , j 这 个 数 出 现 的 次 数 ) (n-a_1中a_{i,j}这个数出现的次数)*(n-a_1中a_{i,j}这个数出现的次数)*(n-a_2中a_{i,j}这个数出现的次数)*..*(n-a_{i-1}中a_{i,j}这个数出现的次数) (n−a1中ai,j这个数出现的次数)∗(n−a1中ai,j这个数出现的次数)∗(n−a2中ai,j这个数出现的次数)∗..∗(n−ai−1中ai,j这个数出现的次数)* n m − i n^{m-i} nm−i
然后直接 O ( n m ) O(nm) O(nm)的枚举 a i , j a_{i,j} ai,j
对所有的 a i , j a_{i,j} ai,j预先 s o r t + u n i q u e sort+unique sort+unique然后后面利用 l o w e r lower lower_ b o u n d bound bound O ( l o g 2 ( n m ) ) O(log_2(nm)) O(log2(nm))去找对应的位置
这个位置预存 ( n − a 1 中 a i , j 这 个 数 出 现 的 次 数 ) ∗ ( n − a 1 中 a i , j 这 个 数 出 现 的 次 数 ) ∗ ( n − a 2 中 a i , j 这 个 数 出 现 的 次 数 ) . . ∗ ( n − a i − 1 中 a i , j 这 个 数 出 现 的 次 数 ) (n-a_1中a_{i,j}这个数出现的次数)*(n-a_1中a_{i,j}这个数出现的次数)*(n-a_2中a_{i,j}这个数出现的次数)..*(n-a_{i-1}中a_{i,j}这个数出现的次数) (n−a1中ai,j这个数出现的次数)∗(n−a1中ai,j这个数出现的次数)∗(n−a2中ai,j这个数出现的次数)..∗(n−ai−1中ai,j这个数出现的次数)
然后 n m − i n^{m-i} nm−i预处理一下
时间复杂度好像是 O ( n m l o g 2 ( n m ) ) O(nmlog_2(nm)) O(nmlog2(nm))
#include
#include
#include
#include
#include
#include
#define N 2005
using namespace std;
typedef long long ll;
const int mo = 1e9 + 7;
int a[N*N], b[N*N], id[N*N][5], mi[N], orz[N], n, m;
ll ans;
void read(int &x)
{
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s >= '0' && s <= '9') { x = x * 10 + (s - '0'); s = getchar(); }
x = x * f;
}
int main()
{
scanf("%d %d", &n, &m);
mi[0] = 1; for (int i = 1; i <= m; i++) mi[i] = (ll)mi[i - 1] * n % mo;
for (int i = 1; i <= m; i++)
for (int j = 1; j <= n; j++)
read(a[n * (i - 1) + j]), b[n * (i - 1) + j] = a[n * (i - 1) + j];
sort(b + 1, b + n * m + 1);
int size = unique(b + 1, b + n * m + 1) - b - 1;
int x, num;
for (int i = 1; i <= m; i++)
{
for (int j = 1; j <= n; j++)
{
x = a[n * (i - 1) + j];
num = lower_bound(b + 1, b + size + 1, x) - b;
if (!id[num][3]) id[num][2] = 1;
if (id[num][1] != i)
{
if (id[num][3]) id[num][2] = (ll)id[num][2] * (n - id[num][4]) % mo;
id[num][4] = 0, id[num][3] += 1, id[num][1] = i;
}
id[num][4]++;
ans = (ans + (ll)id[num][2] * mi[m - id[num][3]] % mo * x % mo) % mo;
}
}
ans = ans % mo;
printf("%lld\n", ans);
}