前缀和通常用来解决区间(求和、乘积、异或和)问题,单次询问从O(n)变成O(1)。
例题: 1838. 最高频元素的频数
题解: [LeetCode解题报告] 1838. 最高频元素的频数
排序可以让差别最小的元素放在一起。
class Solution:
def maxFrequency(self, nums: List[int], k: int) -> int:
nums.sort()
n = len(nums)
presum = list(accumulate(nums,initial=0))
def sum_interval(i,j):
return presum[j+1]-presum[i]
def calc_s(i,j):
return nums[j]*(j-i+1)-sum_interval(i,j)
ans = 1
l=r=0
while l<n and r<n:
s = calc_s(l,r)
if s <= k:
ans = max(ans,r-l+1)
r += 1
else:
l += 1
if l==r:
r += 1
return ans
链接: 1590. 使数组和能被 P 整除
题解: [LeetCode解题报告] 1590. 使数组和能被 P 整除
同余
需要一点数论的知识,看了题解才会。
class Solution:
def minSubarray(self, nums: List[int], p: int) -> int:
n = len(nums)
total = sum(nums)
mod = total%p
if mod == 0:
return 0
elif n == 1:
return -1
# 如果某个子数组的和%p==mod,就可以移除。
# 计算前缀和的过程中,用字典记录前边每个和模p结果的最近下标,
# 如果当前模-mod存在,那么这个差就是这个区间的和,模==mod
ans = n # 不允许移除整个数组
presum = 0
pres = {0:-1}
for i in range(n):
presum += nums[i]
cur_mod = presum%p
need_mod = (cur_mod+p-mod)%p
# print(presum,cur_mod,need_mod)
need_idx = pres.get(need_mod,-2)
if need_idx != -2 :
ans = min(ans,i-need_idx)
if ans == 1:
return 1
pres[cur_mod] = i
if ans == n: # 不允许移除整个数组
return -1
return ans
1589. 所有排列中的最大和
[LeetCode解题报告] 1589. 所有排列中的最大和
实际上是求每个点的访问次数,求完后排序,频次越高的位置放更大的数。
因此数据和频次都排序,求卷积即可。
其实就是IUOP,利用差分数组。类似树状数组。
树状数组能AC此题,多一层lg的时间复杂度。
class Solution:
def maxSumRangeQuery(self, nums: List[int], requests: List[List[int]]) -> int:
n = len(nums)
mod = 10**9+7
nums.sort(reverse=True)
diff = [0]*(n+1)
for start,end in requests:
diff[start] += 1
diff[end+1] -= 1
freq = [0] * n
s = 0
for i in range(n):
s += diff[i]
freq[i] = s
ans = 0
freq.sort(reverse=True)
for i in range(n):
if freq[i] ==0:
break
ans += freq[i]*nums[i]%mod
ans %= mod
return ans
链接: 1712. 将数组分成三个子数组的方案数
枚举中间区间的左右端点即可。
参考链接: [LeetCode解题报告] 1712. 将数组分成三个子数组的方案数
class Solution:
def waysToSplit(self, nums: List[int]) -> int:
n = len(nums)
mod = 10**9+7
presum = list(accumulate(nums,initial=0))
def sum_interval(i,j):
return presum[j+1]-presum[i]
ans = 0
# 随着i右移,j_min是右移的,j_max右移的,所以可以三指针寻找。O(n)
j_min,j_max = 1,1
for i in range(1,n-1):
left = sum_interval(0,i-1)
j_min = max(j_min,i)
while j_min<n-1 and sum_interval(i,j_min) < left:
j_min += 1
# print(i,j_min,j_max,ans,left,sum_interval(i,j_min),sum_interval(j_min+1,n-1))
if j_min >= n-1 :
return ans
if sum_interval(j_min+1,n-1) < sum_interval(i,j_min):
continue
while j_max+1<n-1 and sum_interval(i,j_max+1) <= sum_interval(j_max+2,n-1):
j_max += 1
ans += j_max-j_min+1
ans %= mod
return ans
链接: 1442. 形成两个异或相等数组的三元组数目
证明复杂一些,代码倒是短。
参考链接: [LeetCode解题报告] 1442. 形成两个异或相等数组的三元组数目
class Solution:
def countTriplets(self, arr: List[int]) -> int:
n = len(arr)
presum = list(accumulate(arr,initial=0,func=xor))
def sum_interval(i,j):
return presum[j+1]^presum[i]
cnt,total = defaultdict(int),defaultdict(int)
ans = 0
for k in range(0,n):
# presum[k+1] == presum[i] 则[i,k]中间j任意,都是满足需求的三元组
s = presum[k+1]
if s in cnt:
ans += cnt[s]*k - total[s]
cnt[presum[k]] += 1
total[presum[k]] += k
return ans