torchscript相关知识介绍(二)

一、TORCHSCRIPT 语言参考

1、TorchScript 是 Python 的静态类型子集,可以直接编写(使用@torch.jit.script装饰器)或者通过跟踪(tracing)从python代码自动生成。当使用跟踪(tracing)时,代码会通过仅仅记录张量上的真实OP操作,来自动转换为 Python 的子集;并简单地执行和丢弃其他周围的 Python 代码。

注:python装饰器学习可以参考:链接1、链接2。

当直接使用 @torch.jit.script装饰器书写 TorchScript ,程序员只能使用 TorchScript 支持的 Python 子集。本节记录了 TorchScript 中支持的内容,就好像它是独立语言的语言参考一样。

本参考资料中未提及的任何 Python 功能都不属于 TorchScript。有关可用 Pytorch 张量方法、模块和函数的完整参考,请参阅内置函数。

作为 Python 的子集,任何有效的 TorchScript 函数也是有效的 Python 函数。这使得禁用 TorchScript以及使用类似pdb这样的标准python工具来调试函数变得成为可能!

反之则不然:有许多有效的 Python 程序不是有效的 TorchScript 程序。相反,TorchScript 专注于在 PyTorch 中表示神经网络模型所需的 Python 特性。

二、类型(type)

TorchScript 和完整的 Python 语言之间的最大区别在于,TorchScript 仅支持表达神经网络模型所需的一小部分类型。特别是,TorchScript 支持如下类型:

torchscript相关知识介绍(二)_第1张图片

torchscript相关知识介绍(二)_第2张图片

 与 Python 不同,TorchScript 函数中的每个变量都必须有一个静态类型(重要特性!!!)。这使得优化 TorchScript 函数变得更加容易。

这里要说明静态类型和动态类型的区别了:

理解静态与动态之别,我们要从变量赋值这个操作为切入点。静态类型语言中,变量的类型必须先声明,即在创建的那一刻就已经确定好变量的类型,而后的使用中,你只能将这一指定类型的数据赋值给变量。如果强行将其他不相干类型的数据赋值给它,就会引发错误。

torchscript相关知识介绍(二)_第3张图片

在静态语言中,一旦声明一个变量是int类型,之后就只能将int类型的数据赋值给它,否则就会引发错误,而动态类型则没有这样的限制,你将什么类型的数据赋值给变量,这个变量就是什么类型

torchscript相关知识介绍(二)_第4张图片

以下语言,皆属于动态类型:

  1. PHP
  2. Ruby
  3. Python

常见的静态类型语言则有:

  1. C
  2. C++
  3. JAVA
  4. C#

示例(类型不匹配的这种情况):

import torch

@torch.jit.script
def an_error(x):
    if x:
        r = torch.rand(1)
    else:
        r = 4
    return r
Exception has occurred: RuntimeError


Type mismatch: r is set to type Tensor in the true branch and type int in the false branch:
  File "/home/wt-yjy/project/code/test_script/main.py", line 5
@torch.jit.script
def an_error(x):
    if x:
    ~~~~~
        r = torch.rand(1)
        ~~~~~~~~~~~~~~~~~
    else:
    ~~~~~
        r = 4
        ~~~~~ <--- HERE
    return r
and was used here:
  File "/home/wt-yjy/project/code/test_script/main.py", line 9
    else:
        r = 4
    return r
           ~ <--- HERE
  File "/home/wt-yjy/project/code/test_script/main.py", line 4, in 
    def an_error(x):

看起来这个就是违背了python语言的动态特性了啊,实际上这里是由于装饰器

@torch.jit.script,将函数 an_error()变成了python的子集,而且是静态的。这正是验证了前面说的:“与 Python 不同,TorchScript 函数中的每个变量都必须有一个静态类型”

三、不支持的typing构造

TorchScript 不支持typing模块的所有功能和类型。其中一些是未来不太可能添加的更基本的东西,而如果有足够的用户需求使其成为优先事项,则可能会添加其他东西。

模块中的这些类型和功能typing在 TorchScript 中不可用。

关于typing(即类型注解支持)的知识可以参考:

typing —— 类型注解支持 — Python 3.10.6 文档

Python 中 typing 模块和类型注解的使用 | 静觅

这里也有说明:

由于 Python 的特性,很多情况下我们并不用去声明它的类型,因此从方法定义上面来看,我们实际上是不知道一个方法的参数到底应该传入什么类型的。 这样其实就造成了很多不方便的地方,在某些情况下一些复杂的方法,如果不借助于一些额外的说明,我们是不知道参数到底是什么类型的。 因此,Python 中的类型注解就显得比较重要了

类型注解:

在python3.5中,Python PEP 484 引入了类型注解(type hints)。在 Python 3.6 中,PEP 526 又进一步引入了变量注解(Variable Annotations),所以上面的代码我们改写成如下写法:

a:int = 2
print('5 + a = ', 5 + a)


def add(a:int)->int:
    return a+1

具体的语法可以归纳为两点:

1)在声明变量时,变量的后面可以加一个冒号,后面再写上变量的类型,如int、list等

2)在声明方法返回值的时候,可以在方法的后面加上一个箭头,后面加上返回值的类型,如int、list等等。

3)在 PEP 8 中,具体的格式是这样规定的:

     1)在声明变量类型时,变量后方紧跟一个冒号。冒号后面跟一个空格,再跟上变量的类型。

     2)在声明方法返回值的时候,箭头左边是方法定义,箭头右边是返回值类型,箭头左右两边都要留有空格。

 上面只是用一个简单的 int 类型做了实例,下面我们再看下一些相对复杂的数据结构,例如列表、元组、字典等类型怎么样来声明。 可想而知了,列表用 list 表示,元组用 tuple 表示,字典用 dict 来表示,那么很自然地,在声明的时候我们就很自然地写成这样了:

names: list = ['Germey', 'Guido']
version: tuple = (3, 7, 4)
operations: dict = {'show': False, 'sort': True}

这么看上去没有问题,确实声明为了对应的类型,但实际上并不能反映整个列表、元组的结构,比如我们只通过类型注解是不知道 names 里面的元素是什么类型的,只知道 names 是一个列表 list 类型,实际上里面都是字符串 str 类型。我们也不知道 version 这个元组的每一个元素是什么类型的,实际上是 int 类型。但这些信息我们都无从得知。因此说,仅仅凭借 list、tuple 这样的声明是非常 “弱” 的,我们需要一种更强的类型声明。 这时候我们就需要借助于 typing 模块了,它提供了非常 “强 “的类型支持,比如 List[str]Tuple[int, int, int] 则可以表示由 str 类型的元素组成的列表和由 int 类型的元素组成的长度为 3 的元组。所以上文的声明写法可以改写成下面的样子:

from typing import List, Tuple, Dict

names: List[str] = ['Germey', 'Guido']
version: Tuple[int, int, int] = (3, 7, 4)
operations: Dict[str, bool] = {'show': False, 'sort': True}

这样一来,变量的类型便可以非常直观的体现出来了。目前typing模块也已经加入到了Python标准库中,不需要安装第三方模块,我们就可以直接使用了。

typing

下面我们再来详细看下typing模块的具体用法,这里主要会介绍一些常用的注解类型,如List、Tuple、Dict、Sequence等等。了解了每个类型的具体使用方法,我们可以得心应手的对任何变量进行声明了。在引入的时候就直接通过typing模块引入就好了,例如:

from typing import List,Tuple

List

List、列表,是list的泛型,基本等同于list,其后紧跟一个方括号,里面代表了构成这个列表的元素类型,如由数字构成的列表可以声明为:

var:List[int or float] = [2, 3.5]

另外还可以嵌套声明都是可以的:

var: List[List[int]] = [[1, 2], [2, 3]]

Tuple、NamedTuple

Tuple、元组,是 tuple 的泛型,其后紧跟一个方括号,方括号中按照顺序声明了构成本元组的元素类型,如 Tuple[X, Y] 代表了构成元组的第一个元素是 X 类型,第二个元素是 Y 类型。 比如想声明一个元组,分别代表姓名、年龄、身高,三个数据类型分别为 str、int、float,那么可以这么声明:

person: Tuple[str, int, float] = ('Mike', 22, 1.75)

同样地也可以使用类型嵌套。 NamedTuple,是 collections.namedtuple 的泛型,实际上就和 namedtuple 用法完全一致,但个人其实并不推荐使用 NamedTuple,推荐使用 attrs 这个库来声明一些具有表征意义的类。

Dict、Mapping、MutableMapping

Dict、字典,是 dict 的泛型;Mapping,映射,是 collections.abc.Mapping 的泛型。

根据官方文档,Dict 推荐用于注解返回类型,Mapping 推荐用于注解参数。它们的使用方法都是一样的,其后跟一个中括号,中括号内分别声明键名、键值的类型,如:

def size(rect: Mapping[str, int]) -> Dict[str, int]:
    return {'width': rect['width'] + 100, 'height': rect['width'] + 100}

这里将 Dict 用作了返回值类型注解,将 Mapping 用作了参数类型注解。 MutableMapping 则是 Mapping 对象的子类,在很多库中也经常用 MutableMapping 来代替 Mapping。

Set、AbstractSet

Set、集合,是 set 的泛型;AbstractSet、是 collections.abc.Set 的泛型。根据官方文档,Set 推荐用于注解返回类型,AbstractSet 用于注解参数。它们的使用方法都是一样的,其后跟一个中括号,里面声明集合中元素的类型,如:

def describe(s: AbstractSet[int]) -> Set[int]:
    return set(s)

这里将 Set 用作了返回值类型注解,将 AbstractSet 用作了参数类型注解。

Sequence

Sequence,是 collections.abc.Sequence 的泛型,在某些情况下,我们可能并不需要严格区分一个变量或参数到底是列表 list 类型还是元组 tuple 类型,我们可以使用一个更为泛化的类型,叫做 Sequence,其用法类似于 List,如:

def square(elements: Sequence[float]) -> List[float]:
    return [x ** 2 for x in elements]

NoReturn

NoReturn,当一个方法没有返回结果时,为了注解它的返回类型,我们可以将其注解为 NoReturn,例如:

def hello() -> NoReturn:
    print('hello')

Any

Any,是一种特殊的类型,它可以代表所有类型,静态类型检查器的所有类型都与 Any 类型兼容,所有的无参数类型注解和返回类型注解的都会默认使用 Any 类型,也就是说,下面两个方法的声明是完全等价的:

def add(a):
    return a + 1

def add(a: Any) -> Any:
    return a + 1

原理类似于 object,所有的类型都是 object 的子类。但如果我们将参数声明为 object 类型,静态参数类型检查便会抛出错误,而 Any 则不会,具体可以参考官方文档的说明:typing —— 类型注解支持 — Python 3.10.6 文档。

TypeVar

TypeVar,我们可以借助它来自定义兼容特定类型的变量,比如有的变量声明为 int、float、None 都是符合要求的,实际就是代表任意的数字或者空内容都可以,其他的类型则不可以,比如列表 list、字典 dict 等等,像这样的情况,我们可以使用 TypeVar 来表示。 例如一个人的身高,便可以使用 int 或 float 或 None 来表示,但不能用 dict 来表示,所以可以这么声明:

height = 1.75
Height = TypeVar('Height', int, float, None)
def get_height() -> Height:
    return height

这里我们使用 TypeVar 声明了一个 Height 类型,然后将其用于注解方法的返回结果。

NewType

NewType,我们可以借助于它来声明一些具有特殊含义的类型,例如像 Tuple 的例子一样,我们需要将它表示为 Person,即一个人的含义,但从表面上声明为 Tuple 并不直观,所以我们可以使用 NewType 为其声明一个类型,如:

Person = NewType('Person', Tuple[str, int, float])
person = Person(('Mike', 22, 1.75))

这里实际上 person 就是一个 tuple 类型,我们可以对其像 tuple 一样正常操作。

Callable

Callable,可调用类型,它通常用来注解一个方法,比如我们刚才声明了一个 add 方法,它就是一个 Callable 类型:

print(Callable, type(add), isinstance(add, Callable))

运行结果:

typing.Callable  True

在这里虽然二者 add 利用 type 方法得到的结果是 function,但实际上利用 isinstance 方法判断确实是 True。 Callable 在声明的时候需要使用 Callable[[Arg1Type, Arg2Type, ...], ReturnType] 这样的类型注解,将参数类型和返回值类型都要注解出来,例如:

def date(year: int, month: int, day: int) -> str:
    return f'{year}-{month}-{day}'

def get_date_fn() -> Callable[[int, int, int], str]:
    return date

这里首先声明了一个方法 date,接收三个 int 参数,返回一个 str 结果,get_date_fn 方法返回了这个方法本身,它的返回值类型就可以标记为 Callable,中括号内分别标记了返回的方法的参数类型和返回值类型。

Union

Union,联合类型,Union[X, Y] 代表要么是 X 类型,要么是 Y 类型。 联合类型的联合类型等价于展平后的类型:

Union[Union[int, str], float] == Union[int, str, float]

仅有一个参数的联合类型会坍缩成参数自身,比如:

Union[int] == int

多余的参数会被跳过,比如:

Union[int, str, int] == Union[int, str]

在比较联合类型的时候,参数顺序会被忽略,比如:

Union[int, str] == Union[str, int]

这个在一些方法参数声明的时候比较有用,比如一个方法,要么传一个字符串表示的方法名,要么直接把方法传过来:

def process(fn: Union[str, Callable]):
    if isinstance(fn, str):
        # str2fn and process
        pass
    elif isinstance(fn, Callable):
        fn()

这样的声明在一些类库方法定义的时候十分常见。

Optional

Optional,意思是说这个参数可以为空或已经声明的类型,即 Optional[X] 等价于 Union[X, None]。 但值得注意的是,这个并不等价于可选参数,当它作为参数类型注解的时候,不代表这个参数可以不传递了,而是说这个参数可以传为 None。 如当一个方法执行结果,如果执行完毕就不返回错误信息, 如果发生问题就返回错误信息,则可以这么声明:

def judge(result: bool) -> Optional[str]:
    if result: return 'Error Occurred'

Generator

如果想代表一个生成器类型,可以使用 Generator,它的声明比较特殊,其后的中括号紧跟着三个参数,分别代表 YieldType、SendType、ReturnType,如:

def echo_round() -> Generator[int, float, str]:
    sent = yield 0
    while sent >= 0:
        sent = yield round(sent)
    return 'Done'

在这里 yield 关键字后面紧跟的变量的类型就是 YieldType,yield 返回的结果的类型就是 SendType,最后生成器 return 的内容就是 ReturnType。 当然很多情况下,生成器往往只需要 yield 内容就够了,我们是不需要 SendType 和 ReturnType 的,可以将其设置为空,如:

def infinite_stream(start: int) -> Generator[int, None, None]:
    while True:
        yield start
        start += 1

案例实战

接下来让我们看一个实际的项目,看看经常用到的类型一般是怎么使用的。 这里我们看的库是 requests-html,是由 Kenneth Reitz 所开发的,其 GitHub 地址为:https://github.com/psf/requests-html,下面我们主要看看它的源代码中一些类型是如何声明的。 这个库的源代码其实就一个文件,那就是 https://github.com/psf/requests-html/blob/master/requests_html.py,我们看一下它里面的一些 typing 的定义和方法定义。 首先 Typing 的定义部分如下:

from typing import Set, Union, List, MutableMapping, Optional

_Find = Union[List['Element'], 'Element']
_XPath = Union[List[str], List['Element'], str, 'Element']
_Result = Union[List['Result'], 'Result']
_HTML = Union[str, bytes]
_BaseHTML = str
_UserAgent = str
_DefaultEncoding = str
_URL = str
_RawHTML = bytes
_Encoding = str
_LXML = HtmlElement
_Text = str
_Search = Result
_Containing = Union[str, List[str]]
_Links = Set[str]
_Attrs = MutableMapping
_Next = Union['HTML', List[str]]
_NextSymbol = List[str]

这里可以看到主要用到的类型有 Set、Union、List、MutableMapping、Optional,这些在上文都已经做了解释,另外这里使用了多次 Union 来声明了一些新的类型,如 _Find 则要么是是 Element 对象的列表,要么是单个 Element 对象,_Result 则要么是 Result 对象的列表,要么是单个 Result 对象。另外 _Attrs 其实就是字典类型,这里用 MutableMapping 来表示了,没有用 Dict,也没有用 Mapping。 接下来再看一个 Element 类的声明:

class Element(BaseParser):
    """An element of HTML.
    :param element: The element from which to base the parsing upon.
    :param url: The URL from which the HTML originated, used for ``absolute_links``.
    :param default_encoding: Which encoding to default to.
    """

    __slots__ = [
        'element', 'url', 'skip_anchors', 'default_encoding', '_encoding',
        '_html', '_lxml', '_pq', '_attrs', 'session'
    ]

    def __init__(self, *, element, url: _URL, default_encoding: _DefaultEncoding = None) -> None:
        super(Element, self).__init__(element=element, url=url, default_encoding=default_encoding)
        self.element = element
        self.tag = element.tag
        self.lineno = element.sourceline
        self._attrs = None

    def __repr__(self) -> str:
        attrs = ['{}={}'.format(attr, repr(self.attrs[attr])) for attr in self.attrs]
        return "".format(repr(self.element.tag), ' '.join(attrs))

    @property
    def attrs(self) -> _Attrs:
        """Returns a dictionary of the attributes of the :class:`Element `
        (`learn more `_).
        """
        if self._attrs is None:
            self._attrs = {k: v for k, v in self.element.items()}

            # Split class and rel up, as there are ussually many of them:
            for attr in ['class', 'rel']:
                if attr in self._attrs:
                    self._attrs[attr] = tuple(self._attrs[attr].split())

        return self._attrs

这里 __init__ 方法接收非常多的参数,同时使用 _URL 、_DefaultEncoding 进行了参数类型注解,另外 attrs 方法使用了 _Attrs 进行了返回结果类型注解。 整体看下来,每个参数的类型、返回值都进行了清晰地注解,代码可读性大大提高。 以上便是类型注解和 typing 模块的详细介绍。

torchscript相关知识介绍(二)_第5张图片

 

torchscript相关知识介绍(二)_第6张图片

 

typing本文档中未明确列出的模块中的任何其他功能均不受支持。

四、默认类型

默认情况下,TorchScript 函数的所有参数都假定为 Tensor。要指定 TorchScript 函数的参数是另一种类型,可以使用上面列出的类型使用 MyPy 样式的类型注释。

import torch

@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

也可以使用 typing模块中的 Python 3 类型提示来注释类型。

import torch
from typing import Tuple

@torch.jit.script
def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

一个空列表假定为 List[Tensor] 以及一个空字典假定为Dict [str, Tensor]

实例化其他类型的空列表或字典,请使用Python 3 类型提示。Dict[str, Tensor]

示例(Python 3 的类型注释):

import torch
import torch.nn as nn
from typing import Dict, List, Tuple

class EmptyDataStructures(torch.nn.Module):
    def __init__(self):
        super(EmptyDataStructures, self).__init__()

    def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]:
        # This annotates the list to be a `List[Tuple[int, float]]`
        my_list: List[Tuple[int, float]] = []
        for i in range(10):
            my_list.append((i, x.item()))

        my_dict: Dict[str, int] = {}
        return my_list, my_dict

x = torch.jit.script(EmptyDataStructures())

五、可选类型细化(Optional Type Refinement

当与在if语句的条件内或在断言assert条件下的检查时的作比较None中,TorchScript将优化optional[T]类型变量的类型。编译器可以推理与and、or和not组合的多个None检查。对于未显式写入的if语句的else块,也会进行细化。

None的检查必须在if条件语句下;将None检查分配给变量并在if语句的条件中使用它这将不会优化检查中的变量类型。只有局部变量会被细化,就像属性self.x不会。它(局部变量如self.x)必须被分配给要细化的局部变量。

示例(优化参数和局部变量的类型):

import torch
import torch.nn as nn
from typing import Optional

class M(nn.Module):
    z: Optional[int]

    def __init__(self, z):
        super(M, self).__init__()
        # If `z` is None, its type cannot be inferred, so it must
        # be specified (above)
        self.z = z

    def forward(self, x, y, z):
        # type: (Optional[int], Optional[int], Optional[int]) -> int
        if x is None:
            x = 1
            x = x + 1

        # Refinement for an attribute by assigning it to a local
        z = self.z
        if y is not None and z is not None:
            x = y + z

        # Refinement via an `assert`
        assert z is not None
        x += z
        return x

module = torch.jit.script(M(2))
module = torch.jit.script(M(None))

六、TorchScript 类(TorchScript Classes

警告:

TorchScript 类支持是实验性的。目前它最适合简单的类似记录的类型(想想NamedTuple附加方法)。

如果 Python 类使用 注释,则可以在 TorchScript 中使用@torch.jit.script,类似于声明 TorchScript 函数的方式:

@torch.jit.script
class Foo:
  def __init__(self, x, y):
    self.x = x

  def aug_add_x(self, inc):
    self.x += inc

这个子集是受限的:

  • 所有函数都必须是有效的 TorchScript 函数(包括__init__())。

  • 类必须是新式类,因为我们使用__new__()pybind11 来构造它们。

  • TorchScript 类是静态类型的。成员只能通过在__init__()方法中赋值给 self 来声明。

例如,分配给方法self外部__init__()

@torch.jit.script
class Foo:
  def assign_x(self):
    self.x = torch.rand(2, 3)

将导致:

  • 类的主体中不允许除方法定义外的任何表达式。

  • 不支持继承或任何其他多态策略,除了继承自object指定新样式类。

定义类后,它可以像任何其他 TorchScript 类型一样在 TorchScript 和 Python 中互换使用:

# Declare a TorchScript class
@torch.jit.script
class Pair:
  def __init__(self, first, second):
    self.first = first
    self.second = second

@torch.jit.script
def sum_pair(p):
  # type: (Pair) -> Tensor
  return p.first + p.second

p = Pair(torch.rand(2, 3), torch.rand(2, 3))
print(sum_pair(p))

七、TorchScript 枚举

Python 枚举可以在 TorchScript 中使用,无需任何额外的注释或代码:

from enum import Enum


class Color(Enum):
    RED = 1
    GREEN = 2

@torch.jit.script
def enum_fn(x: Color, y: Color) -> bool:
    if x == Color.RED:
        return True

    return x == y

定义枚举后,它可以像任何其他 TorchScript 类型一样在 TorchScript 和 Python 中互换使用。枚举值的类型必须是int、 floatstr。所有值必须属于同一类型;不支持枚举值的异构类型。

八、命名元组

生成的类型collections.namedtuple可以在 TorchScript 中使用。

import torch
import collections

Point = collections.namedtuple('Point', ['x', 'y'])

@torch.jit.script
def total(point):
    # type: (Point) -> Tensor
    return point.x + point.y

p = Point(x=torch.rand(3), y=torch.rand(3))
print(total(p))

九、可迭代对象

某些函数(例如zipand enumerate)只能对可迭代类型进行操作。TorchScript 中的可迭代类型包括Tensors、列表、元组、字典、字符串 torch.nn.ModuleList和torch.nn.ModuleDict.

十、表达式

支持以下 Python 表达式。

(1)字面量

True
False
None
'string literals'
"string literals"
3  # interpreted as int
3.4  # interpreted as a float

(2)清单建设

假定空列表具有 type List[Tensor]。其他列表文字的类型是从成员的类型派生的。有关详细信息,请参阅默认类型。

[3, 4]
[]
[torch.rand(3), torch.rand(4)]

(3)元组构造

(3, 4)
(3,)

(4)字典结构

假设空字典具有 type 。其他 dict 文字的类型是从成员的类型派生的。有关详细信息,请参阅默认类型。Dict[str, Tensor]

{'hello': 3}
{}
{'a': torch.rand(3), 'b': torch.rand(4)}

(5)变量

有关如何解析变量,请参阅变量解析。

my_variable_name

(6)算术运算符

a + b
a - b
a * b
a / b
a ^ b
a @ b

(7)比较运算符

a == b
a != b
a < b
a > b
a <= b
a >= b

(8)逻辑运算符

a and b
a or b
not b

(9)下标和切片

t[0]
t[-1]
t[0:2]
t[1:]
t[:1]
t[:]
t[0, 1]
t[0, 1:2]
t[0, :1]
t[-1, 1:, 0]
t[1:, -1, 0]
t[i:j, i]

(10)函数调用

调用内置函数

torch.rand(3, dtype=torch.int)

调用其他脚本函数:

import torch

@torch.jit.script
def foo(x):
    return x + 1

@torch.jit.script
def bar(x):
    return foo(x)

十一、方法调用

调用内置类型的方法,如张量:x.mm(y)

在模块上,必须先编译方法,然后才能调用它们。TorchScript 编译器递归地编译它在编译其他方法时看到的方法。默认情况下,编译在该forward方法上开始。任何被调用的方法forward都将被编译,任何被这些方法调用的方法都将被编译,以此类推。要在 以外的方法开始编译forward,请使用@torch.jit.export装饰器(forward隐式标记为@torch.jit.export)。

直接调用子模块(例如self.resnet(input))相当于调用它的forward方法(例如self.resnet.forward(input)

import torch
import torch.nn as nn
import torchvision

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        means = torch.tensor([103.939, 116.779, 123.68])
        self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1))
        resnet = torchvision.models.resnet18()
        self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224))

    def helper(self, input):
        return self.resnet(input - self.means)

    def forward(self, input):
        return self.helper(input)

    # Since nothing in the model calls `top_level_method`, the compiler
    # must be explicitly told to compile this method
    @torch.jit.export
    def top_level_method(self, input):
        return self.other_helper(input)

    def other_helper(self, input):
        return input + 10

# `my_script_module` will have the compiled methods `forward`, `helper`,
# `top_level_method`, and `other_helper`
my_script_module = torch.jit.script(MyModule())

(1)三元表达式

x if x > y else y

(2)转换(casts)

float(ten)
int(3.5)
bool(ten)
str(2)``

(3)访问模块参数

self.my_parameter
self.my_submodule.my_parameter

十二、语句(Statements)

TorchScript支持以下类型的声明。

(1)简单赋值

a = b
a += b # short-hand for a = a + b, does not operate in-place on a
a -= b

(2)模式匹配赋值

a, b = tuple_or_list
a, b, *c = a_tuple

(3)多元赋值

a = b, c = tup

(4)打印语句

print("the result of an add:", a + b)

(5)条件语句

if a < 4:
    r = -a
elif a < 3:
    r = a + a
else:
    r = 3 * a

除了布尔值之外,浮点数、整数和张量也可以在条件中使用,并将被隐式转换为布尔值。

(6)While 循环

a = 0
while a < 4:
    print(a)
    a += 1

(7)带范围的循环

x = 0
for i in range(10):
    x *= i

(8)元组上的 for 循环

这些展开循环,为元组的每个成员生成一个主体。正文必须对每个成员进行正确的类型检查。

tup = (3, torch.rand(4))
for x in tup:
    print(x)

(9)对于常量 nn.ModuleList 上的循环

要在编译后的方法中使用nn.ModuleList,必须通过将属性名称添加到__constants__ 类型列表来将其标记为常量。

a 上的 for 循环nn.ModuleList将在编译时展开循环体,其中包含常量模块列表的每个成员。

class SubModule(torch.nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.weight = nn.Parameter(torch.randn(2))

    def forward(self, input):
        return self.weight + input

class MyModule(torch.nn.Module):
    __constants__ = ['mods']

    def __init__(self):
        super(MyModule, self).__init__()
        self.mods = torch.nn.ModuleList([SubModule() for i in range(10)])

    def forward(self, v):
        for module in self.mods:
            v = module(v)
        return v


m = torch.jit.script(MyModule())

(10)中断并继续

for i in range(5):
    if i == 1:
        continue
    if i == 3:
        break
    print(i)

(11)返回

return a, b

十三、可变分辨率

TorchScript 支持 Python 的变量解析(即范围)规则的一个子集。局部变量的行为与 Python 中的相同,除了在通过函数的所有路径中变量必须具有相同类型的限制。如果变量在 if 语句的不同分支上具有不同的类型,则在 if 语句结束后使用它是错误的。

同样,如果一个变量只在函数的某些路径上定义,则不允许使用该变量。

例子:

@torch.jit.script
def foo(x):
    if x < 0:
        y = 4
    print(y)
Traceback (most recent call last):
  ...
RuntimeError: ...

y is not defined in the false branch...
@torch.jit.script...
def foo(x):
    if x < 0:
    ~~~~~~~~~
        y = 4
        ~~~~~ <--- HERE
    print(y)
and was used here:
    if x < 0:
        y = 4
    print(y)
          ~ <--- HERE...

定义函数时,非局部变量在编译时解析为 Python 值。然后使用 Python 值的使用中描述的规则将这些值转换为 TorchScript 值。

十四、Python 值的使用

为了使编写TorchScript更加方便,我们允许脚本代码引用周围范围中的Python值。例如,只要存在对torch的引用,TorchScript编译器就会在声明函数时将其解析为torch Python模块。这些Python值不是TorchScript的第一类部分。相反,它们在编译时被分解成TorchScript支持的基本类型。这取决于编译时引用的Python值的动态类型。本节介绍访问TorchScript中的Python值时使用的规则。

(1)函数

TorchScript可以调用Python函数。当增量地转换一个模型为TorchScript时,此功能非常有用。模型可以一个函数一个函数地移动到TorchScript中,保留对Python函数的调用。这样,您可以在执行过程中逐步检查模型的正确性。

torch.jit.is_scripting()

当在编译时该函数返回True,否则返回False。

这对于@unused 装饰器非常有用,因为它可以在模型中保留尚未与TorchScript兼容的代码。。

测试代码如下:

import torch

@torch.jit.unused
def unsupported_linear_op(x):
    return x

def linear(x):
   if torch.jit.is_scripting():
      return torch.linear(x)
   else:
      return unsupported_linear_op(x)

torch.jit.is_tracing()

在跟踪(tracing)中返回True(如果在使用torch.jit.trace跟踪代码期间调用函数),否则返回False。

十五、Python 模块上的属性查找

TorchScript 可以在模块上查找属性。像这样的内置函数torch.add 可以通过这种方式访问​​。这允许 TorchScript 调用其他模块中定义的函数。

十六、Python 定义的常量

TorchScript 还提供了一种使用 Python 中定义的常量的方法。这些可用于将超参数硬编码到函数中,或定义通用常量。有两种方法可以指定 Python 值应被视为常量。

  1. 假定作为模块属性查找的值是恒定的:

import math
import torch

@torch.jit.script
def fn():
    return math.pi

   2.ScriptModule 的属性可以通过使用注释来标记为常量Final[T]

import torch
import torch.nn as nn

class Foo(nn.Module):
    # `Final` from the `typing_extensions` module can also be used
    a : torch.jit.Final[int]

    def __init__(self):
        super(Foo, self).__init__()
        self.a = 1 + 4

    def forward(self, input):
        return self.a + input

f = torch.jit.script(Foo())

支持的常量 Python 类型是

  • int

  • float

  • bool

  • torch.device

  • torch.layout

  • torch.dtype

  • 包含支持类型的元组

  • torch.nn.ModuleList可以在 TorchScript for 循环中使用

十七、模块属性(Module Attributes

torch.nn.Parameter包装器和register_buffer可用于为模块分配张量。

如果可以推断其类型,分配给已编译模块的其他值将被添加到已编译模块中。

TorchScript中可用的所有类型都可以用作模块属性。张量属性在语义上与缓冲区相同。空列表和空字典的类型以及None值不能推断,必须通过PEP 526样式类注释指定。如果无法推断类型且未显式注释,则不会将其作为属性添加到结果ScriptModule。

例子:

from typing import List, Dict

class Foo(nn.Module):
    # `words` is initialized as an empty list, so its type must be specified
    words: List[str]

    # The type could potentially be inferred if `a_dict` (below) was not
    # empty, but this annotation ensures `some_dict` will be made into the
    # proper type
    some_dict: Dict[str, int]

    def __init__(self, a_dict):
        super(Foo, self).__init__()
        self.words = []
        self.some_dict = a_dict

        # `int`s can be inferred
        self.my_int = 10

    def forward(self, input):
        # type: (str) -> int
        self.words.append(input)
        return self.some_dict[input] + self.my_int

f = torch.jit.script(Foo({'hi': 2}))

pytorch-doc-zh/84.md at 09d11048ec7b4f3d8c90884b6a755d9546cb79c0 · apachecn/pytorch-doc-zh · GitHub

pytorch-doc-zh/jit.md at 09d11048ec7b4f3d8c90884b6a755d9546cb79c0 · apachecn/pytorch-doc-zh · GitHub

参考:静态类型与动态类型编程语言之间的区别 - 知乎

你可能感兴趣的:(部署,深度学习,pytorch,pytorch,人工智能,python)