题目链接:https://www.lanqiao.cn/paper/4403/problem/19720/
根据题目的意思,每一行有一个 R R R 值,每一列有一个 C C C 值。合法的操作的定义是,从 R 1 R_1 R1 这一行跳到一个恰好比 R 1 R_1 R1 大的 R 2 R_2 R2 这一行,而 R 1 R_1 R1 和 R 2 R_2 R2 之间不可以有 R R R 值。列的跳法同理。
那么思路分为两步。
首先,假设位于区间 ( R 1 , R 2 ) (R_1,R_2) (R1,R2) 之间的 R R R 值数有 r r r 个,记为 R m 1 , ⋯ , R m r R_{m1},\cdots,R_{mr} Rm1,⋯,Rmr,假设它们所对应的行的总数有 r r r 个。而 R i R_i Ri 对应的行一共有 r i r_i ri 个。
比如 R 1 = 3 , R 2 = 7 R_1 = 3, R_2 = 7 R1=3,R2=7 之间有一个 R = 4 R=4 R=4 对应 2 2 2 行,而 R = 6 R=6 R=6 对应 3 3 3 行,那么 R m 1 = 4 , R m 2 = 6 , r = 2 R_{m1}=4,R_{m2}=6,r=2 Rm1=4,Rm2=6,r=2。
那么这意味着光考虑行的话,从起点到终点恰好是跳了 r + 1 r + 1 r+1 次,第一次的落脚点(即落在哪一行)有 r m 1 r_{m1} rm1 种可能,第二次的落脚点有 r m 2 r_{m2} rm2 种可能……最后一次落脚只能落在终点。所以根据乘法原理,落脚行的选法有 ∏ i = 1 r r m i \prod_{i=1}^r r_{mi} ∏i=1rrmi 种。
同理,从起点到终点的过程中,经过了哪些列呢?落脚列的选法有 ∏ i = 1 c c m i \prod_{i=1}^c c_{mi} ∏i=1ccmi 种。
根据上面的分析,从起点到终点的 r + c + 2 r+c+2 r+c+2 步中,一共有 r + 1 r+1 r+1 次跳跃是切换行,而 c + 1 c+1 c+1 次跳跃是在切换列。因此我们这里需要一个组合数,表示这 r + c + 2 r+c+2 r+c+2 步中哪些是在切换行—— C ( r + c + 2 , r + 1 ) C(r+c+2,r+1) C(r+c+2,r+1)。
据此,答案就是 ∏ i = 1 r r m i × ∏ i = 1 c c m i × C ( r + c + 2 , r + 1 ) \prod _ {i=1}^r r _ {mi}\times\prod _ {i=1}^c c _ {mi}\times C(r+c+2,r+1) ∏i=1rrmi×∏i=1ccmi×C(r+c+2,r+1) 再取个模。
求 ∏ i = 1 r r m i \prod _ {i=1}^r r _ {mi} ∏i=1rrmi 或者 ∏ i = 1 c c m i \prod _ {i=1}^c c _ {mi} ∏i=1ccmi 的过程中,我们不可能每次真的进行一个连乘,可以使用线段树或者倍增技巧进行优化。我用的是后者,先用 O ( n log n ) O(n\log n) O(nlogn)(对于行来说)的复杂度初始化一个倍增数组,然后每次查询乘积只需要 O ( log n ) O(\log n) O(logn) 的复杂度。
然后就是求组合数 C ( n , k ) C(n,k) C(n,k) 的问题。当然有组合数的递推公式,比如 C ( n , k ) = C ( n − 1 , k ) + C ( n − 1 , k − 1 ) C(n,k)=C(n-1,k)+C(n-1,k-1) C(n,k)=C(n−1,k)+C(n−1,k−1),但是它的时间复杂度高,不行。如果用动态规划初始化一个数组存储所有的组合数,时空复杂度都太高。注意到我们这里的 m o d = 1 0 9 + 7 \mathit{mod}=10^9+7 mod=109+7 是一个质数,我们可以采用乘法逆元的方式。
关于乘法逆元,可以参考这篇博客。
C ( n , k ) = n ! k ! ( n − k ) ! C(n,k)=\cfrac{n!}{k!(n-k)!} C(n,k)=k!(n−k)!n!
观察到上面的组合数公式,我们可以用 O ( n ) O(n) O(n) 的复杂度初始化一个阶乘数组fact
,其中fact[i] = i!
。可以初始化一个阶乘的逆元数组invs
,其中invs[i]
存储的是 i ! i! i! 模 1 0 9 + 7 10^9+7 109+7 的逆元。这个逆元数组的初始化复杂度是 O ( n log n ) O(n\log n) O(nlogn)。
在初始化数组fact
和invs
结束之后,求组合数就是 O ( 1 ) O(1) O(1) 的时间复杂度了。
算法的总时间复杂度: O ( ( n + m ) log ( n + m ) + T log n m ) O((n+m)\log(n+m)+T\log nm) O((n+m)log(n+m)+Tlognm)。
import os
import sys
def make_array(length,val):
return [val for _ in range(length)]
def make_2d_array(rows,cols,val):
return [[val for _ in range(cols)] for _ in range(rows)]
def make_3d_array(d1,d2,d3,val):
return [[[val for _ in range(d3)] for _ in range(d2)] for _ in range(d1)]
def read_int():
return int(input())
def read_ints():
return [int(i) for i in input().split()]
mod = int(1e9) + 7
n,m,T = read_ints()
dup_rows = []
prod_rows = make_2d_array(n,18,0)
rrfl = make_array(n,0)
last = -1
rr = read_ints()
for idx,it in enumerate(sorted(list(enumerate(rr)),key=lambda x:x[1])):
# rows: [(idx,val),...]
if it[1] != last:
last = it[1]
dup_rows.append([last,1])
else:
dup_rows[-1][1] += 1
rrfl[it[0]] = len(dup_rows) - 1
for i,val in enumerate(dup_rows):
prod_rows[i][0] = val[1]
for i in range(1,18):
for j in range(0,len(dup_rows) - (1 << i) + 1):
prod_rows[j][i] = prod_rows[j][i - 1] * prod_rows[j + (1 << (i - 1))][i - 1] % mod
dup_cols = []
crfl = make_array(m,0)
prod_cols = make_2d_array(m,18,0)
last = -1
cc = read_ints()
for idx,it in enumerate(sorted(list(enumerate(cc)),key=lambda x:x[1])):
if it[1] != last:
last = it[1]
dup_cols.append([last,1])
else:
dup_cols[-1][1] += 1
crfl[it[0]] = len(dup_cols) - 1
for i,val in enumerate(dup_cols):
prod_cols[i][0] = val[1]
for i in range(1,18):
for j in range(0,len(dup_cols) - (1 << i) + 1):
prod_cols[j][i] = prod_cols[j][i - 1] * prod_cols[j + (1 << (i - 1))][i - 1] % mod
def cal_prod(begin,count,arr):
ans = 1;base = 0
while count > 0:
if count & 1:
ans = ans * arr[begin][base] % mod
begin += 1 << base
count >>= 1
base += 1
return ans
fact = make_array(n + m + 5,1)
for i in range(2,n + m + 5):
fact[i] = fact[i - 1] * i % mod
invs = make_array(n + m + 5,1)
def extended_gcd(a, b):
if a == 0:
return b, 0, 1
gcd, x1, y1 = extended_gcd(b % a, a)
x = y1 - (b // a) * x1
y = x1
return gcd, x, y
def mod_inverse(b, k=mod):
gcd, x, _ = extended_gcd(b, k)
if gcd != 1:
raise ValueError(f"No modular inverse for {b} mod {k}")
return x % k
for i in range(n + m + 5):
invs[i] = mod_inverse(fact[i])
def C(tol,choice):
return fact[tol] * invs[tol - choice] * invs[choice] % mod
for ct in range(T):
ans = 1
r1,c1,r2,c2 = [i - 1 for i in read_ints()]
# if ct == 11:
# print("debug: r1,r2,c1,c2:",r1,r2,c1,c2)
# print("debug: len(rr),len(cc):",len(rr),len(cc))
# print("debug:",rrfl[r1],rrfl[r2],crfl[c1],crfl[c2])
# print("debug:",rr[r1],rr[r2],cc[c1],cc[c2])
# exit(0)
if rrfl[r1] >= rrfl[r2]:
if r1 != r2:
print(0)
continue
else:
ans *= cal_prod(rrfl[r1] + 1,rrfl[r2] - rrfl[r1] - 1,prod_rows)
if crfl[c1] >= crfl[c2]:
if c1 != c2:
print(0)
continue
else:
ans *= cal_prod(crfl[c1] + 1,crfl[c2] - crfl[c1] - 1,prod_cols)
print(ans * C(rrfl[r2] + crfl[c2] - rrfl[r1] - crfl[c1],crfl[c2] - crfl[c1]) % mod)