TaiChi Lang 让Python代码提速100倍!(高性能计算、图形学、仿真等领域;加速 Python 中计算密集任务程序;希望使用 Python 开发但部署到其它环境)

1、TaiChi简介

Taichi 起步于 MIT 的计算机科学与人工智能实验室(CSAIL),设计初衷是便利计算机图形学研究人员的日常工作,帮助他们快速实现适用于 GPU 的视觉计算和物理模拟算法。 Taichi 选择了一条创新的路径:嵌入于 Python,使用即时编译(JIT)架构(如 LLVM、SPIR-V),将 Python 源代码转化为 GPU 或 CPU 的原生指令,在开发时和运行时均提供优越性能。

当然,以 Python 为前端的领域特定语言(DSL)不是什么新奇的创造。 过去几年里,Halide、PyTorch、TVM 等框架发展成熟,实际已塑造了图像处理、深度学习等领域的标准。 Taichi 与这些框架的最大区别在于其指令式编程范式。 作为一种领域特定语言,Taichi 并不专长于特定的某种计算模式。 这意味着更大的灵活度。 也许有人会假定灵活性需要牺牲优化程度,但对 Taichi 而言并非如此,主要有以下几个原因:

  1. Taichi 的工作负荷呈现出不可利用的特点(如不支持逐元素运算),也就是说算法强度是固定不变的。 只要切换到 GPU 后端,用户就可以收获明显的性能提升。
  2. 传统深度学习框架使用的运算符都是简单的数学表达式,需要在计算图层面融合运算符,以实现更高的算法强度。但 Taichi 的指令式编程让用户能在一个 kernel 中轻松完成大量计算。 我们将这样的 kernel 命名为 mega-kernel。
  3. Taichi 使用各种编译器技术大幅优化源代码,包括公共子表达式消除、死码删除、控制流图分析。 这些优化手段适用于各个后端,因为 Taichi 有自己的中间表示(IR)层。
    即时编译提供了更多的优化机会。

尽管如此,Taichi 远不止于一个 Python 的即时转译器。 最初的设计目标之一是将计算与数据结构解耦。 为此,Taichi 提供一套通用的数据容器,叫做 SNode (/ˈsnoʊd/)。 SNode 可以方便地构造或稠密或稀疏的多维 field,并形成清晰的层级。 在 AoS 和 SoA 两种内存布局间切换只需不到 10 行代码。 这启发了很多数值模拟领域的使用案例。 若你想学习如何使用这种数据容器,请查看 field(高级)、稀疏空间数据结构 或 原始 Taichi 论文。

我们将解耦的概念进一步延伸到数据类型。 由于 GPU 内存容量和带宽已成为当前的主要瓶颈,让每个内存单位存储更多的数据变得至关重要。 2021 年,Taichi 引入了可定制量化类型, 允许定义任意位数的定点数或浮点数(但仍不能超过 64 位)。 从此,在单个 GPU 设备上进行超 4 亿粒子的 MPM 模拟成为可能。 论文《QuanTaichi》对此进行了详细介绍。

Taichi 是直观的语言。 如果你使用 Python,你就能使用 Taichi。 当你用 Taichi 编程,程序自动选择在 GPU 运行(CPU 为替补)。 Taichi 问世后,这样一个并不复杂的理念有幸获得了诸多关注,在众多贡献者的努力下,现在 Taichi 支持更多后端,包括 Vulkan、OpenGL,和 Direct X(仍在进展中)。 没有一个强大且专注的社区,Taichi 无法走到今天。

2、代码示例:

计算素数的个数:

"""Count the number of primes in range [1, n].
"""

def is_prime(n: int):
    result = True
    for k in range(2, int(n ** 0.5) + 1):
        if n % k == 0:
            result = False
            break
    return result
    
def count_primes(n: int) -> int:
    count = 0
    for k in range(2, n):
        if is_prime(k):
            count += 1
        
    return count
    
print(count_primes(10000000))

75.7s

使用taichi加速

"""Count the number of primes below a given bound.
"""
import taichi as ti
ti.init()

@ti.func
def is_prime(n: int):
    result = True
    for k in range(2, int(n ** 0.5) + 1):
        if n % k == 0:
            result = False
            break
    return result

@ti.kernel
def count_primes(n: int) -> int:
    count = 0
    for k in range(2, n):
        if is_prime(k):
            count += 1

    return count

print(count_primes(1000000))

1.4s

@ti.func与@ti.kernel的差异参考如下链接:https://docs.taichi-lang.org/zh-Hans/docs/syntax


参考链接:1、https://forum.taichi.graphics/,2、https://mp.weixin.qq.com/s/epc6Gtci5elAwOiRGQ8wmQ

你可能感兴趣的:(人生苦短,我用Python,python,TaiChi)