众所周知,py空间优化很差:1.1e6的数组开不了,这导致了二维dp不滚动过不了;区间dp也不行。
同时,dfs时,py默认限制1000深度(print(sys.getrecursionlimit())),超过了就会报RE:超过最大递归深度限制。
当然可以使用sys.setrecursionlimit(n)来突破这个限制,但是n大了会爆MLE。
于是产生了这个装饰器,来源是群里@Be。
用这个装饰dfs,对dfs返回值进行调整,可以突破py dfs的限制。(依然很占空间,但很多题可以冲了)。
def bootstrap(f, stack=[]):
def wrappedfunc(*args, **kwargs):
if stack:
return f(*args, **kwargs)
else:
to = f(*args, **kwargs)
while True:
if type(to) is GeneratorType:
stack.append(to)
to = next(to)
else:
stack.pop()
if not stack:
break
to = stack[-1].send(to)
return to
return wrappedfunc
例题: CF1032C. Playing Piano
灵神的题解:
纯构造做法
https://www.luogu.com.cn/blog/endlesscheng/solution-cf1032c
每一个上升段和下降段都可以一段段地处理。
对于上升段,让起始值尽量小,每次增长 1。
对于下降段,让起始值尽量大,每次减少 1。
相等的段,元素可以取 2 或 3,这样不会妨碍上升段和下降段的起始值。
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from types import GeneratorType
if sys.hexversion == 50924784:
sys.stdin = open('cfinput.txt')
RI = lambda: map(int, input().split())
else:
input = sys.stdin.readline
input_int = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
RI = lambda: map(int, input_int().split())
RS = lambda: input().strip().split()
RILST = lambda: list(RI())
MOD = 10 ** 9 + 7
"""https://codeforces.com/problemset/problem/1032/C
输入 n (≤1e5) 和一个长为 n 的数组 a (1≤a[i]≤2e5)。
构造一个长为 n 的数组 b,满足:
1. 1≤b[i]≤5;
2. 如果 a[i]a[i+1],则 b[i]>b[i+1];
4. 如果 a[i]=a[i+1],则 b[i]≠b[i+1];
如果不存在这样的 b 则输出 -1,否则输出任意一个满足要求的 b。
1 1 4 2 2
"""
D = list(range(1, 6))
# 124 ms
def solve1(n, a):
if n == 1:
return print(1)
b = [1] * n
if a[0] > a[1]:
b[0] = 5
for i in range(1, n):
x, y = a[i - 1], a[i]
if y < x: # 降
b[i] = b[i - 1] - 1
if b[i] < 1:
return print(-1)
if i < n - 1:
if y < a[i + 1]: # i是谷,则直接置1最优
b[i] = 1
elif y == a[i + 1]:
# i和右边相同:
# 若i+1后要降,则i+1置5最优,i置1;
# 若后边升,不操作:i+1尽量小(1)优,i不变(不要趋向1,但可能已经是1,则i+1只能是2)
if i < n - 2 and a[i + 1] > a[i + 2]:
b[i] = 1
elif y > x: # 升
b[i] = b[i - 1] + 1
if b[i] > 5:
return print(-1)
if i < n - 1:
if y > a[i + 1]:
b[i] = 5
elif y == a[i + 1]:
if i < n - 2 and a[i + 1] < a[i + 2]:
b[i] = 5
else: # 相同
b[i] = 3 if b[i - 1] == 2 else 2 # 先随便置一个数
if i < n - 1:
if y < a[i + 1]: # 若后边升,则置尽量小1/2
b[i] = 2 if b[i - 1] == 1 else 1
elif y > a[i + 1]: # 若后边降,则置尽量大5/4
b[i] = 4 if b[i - 1] == 5 else 5
# else: # 若相同,置一个不耽误右边值变成1/5的数(2/3/4任选)
# b[i] = 3 if b[i - 1] == 2 else 2
print(' '.join(map(str, b)))
# 389 ms
def solve2(n, a):
if n == 1:
return print(1)
f = [[0] * 5 for _ in range(n)]
f[0] = [1] * 5
for i in range(1, n):
x, y = a[i - 1], a[i]
flag = 0
if y < x: # 降
for j in range(5):
if any(k for k in f[i - 1][j + 1:]):
flag = f[i][j] = 1
elif y > x:
for j in range(5):
if any(k for k in f[i - 1][:j]):
flag = f[i][j] = 1
else:
for j in range(5):
if any(f[i - 1][k] for k in range(5) if k != j):
flag = f[i][j] = 1
if not flag:
return print(-1)
# print(f)
b = [f[-1].index(1)]
for i in range(n - 2, -1, -1):
x, y = a[i], a[i + 1]
for j in range(5):
if f[i][j] and ((x < y and j < b[-1]) or (x > y and j > b[-1]) or (x == y and j != b[-1])):
b.append(j)
break
# print(b)
print(' '.join(map(lambda x: str(x + 1), b[::-1])))
# 249 ms
def solve1(n, a):
if n == 1:
return print(1)
f = [[0] * 5 for _ in range(n)]
f[0] = [1] * 5
for i in range(1, n):
x, y = a[i - 1], a[i]
flag = 0
g = f[i - 1]
if y < x: # 降
for j in range(5):
for k in range(j + 1, 5):
if g[k]:
flag = f[i][j] = 1
break
elif y > x:
for j in range(5):
for k in range(j - 1, -1, -1):
if g[k]:
flag = f[i][j] = 1
break
else:
for j in range(5):
for k in range(5):
if k != j and g[k]:
flag = f[i][j] = 1
break
if not flag:
return print(-1)
# print(f)
b = [f[-1].index(1)]
for i in range(n - 2, -1, -1):
x, y = a[i], a[i + 1]
for j in range(5):
if f[i][j] and ((x < y and j < b[-1]) or (x > y and j > b[-1]) or (x == y and j != b[-1])):
b.append(j)
break
# print(b)
print(' '.join(map(lambda x: str(x + 1), b[::-1])))
def bootstrap(f, stack=[]):
def wrappedfunc(*args, **kwargs):
if stack:
return f(*args, **kwargs)
else:
to = f(*args, **kwargs)
while True:
if type(to) is GeneratorType:
stack.append(to)
to = next(to)
else:
stack.pop()
if not stack:
break
to = stack[-1].send(to)
return to
return wrappedfunc
def solve(n, a):
bo = [1] * n
up = [5] * n
for i in range(1, n):
x, y = a[i - 1], a[i]
if y > x:
bo[i] = bo[i - 1] + 1
elif y < x:
up[i] = up[i - 1] - 1
if up[i] < bo[i]:
return print(-1)
for i in range(n - 2, -1, -1):
x, y = a[i], a[i + 1]
if x > y:
bo[i] = bo[i + 1] + 1
elif x < y:
up[i] = up[i + 1] - 1
if up[i] < bo[i]:
return print(-1)
b = [0] * n
# print(a)
# print(up)
# print(bo)
ok = False
def dfs1(i):
# print(b)
if i == n:
nonlocal ok
ok = True
return True
for j in range(bo[i], up[i] + 1):
if i > 0:
if a[i] == a[i - 1] and j == b[i - 1]:
continue
if a[i] > a[i - 1] and j <= b[i - 1]:
continue
# if i==3 :
# print(a[i] ,a[i - 1] , j , b[i])
if a[i] < a[i - 1] and j >= b[i - 1]:
break
b[i] = j
if dfs(i + 1):
return True
return False
@bootstrap
def dfs(i):
nonlocal ok
if i == n:
ok = True
yield None
for j in range(bo[i], up[i] + 1):
if i > 0:
if a[i] == a[i - 1] and j == b[i - 1]:
continue
if a[i] > a[i - 1] and j <= b[i - 1]:
continue
# if i==3 :
# print(a[i] ,a[i - 1] , j , b[i])
if a[i] < a[i - 1] and j >= b[i - 1]:
break
b[i] = j
yield dfs(i + 1)
if ok: # 注意 dfs完成后立刻判断,否则会继续for造成数据错误
break
yield None
dfs(0)
if ok:
print(' '.join(map(str, b)))
else:
print(-1)
if __name__ == '__main__':
n, = RI()
a = RILST()
solve(n, a)
链接: CF337 A. Maze
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from types import GeneratorType
if sys.hexversion == 50924784:
sys.stdin = open('cfinput.txt')
RI = lambda: map(int, input().split())
else:
input = sys.stdin.readline
input_int = sys.stdin.buffer.readline
RI = lambda: map(int, input_int().split())
RS = lambda: input().strip().split()
RILST = lambda: list(RI())
MOD = 10 ** 9 + 7
"""https://codeforces.com/problemset/problem/377/A
输入 n(≤500) m(≤500) k 和一个 n 行 m 列的网格图,'#' 表示墙,'.' 表示平地。
保证所有 '.' 可以互相到达(四方向连通)。保证 k 小于 '.' 的个数。
你需要把恰好 k 个 '.' 修改成 '#',使得剩余的所有 '.' 仍然是可以互相到达的。
输出修改后的网格图。
输入
3 4 2
#..#
..#.
#...
输出
#.X#
X.#.
#...
输入
5 4 5
#...
#.#.
.#..
...#
.#.#
输出
#XXX
#X#.
X#..
...#
.#.#
"""
# bfs 312 ms
def solve1(m, n, k, g):
def print_g():
for b in g:
print(''.join(b))
if k == 0:
return print_g()
points = []
for i in range(m):
for j in range(n):
if g[i][j] == '.':
points.append((i, j))
k = len(points) - k # 只能保留k个点
def get_route():
DIRS = [(0, 1), (0, -1), (1, 0), (-1, 0)]
def inside(x, y):
return 0 <= x < m and 0 <= y < n
q = deque([points[0]])
vis = set(q)
if len(vis) == k:
return vis
while q:
x, y = q.popleft()
for dx, dy in DIRS:
a, b = x + dx, y + dy
if inside(a, b) and g[a][b] == '.' and (a, b) not in vis:
vis.add((a, b))
if len(vis) == k:
return vis
q.append((a, b))
return vis
# 除了路径上的点都改成X
keep = get_route()
for x, y in points:
if (x, y) not in keep:
g[x][y] = 'X'
print_g()
# bfs翻过来做 直接从x修改到.;本题dfs 会爆栈 171 ms
def solve2(m, n, k, g):
def print_g():
for b in g:
print(''.join(b))
if k == 0:
return print_g()
cnt = 0
start = (0, 0)
for i in range(m):
for j in range(n):
if g[i][j] == '.':
g[i][j] = 'X'
if 0 == cnt:
start = (i, j)
cnt += 1
k = cnt - k # 只能保留k个点
DIRS = [(0, 1), (0, -1), (1, 0), (-1, 0)]
def inside(x, y):
return 0 <= x < m and 0 <= y < n
g[start[0]][start[1]] = '.'
k -= 1
if k == 0:
return print_g()
q = deque([start])
while q:
x, y = q.popleft()
for dx, dy in DIRS:
a, b = x + dx, y + dy
if inside(a, b) and g[a][b] == 'X':
g[a][b] = '.'
k -= 1
if 0 == k:
return print_g()
q.append((a, b))
print_g()
# 155ms
def cf377a():
n, m, k = RI()
if k == 0:
for _ in range(n):
b, = RS()
print(b)
else:
g = []
cnt = 0
for _ in range(n):
b, = RS()
cnt += b.count('.')
g.append(list(b.replace('.', 'X')))
m, n = n, m
def get_start():
for i in range(m):
for j in range(n):
if g[i][j] == 'X':
return i, j
start = get_start()
def print_g():
print('\n'.join(''.join(b) for b in g))
k = cnt - k # 只能保留k个点
g[start[0]][start[1]] = '.'
k -= 1
if k == 0:
return print_g()
q = deque([start])
while q:
x, y = q.popleft()
for dx, dy in (0, 1), (0, -1), (1, 0), (-1, 0):
a, b = x + dx, y + dy
if 0 <= a < m and 0 <= b < n and g[a][b] == 'X':
g[a][b] = '.'
k -= 1
if 0 == k:
return print_g()
q.append((a, b))
print_g()
def bootstrap(f, stack=[]):
def wrappedfunc(*args, **kwargs):
if stack:
return f(*args, **kwargs)
else:
to = f(*args, **kwargs)
while True:
if type(to) is GeneratorType:
stack.append(to)
to = next(to)
else:
stack.pop()
if not stack:
break
to = stack[-1].send(to)
return to
return wrappedfunc
# dfs
def solve(m, n, k, g):
def print_g():
print('\n'.join(''.join(b) for b in g))
if k == 0:
return print_g()
cnt = sum(b.count('.') for b in g)
k = cnt - k # 只能保留k个点
DIRS = [(0, 1), (0, -1), (1, 0), (-1, 0)]
def inside(x, y):
return 0 <= x < m and 0 <= y < n
vis = set()
@bootstrap
def dfs2(x, y):
vis.add((x, y))
if len(vis) == k:
yield None
for dx, dy in DIRS:
a, b = x + dx, y + dy
if inside(a, b) and g[a][b] == '.' and (a, b) not in vis and len(vis) < k:
yield dfs(a, b)
yield None
cnt = 0
@bootstrap
def dfs(x, y):
nonlocal cnt
cnt += 1
print(cnt)
if len(vis) == k:
yield None
vis.add((x, y))
for dx, dy in DIRS:
a, b = x + dx, y + dy
if inside(a, b) and g[a][b] == '.' and (a, b) not in vis:
yield dfs(a, b)
# print(r)
# if r is None:
# yield None
# yield r
yield None
def dfs1(x, y):
vis.add((x, y))
if len(vis) == k:
return True
for dx, dy in DIRS:
a, b = x + dx, y + dy
if inside(a, b) and g[a][b] == '.' and (a, b) not in vis:
if dfs(a, b):
return True
return False
for i in range(m):
for j in range(n):
if g[i][j] == '.':
dfs(i, j)
# print(vis)
for x in range(m):
for y in range(n):
if g[x][y] == '.' and (x, y) not in vis:
g[x][y] = 'X'
return print_g()
if __name__ == '__main__':
n, m, k = RI()
a = []
for _ in range(n):
b, = RS()
a.append(list(b))
solve(n, m, k, a)
# cf377a()