py装饰器强行DFS,突破递归深度限制

[python刷题模板] py装饰器强行DFS,突破递归深度限制

    • 一、 算法&数据结构
      • 1. 描述
      • 2. 复杂度分析
      • 3. 常见应用
      • 4. 常用优化
    • 二、 模板代码
      • 1. 矩阵dfs(状态转移*4)CF1032C. Playing Piano
      • 2. 构造遍历(每层状态最多是5) CF337 A. Maze
    • 三、其他
    • 四、更多例题
    • 五、参考链接

一、 算法&数据结构

1. 描述

众所周知,py空间优化很差:1.1e6的数组开不了,这导致了二维dp不滚动过不了;区间dp也不行。
同时,dfs时,py默认限制1000深度(print(sys.getrecursionlimit())),超过了就会报RE:超过最大递归深度限制。
当然可以使用sys.setrecursionlimit(n)来突破这个限制,但是n大了会爆MLE。

于是产生了这个装饰器,来源是群里@Be。
用这个装饰dfs,对dfs返回值进行调整,可以突破py dfs的限制。(依然很占空间,但很多题可以冲了)。
  • 注意,如果要剪枝、提前退出,要使用全局flag,dfs完了立刻判断决定是否break
  • 效率上比其他做法可能差几倍,有可能TLE,比如cf377a,bfs155ms dfs783ms
def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc

2. 复杂度分析

  1. dfs复杂度

3. 常见应用

  1. 状态转移小时可以用dfs回溯莽过的题。

4. 常用优化

二、 模板代码

1. 矩阵dfs(状态转移*4)CF1032C. Playing Piano

例题: CF1032C. Playing Piano

20220818的茶
py装饰器强行DFS,突破递归深度限制_第1张图片

灵神的题解:
纯构造做法
https://www.luogu.com.cn/blog/endlesscheng/solution-cf1032c

每一个上升段和下降段都可以一段段地处理。
对于上升段,让起始值尽量小,每次增长 1。
对于下降段,让起始值尽量大,每次减少 1。
相等的段,元素可以取 2 或 3,这样不会妨碍上升段和下降段的起始值。
  • 这题正解是构造,但是可以dfs硬莽:每层数量有限,回溯即可。
  • 注意每次dfs完判断flag以break
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from types import GeneratorType

if sys.hexversion == 50924784:
    sys.stdin = open('cfinput.txt')
    RI = lambda: map(int, input().split())
else:
    input = sys.stdin.readline
    input_int = io.BytesIO(os.read(0, os.fstat(0).st_size)).readline
    RI = lambda: map(int, input_int().split())

RS = lambda: input().strip().split()
RILST = lambda: list(RI())

MOD = 10 ** 9 + 7
"""https://codeforces.com/problemset/problem/1032/C

输入 n (≤1e5) 和一个长为 n 的数组 a (1≤a[i]≤2e5)。

构造一个长为 n 的数组 b,满足:
1. 1≤b[i]≤5;
2. 如果 a[i]a[i+1],则 b[i]>b[i+1];
4. 如果 a[i]=a[i+1],则 b[i]≠b[i+1];
如果不存在这样的 b 则输出 -1,否则输出任意一个满足要求的 b。
1 1 4 2 2
"""

D = list(range(1, 6))


# 124 ms
def solve1(n, a):
    if n == 1:
        return print(1)
    b = [1] * n
    if a[0] > a[1]:
        b[0] = 5

    for i in range(1, n):
        x, y = a[i - 1], a[i]
        if y < x:  # 降
            b[i] = b[i - 1] - 1
            if b[i] < 1:
                return print(-1)
            if i < n - 1:
                if y < a[i + 1]:  # i是谷,则直接置1最优
                    b[i] = 1
                elif y == a[i + 1]:
                    # i和右边相同:
                    # 若i+1后要降,则i+1置5最优,i置1;
                    # 若后边升,不操作:i+1尽量小(1)优,i不变(不要趋向1,但可能已经是1,则i+1只能是2)
                    if i < n - 2 and a[i + 1] > a[i + 2]:
                        b[i] = 1
        elif y > x:  # 升
            b[i] = b[i - 1] + 1
            if b[i] > 5:
                return print(-1)
            if i < n - 1:
                if y > a[i + 1]:
                    b[i] = 5
                elif y == a[i + 1]:
                    if i < n - 2 and a[i + 1] < a[i + 2]:
                        b[i] = 5
        else:  # 相同
            b[i] = 3 if b[i - 1] == 2 else 2  # 先随便置一个数
            if i < n - 1:
                if y < a[i + 1]:  # 若后边升,则置尽量小1/2
                    b[i] = 2 if b[i - 1] == 1 else 1
                elif y > a[i + 1]:  # 若后边降,则置尽量大5/4
                    b[i] = 4 if b[i - 1] == 5 else 5
                # else:  # 若相同,置一个不耽误右边值变成1/5的数(2/3/4任选)
                #     b[i] = 3 if b[i - 1] == 2 else 2

    print(' '.join(map(str, b)))


# 	389 ms
def solve2(n, a):
    if n == 1:
        return print(1)
    f = [[0] * 5 for _ in range(n)]
    f[0] = [1] * 5
    for i in range(1, n):
        x, y = a[i - 1], a[i]
        flag = 0
        if y < x:  # 降
            for j in range(5):
                if any(k for k in f[i - 1][j + 1:]):
                    flag = f[i][j] = 1
        elif y > x:
            for j in range(5):
                if any(k for k in f[i - 1][:j]):
                    flag = f[i][j] = 1
        else:
            for j in range(5):
                if any(f[i - 1][k] for k in range(5) if k != j):
                    flag = f[i][j] = 1
        if not flag:
            return print(-1)
    # print(f)
    b = [f[-1].index(1)]
    for i in range(n - 2, -1, -1):
        x, y = a[i], a[i + 1]
        for j in range(5):
            if f[i][j] and ((x < y and j < b[-1]) or (x > y and j > b[-1]) or (x == y and j != b[-1])):
                b.append(j)
                break
    # print(b)
    print(' '.join(map(lambda x: str(x + 1), b[::-1])))


# 249	 ms
def solve1(n, a):
    if n == 1:
        return print(1)
    f = [[0] * 5 for _ in range(n)]
    f[0] = [1] * 5
    for i in range(1, n):
        x, y = a[i - 1], a[i]
        flag = 0
        g = f[i - 1]
        if y < x:  # 降
            for j in range(5):
                for k in range(j + 1, 5):
                    if g[k]:
                        flag = f[i][j] = 1
                        break
        elif y > x:
            for j in range(5):
                for k in range(j - 1, -1, -1):
                    if g[k]:
                        flag = f[i][j] = 1
                        break
        else:
            for j in range(5):
                for k in range(5):
                    if k != j and g[k]:
                        flag = f[i][j] = 1
                        break
        if not flag:
            return print(-1)
    # print(f)
    b = [f[-1].index(1)]
    for i in range(n - 2, -1, -1):
        x, y = a[i], a[i + 1]
        for j in range(5):
            if f[i][j] and ((x < y and j < b[-1]) or (x > y and j > b[-1]) or (x == y and j != b[-1])):
                b.append(j)
                break
    # print(b)
    print(' '.join(map(lambda x: str(x + 1), b[::-1])))


def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


def solve(n, a):
    bo = [1] * n
    up = [5] * n
    for i in range(1, n):
        x, y = a[i - 1], a[i]
        if y > x:
            bo[i] = bo[i - 1] + 1
        elif y < x:
            up[i] = up[i - 1] - 1
        if up[i] < bo[i]:
            return print(-1)

    for i in range(n - 2, -1, -1):
        x, y = a[i], a[i + 1]
        if x > y:
            bo[i] = bo[i + 1] + 1
        elif x < y:
            up[i] = up[i + 1] - 1

        if up[i] < bo[i]:
            return print(-1)
    b = [0] * n

    # print(a)
    # print(up)
    # print(bo)

    ok = False
    def dfs1(i):
        # print(b)
        if i == n:
            nonlocal ok
            ok = True
            return True
        for j in range(bo[i], up[i] + 1):
            if i > 0:
                if a[i] == a[i - 1] and j == b[i - 1]:
                    continue
                if a[i] > a[i - 1] and j <= b[i - 1]:
                    continue
                # if i==3 :
                #     print(a[i] ,a[i - 1] , j , b[i])
                if a[i] < a[i - 1] and j >= b[i - 1]:
                    break
            b[i] = j
            if dfs(i + 1):
                return True
        return False

    @bootstrap
    def dfs(i):
        nonlocal ok
        if i == n:
            ok = True
            yield None

        for j in range(bo[i], up[i] + 1):
            if i > 0:
                if a[i] == a[i - 1] and j == b[i - 1]:
                    continue
                if a[i] > a[i - 1] and j <= b[i - 1]:
                    continue
                # if i==3 :
                #     print(a[i] ,a[i - 1] , j , b[i])
                if a[i] < a[i - 1] and j >= b[i - 1]:
                    break
            b[i] = j
            yield dfs(i + 1)
            if ok:  # 注意 dfs完成后立刻判断,否则会继续for造成数据错误
                break
        yield None

    dfs(0)
    if ok:
        print(' '.join(map(str, b)))
    else:
        print(-1)


if __name__ == '__main__':
    n, = RI()
    a = RILST()


    solve(n, a)


2. 构造遍历(每层状态最多是5) CF337 A. Maze

链接: CF337 A. Maze

20220920的茶
py装饰器强行DFS,突破递归深度限制_第2张图片

  • 这题比较简单,dfs或bfs时,先访问到的k个位置一定是连通的,剩下的位置填上即可。
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from types import GeneratorType

if sys.hexversion == 50924784:
    sys.stdin = open('cfinput.txt')
    RI = lambda: map(int, input().split())
else:
    input = sys.stdin.readline
    input_int = sys.stdin.buffer.readline
    RI = lambda: map(int, input_int().split())

RS = lambda: input().strip().split()
RILST = lambda: list(RI())

MOD = 10 ** 9 + 7
"""https://codeforces.com/problemset/problem/377/A

输入 n(≤500) m(≤500) k 和一个 n 行 m 列的网格图,'#' 表示墙,'.' 表示平地。
保证所有 '.' 可以互相到达(四方向连通)。保证 k 小于 '.' 的个数。
你需要把恰好 k 个 '.' 修改成 '#',使得剩余的所有 '.' 仍然是可以互相到达的。
输出修改后的网格图。

输入
3 4 2
#..#
..#.
#...
输出
#.X#
X.#.
#...

输入
5 4 5
#...
#.#.
.#..
...#
.#.#
输出
#XXX
#X#.
X#..
...#
.#.#
"""


# bfs 312	 ms
def solve1(m, n, k, g):
    def print_g():
        for b in g:
            print(''.join(b))

    if k == 0:
        return print_g()

    points = []
    for i in range(m):
        for j in range(n):
            if g[i][j] == '.':
                points.append((i, j))

    k = len(points) - k  # 只能保留k个点

    def get_route():
        DIRS = [(0, 1), (0, -1), (1, 0), (-1, 0)]

        def inside(x, y):
            return 0 <= x < m and 0 <= y < n

        q = deque([points[0]])
        vis = set(q)
        if len(vis) == k:
            return vis
        while q:
            x, y = q.popleft()
            for dx, dy in DIRS:
                a, b = x + dx, y + dy
                if inside(a, b) and g[a][b] == '.' and (a, b) not in vis:
                    vis.add((a, b))
                    if len(vis) == k:
                        return vis
                    q.append((a, b))
        return vis

    # 除了路径上的点都改成X
    keep = get_route()
    for x, y in points:
        if (x, y) not in keep:
            g[x][y] = 'X'

    print_g()


# bfs翻过来做 直接从x修改到.;本题dfs 会爆栈 171 ms
def solve2(m, n, k, g):
    def print_g():
        for b in g:
            print(''.join(b))

    if k == 0:
        return print_g()

    cnt = 0
    start = (0, 0)
    for i in range(m):
        for j in range(n):
            if g[i][j] == '.':
                g[i][j] = 'X'
                if 0 == cnt:
                    start = (i, j)
                cnt += 1

    k = cnt - k  # 只能保留k个点
    DIRS = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    def inside(x, y):
        return 0 <= x < m and 0 <= y < n

    g[start[0]][start[1]] = '.'
    k -= 1
    if k == 0:
        return print_g()
    q = deque([start])

    while q:
        x, y = q.popleft()
        for dx, dy in DIRS:
            a, b = x + dx, y + dy
            if inside(a, b) and g[a][b] == 'X':
                g[a][b] = '.'
                k -= 1
                if 0 == k:
                    return print_g()
                q.append((a, b))
    print_g()


# 155ms
def cf377a():
    n, m, k = RI()
    if k == 0:
        for _ in range(n):
            b, = RS()
            print(b)
    else:
        g = []
        cnt = 0
        for _ in range(n):
            b, = RS()
            cnt += b.count('.')
            g.append(list(b.replace('.', 'X')))
        m, n = n, m

        def get_start():
            for i in range(m):
                for j in range(n):
                    if g[i][j] == 'X':
                        return i, j

        start = get_start()

        def print_g():
            print('\n'.join(''.join(b) for b in g))

        k = cnt - k  # 只能保留k个点

        g[start[0]][start[1]] = '.'
        k -= 1
        if k == 0:
            return print_g()
        q = deque([start])

        while q:
            x, y = q.popleft()
            for dx, dy in (0, 1), (0, -1), (1, 0), (-1, 0):
                a, b = x + dx, y + dy
                if 0 <= a < m and 0 <= b < n and g[a][b] == 'X':
                    g[a][b] = '.'
                    k -= 1
                    if 0 == k:
                        return print_g()
                    q.append((a, b))
        print_g()


def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


# dfs
def solve(m, n, k, g):
    def print_g():
        print('\n'.join(''.join(b) for b in g))

    if k == 0:
        return print_g()
    cnt = sum(b.count('.') for b in g)
    k = cnt - k  # 只能保留k个点
    DIRS = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    def inside(x, y):
        return 0 <= x < m and 0 <= y < n

    vis = set()

    @bootstrap
    def dfs2(x, y):
        vis.add((x, y))
        if len(vis) == k:
            yield None

        for dx, dy in DIRS:
            a, b = x + dx, y + dy
            if inside(a, b) and g[a][b] == '.' and (a, b) not in vis and len(vis) < k:
                yield dfs(a, b)
        yield None
    cnt = 0
    @bootstrap
    def dfs(x, y):
        nonlocal cnt
        cnt += 1
        print(cnt)
        if len(vis) == k:
            yield None
        vis.add((x, y))

        for dx, dy in DIRS:
            a, b = x + dx, y + dy
            if inside(a, b) and g[a][b] == '.' and (a, b) not in vis:
                yield dfs(a, b)
                # print(r)
                # if r is None:
                #     yield None
                # yield r

        yield None

    def dfs1(x, y):
        vis.add((x, y))
        if len(vis) == k:
            return True
        for dx, dy in DIRS:
            a, b = x + dx, y + dy
            if inside(a, b) and g[a][b] == '.' and (a, b) not in vis:
                if dfs(a, b):
                    return True
        return False

    for i in range(m):
        for j in range(n):
            if g[i][j] == '.':
                dfs(i, j)

                # print(vis)
                for x in range(m):
                    for y in range(n):
                        if g[x][y] == '.' and (x, y) not in vis:
                            g[x][y] = 'X'
                return print_g()


if __name__ == '__main__':
    n, m, k = RI()

    a = []
    for _ in range(n):
        b, = RS()
        a.append(list(b))

    solve(n, m, k, a)
    # cf377a()


三、其他

四、更多例题

五、参考链接

  • 链接:

你可能感兴趣的:(python刷题模板,深度优先,算法)