树状数组:主要是用于求动态区间连续和。其时间复杂度为logn。
线段树:其是树状数组的plus。
动态求连续区间和(原题链接)
线段树和树状数组都可做
代码如下:
#1264.动态求连续区间和
#定义树状数组的函数:lowbit,add,query
n,m=map(int,input().split())
lst=list(map(int,input().split()))
tree_list=[0 for _ in range(n+1)]
def lowbit(x):
return x&-x
def add(x,v):
while x<n+1:
tree_list[x]+=v
x+=lowbit(x)
def query(x):
res=0
while x>0:
res+=tree_list[x]
x-=lowbit(x)
return res
#初始化
for i in range(1,n+1):
add(i,lst[i-1])
for _ in range(m):
k,a,b=map(int,input().split())
if k:
add(a,b)
else:
print(query(b)-query(a-1))
#接下来利用线段树进行求解
#用线段树进行求解
#线段树本质上还是作为二叉树进行使用
[n,m]=list(map(int,input().split()))
lst=list(map(int,input().split()))
#定义节点参数
class Node:
def __init__(self,l=0,r=0,s=0) :
self.l=l
self.r=r
self.sum=s
N = 100010
tr = [Node() for _ in range(4 * N)]#线段树数组
def push_up(u):
#u:根节点的编号
#return 左右儿子的和
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum
def bulid(u,l,r):
"""
u:当前节点编号
l:当前节点的左边界
r:当前节点的右边界
"""
if l==r:#叶节点
tr[u]=Node(l,r,lst[r])
else:
tr[u]=Node(l,r)
mid=(l+r)>>1
bulid(u<<1,l,mid)
bulid(u<<1|1,mid+1,r)
push_up(u)
def query(u,l,r):
"""
u:节点编号
l:左端点
r:右端点
return:返回区间和
"""
if tr[u].l>=l and tr[u].r<=r:
return tr[u].sum
mid=(tr[u].l+tr[u].r)>>1
sum=0
#这一部分类似二分的思想
if l<=mid:
#和左子节点有交集
sum+=query(u<<1,l,r)
if r>mid:
sum+=query(u<<1|1,l,r)
return sum
def modify(u,x,v):
"""
u:节点编号
x:插入位置
v:插入的值
"""
if tr[u].l==tr[u].r:
#叶节点
tr[u].sum+=v
else:
mid=(tr[u].l+tr[u].r)>>1
if x<=mid:
modify(u<<1,x,v)
else:
modify(u<<1|1,x,v)
push_up(u)
#初始化线段树
bulid(1,0,n-1)
for _ in range(m):
k, a, b = map(int, input().split())
if k == 0:
print(query(1, a-1, b-1))
else:
modify(1, a-1, b)
数星星(原题链接)
利用线段树的思维去理解。由于输入数据按照y轴递增,故我们阔以动态维护树状数组,就阔以获得每个星星其下左的星星数,也就是级数。
代码如下:
n=int(input())
tr=[0 for _ in range(32010)]
def lowbit(x):
return x&-x
def add(x,v):
while x<32010:
tr[x]+=v
x+=lowbit(x)
def query(x):
s=0
while x>0:
s+=tr[x]
x-=lowbit(x)
return s
level_lst=[0 for _ in range(n)]
for _ in range(n):
x,_=map(int,input().split())
add(x+1,1)
level_lst[query(x+1)-1]+=1
for i in level_lst:
print(i)
数列区间最大值(原题链接)
利用线段树的思维去理解。但python由于IO操作过于繁琐,故无法AC
代码如下:
#1270.数列区间最大值
n,m=map(int,input().split())
lst=list(map(int,input().split()))
#用线段树的思想去解决
class Node():
def __init__(self,l=0,r=0,m=0):
self.l=l
self.r=r
self.max=m
tr=[Node() for _ in range(4*n)]
def push_up(u):
tr[u].max=max(tr[u<<1].max,tr[u<<1|1].max)
def bulid(u,l,r):
if l==r:
tr[u]=Node(l,r,lst[r])
else:
tr[u]=Node(l,r)#未初始化!!!
mid=l+r>>1
bulid(u<<1,l,mid)
bulid(u<<1|1,mid+1,r)
push_up(u)
def query(u,l,r):
if tr[u].l>=l and tr[u].r<=r:
return tr[u].max
left=right=-1
mid=(tr[u].l+tr[u].r)>>1
if l<=mid:
#与左边有交集
left=query(u<<1,l,r)
if r>mid:
right=query(u<<1|1,l,r)
return max(left,right)
#初始化线段树
bulid(1,0,n-1)
while m:
a,b=map(int,input().split())
print(query(1,a-1,b-1))
m-=1
#但是我们发现python的版本无法AC
小朋友排队(原题链接)
我们知道该题是冒泡排序照进现实,同时需要明白我们直接利用冒泡排序来进行统计是不行的,其并非有针对性的指针。故我们需要引入一个树状数组进行动态维护。
代码如下:
n=int(input())
lst=list(map(lambda x:int(x)+1,input().split()))
k_list=[0 for _ in range(n)]
#记录每个数出现的次数,此处采取动态维护,所以是按照顺序进行记录的
tree_list=[0 for _ in range(1000010)]
def lowbit(x):
return x&-x
def add(x,v):
while x<1000010:
tree_list[x]+=v
x+=lowbit(x)
def query(x):
cnt=0
while x:
cnt+=tree_list[x]
x-=lowbit(x)
return cnt
#统计k1:对于每个位置其前面大于他的数的个数
#这里就可以类比是数星星那个题
for i in range(n):
k_list[i]=query(1000009)-query(lst[i])
add(lst[i],1)
tree_list=[0 for _ in range(1000010)]
#统计k2:后面有多少数比他小
for i in range(n-1,-1,-1):
k_list[i]+=query(lst[i]-1)
add(lst[i],1)
res=0
for num in k_list:
res+=num*(num+1)/2
res=int(res)
print(res)
油漆面积(原题链接)
get到新的技能:扫描线。如何利用扫描线+线段树来解决油漆面积问题。
代码如下:
#表示与纵轴相平行的线段
class seg:
def __init__(self,x=0,y1=0,y2=0,k=0) :
"""
x:表示线段index
y1:表示下y
y2:表示上y
k:-1表示出,+1表示进
"""
self.x,self.y1,self.y2,self.k=x,y1,y2,k
#构建线段树节点
class Node:
def __init__(self,l=0,r=0,cnt=0,length=0) :
self.l,self.r,self.cnt,self.len=l,r,cnt,length
#cnt:完整覆盖区间[l,r]的次数
#len:表示被覆盖的长度
def push_up(u):
if t[u].cnt>0:
#说明被完全覆盖了
t[u].len=t[u].r-t[u].l+1
elif t[u].l==t[u].r:
t[u].len=0#叶节点,len不是1,因为该节点未被覆盖(上面的if跑掉了)
else:
t[u].len=t[u<<1].len+t[u<<1|1].len
def bulid(u,l,r):
t[u]=Node(l,r)
if l==r:
return
mid=l+r>>1
bulid(u<<1,l,mid)
bulid(u<<1|1,mid+1,r)
def query(u,l,r):
#此处不需要进行查询操作,只需要查询根节点的len
pass
def modify(u,l,r,v):
if t[u].l>=l and t[u].r<=r:
t[u].cnt+=v
else:
mid=t[u].l+t[u].r>>1
if l<=mid:
modify(u<<1,l,r,v)
if r>mid:
modify(u<<1|1,l,r,v)
push_up(u)
N=10**4+10
s=[]
t=[Node() for i in range(N*4)]
n=int(input())
m,res=0,0
for _ in range(n):
x1,y1,x2,y2=map(int,input().split())
s.append(seg(x1,y1,y2,1))
s.append(seg(x2,y1,y2,-1))
m+=2
s.sort(key=lambda e:e.x)
bulid(1,0,10000)#对线段树进行初始化
for i in range(m):
if i:res+=t[1].len*(s[i].x-s[i-1].x)
modify(1,s[i].y1,s[i].y2-1,s[i].k)
print(res)
三体攻击(原题链接)
利用三维前缀和来求解,更为简便。差分对于一个矩阵的操作数是最少的。
代码如下:
a,b,c,m=map(int,input().split())
def get(i,j,k):
return ((i-1)*b+j-1)*c+k-1
"""
s:原数组
b_:初始差分数组
op:行为矩阵
tmp:会改变的差分数组
add:对lst进行差分处理
"""
def add(x1,x2,y1,y2,z1,z2,c,lst):
lst[get(x1,y1,z1)]+=c
lst[get(x1,y1,z2+1)]-=c
lst[get(x1,y2+1,z1)]-=c
lst[get(x1,y2+1,z2+1)]+=c
lst[get(x2+1,y1,z1)]-=c
lst[get(x2+1,y1,z2+1)]+=c
lst[get(x2+1,y2+1,z1)]+=c
lst[get(x2+1,y2+1,z2+1)]-=c
def check(mid):
#计算前缀和来进行判断
tmp=b_.copy()
for i in range(mid):
[la,ra,lb,rb,lc,rc,h]=op[i]
#对tmp进行构造
add(la,ra,lb,rb,lc,rc,-h,tmp)
#print(tmp[:get(a,b,c)])
global s
#构造前缀和
s=[0 for _ in range(len(s))]
for i in range(1,a+1):
for j in range(1,b+1):
for k in range(1,c+1):
s[get(i,j,k)]=tmp[get(i,j,k)]+s[get(i-1,j,k)]+s[get(i,j-1,k)]+s[get(i,j,k-1)]-s[get(i-1,j-1,k)]-s[get(i-1,j,k-1)]-s[get(i,j-1,k-1)]+s[get(i-1,j-1,k-1)]
#print(s[get(i,j,k)])
if s[get(i,j,k)]<0:return True
return False
s=list(map(int,input().split()))
s.extend([0 for _ in range(1500000-len(s))])
b_=[0 for _ in range(len(s))]
#构建初始差分数组
for i in range(1,a+1):
for j in range(1,b+1):
for k in range(1,c+1):
add(i,i,j,j,k,k,s[get(i,j,k)],b_)
#print('b_:',b_[:get(a,b,c)])
op=[]
for _ in range(m):
op.append(list(map(int,input().split())))
l=-1
r=m
#二分查找答案
while l!=r-1:
mid=l+r>>1
if check(mid):
r=mid
else:
l=mid
print(r)
螺旋折线(原题链接)
数学规律。我们阔以将折线掰下来看,就很简单了。
代码如下:
#1237.螺旋线
x,y=map(int,input().split())
#求出max
n=max(abs(x),abs(y))
in_sum=8*(n*(n+1)//2)
if x==-n and y>-n:
print(in_sum-(8*n-(y+n)))
elif x>-n and y==n:
print(in_sum-(8*n-(y+n))+(x+n))
elif x==n and y<n:
print(in_sum-4*n+(n-y))
elif x<n and y==-n:
print(in_sum-2*n+(n-x))
难的要死。