在PyTorch的torch/_C/_VariableFunctions.pyi
中有如下代碼:
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, generator: Optional[Generator], out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, out: Optional[Tensor] = None, dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(size: Sequence[Union[_int, SymInt]], *, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
@overload
def rand(*size: _int, names: Optional[Sequence[Union[str, ellipsis, None]]], dtype: Optional[_dtype] = None, layout: Optional[_layout] = None, device: Optional[Union[_device, str, None]] = None, pin_memory: Optional[_bool] = False, requires_grad: Optional[_bool] = False) -> Tensor: ...
當中的Sequence
, Iterable
, Optional
, Union
以及_int
, _bool
都是什麼意思呢?可以從torch/_C/_VariableFunctions.pyi.in
中一窺端倪:
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided, inf
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
所以Sequence
, Iterable
, Optional
, Union
等是從一個叫做typing
的庫中導入的。typing是Python的標準庫之一,作用是提供對類型提示的運行時支持。
_int
, _bool
等則是PyTorch中自行定義的類型。
根據Type hints cheat sheet - Standard “duck types”,Sequence
代表的是支持__len__
及__getitem__
方法的序列類型,例如list, tuple和str。dict和set則不屬於此類型。
# Use Iterable for generic iterables (anything usable in "for"),
# and Sequence where a sequence (supporting "len" and "__getitem__") is
# required
根據Python Iterable vs Sequence:
Iterable
代表的是支持__iter__
或__getitem__
的類型,如range
和reversed
。
r = range(4)
r.__getitem__(0) # 0
r.__iter__() #
l = [1, 2, 3]
rv = reversed(l)
rv.__iter__() #
rv.__getitem__() # 不支援__getitem__方法,但因為支持__iter__所以依然可以歸類為Iterable
# Traceback (most recent call last):
# File "", line 1, in
# AttributeError: 'list_reverseiterator' object has no attribute '__getitem__'
因為Sequence
也具有__iter__
和__getitem__
,所以根據定義,所有的Sequence
都是Iterable
。
l = []
l.__iter__ #
l.__getitem__ #
typing - Callable
Callable
Frameworks expecting callback functions of specific signatures might be type hinted using Callable[[Arg1Type, Arg2Type], ReturnType].
文檔寫得很淺顯易懂,不過有一點要注意的是入參型別要用[]
括起來。
Type hints cheat sheet - Functions中給出了例子:
# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
如果先不看類型提示的代碼,這句其實就是x = f
,把x
這個變數設定為f
這個函數。當中的Callable[[int, float], float]
說明了f
是一個接受int
, float
,輸出float
的函數。
typing - Union
typing.Union
Union type; Union[X, Y] is equivalent to X | Y and means either X or Y.
To define a union, use e.g. Union[int, str] or the shorthand int | str. Using that shorthand is recommended.
Union[X, Y]
表示型別可以是X
或Y
,從Python 3.10以後,可以使用X | Y
這種更簡潔的寫法。
Type hints cheat sheet - Useful built-in types中給出的例子:
# On Python 3.10+, use the | operator when something could be one of a few types
x: list[int | str] = [3, 5, "test", "fun"] # Python 3.10+
# On earlier versions, use Union
x: list[Union[int, str]] = [3, 5, "test", "fun"]
typing - Optional
Optional type.
Optional[X] is equivalent to X | None (or Union[X, None]).
Optional[X]
表示該變數可以是X
型別或是None
型別。
Type hints cheat sheet - Useful built-in types中給出了一個很好的例子:
# Use Optional[X] for a value that could be None
# Optional[X] is the same as X | None or Union[X, None]
x: Optional[str] = "something" if some_condition() else None
這裡x
根據some_condition()
的回傳值有可能是一個字串或是None,所以此處選用Optional[str]
的類型提示。
mypy - Functions
指定參數和回傳值型別:
from typing import Callable, Iterator, Union, Optional
# This is how you annotate a function definition
def stringify(num: int) -> str:
return str(num)
多個參數:
# And here's how you specify multiple arguments
def plus(num1: int, num2: int) -> int:
return num1 + num2
無回傳值的函數以None
為回傳型別,並且參數的預設值應寫在參數型別後面:
# If a function does not return a value, use None as the return type
# Default value for an argument goes after the type annotation
def show(value: str, excitement: int = 10) -> None:
print(value + "!" * excitement)
可以接受任意型別參數的函數則不必指定參數型別:
# Note that arguments without a type are dynamically typed (treated as Any)
# and that functions without any annotations not checked
def untyped(x):
x.anything() + 1 + "string" # no errors
將Callable
當作參數的函數:
# This is how you annotate a callable (function) value
x: Callable[[int, float], float] = f
def register(callback: Callable[[str], int]) -> None: ...
generator函數相當於一個Iterator
:
# A generator function that yields ints is secretly just a function that
# returns an iterator of ints, so that's how we annotate it
def gen(n: int) -> Iterator[int]:
i = 0
while i < n:
yield i
i += 1
將function annotation分成多行:
# You can of course split a function annotation over multiple lines
def send_email(address: Union[str, list[str]],
sender: str,
cc: Optional[list[str]],
bcc: Optional[list[str]],
subject: str = '',
body: Optional[list[str]] = None
) -> bool:
...
# Mypy understands positional-only and keyword-only arguments
# Positional-only arguments can also be marked by using a name starting with
# two underscores
def quux(x: int, /, *, y: int) -> None:
pass
quux(3, y=5) # Ok
quux(3, 5) # error: Too many positional arguments for "quux"
quux(x=3, y=5) # error: Unexpected keyword argument "x" for "quux"
注意到此處參數列表中有/
和*
兩個符號,參考What Are Python Asterisk and Slash Special Parameters For?:
Left side | Divider | Right side |
---|---|---|
Positional-only arguments | / |
Positional or keyword arguments |
Positional or keyword arguments | * |
Keyword-only arguments |
Python的參數分為三種:位置參數,關鍵字參數及可變參數(可以透過位置或關鍵字的方式傳遞)。
/
符號的左邊必須是位置參數,*
符號的右邊則必須是關鍵字參數。
所以上例中x
必須以位置參數的方式傳遞,y
必須以關鍵字參數的方式傳遞。
一次指定多個參數的型別:
# This says each positional arg and each keyword arg is a "str"
def call(self, *args: str, **kwargs: str) -> str:
reveal_type(args) # Revealed type is "tuple[str, ...]"
reveal_type(kwargs) # Revealed type is "dict[str, str]"
request = make_request(*args, **kwargs)
return self.do_api_query(request)
mypy - Classes
class BankAccount:
# The "__init__" method doesn't return anything, so it gets return
# type "None" just like any other method that doesn't return anything
def __init__(self, account_name: str, initial_balance: int = 0) -> None:
# mypy will infer the correct types for these instance variables
# based on the types of the parameters.
self.account_name = account_name
self.balance = initial_balance
# For instance methods, omit type for "self"
def deposit(self, amount: int) -> None:
self.balance += amount
def withdraw(self, amount: int) -> None:
self.balance -= amount
成員函數self
參數的型別不需指定。
可以將變數型別指定為自定義的類別:
# User-defined classes are valid as types in annotations
account: BankAccount = BankAccount("Alice", 400)
def transfer(src: BankAccount, dst: BankAccount, amount: int) -> None:
src.withdraw(amount)
dst.deposit(amount)
# Functions that accept BankAccount also accept any subclass of BankAccount!
class AuditedBankAccount(BankAccount):
# You can optionally declare instance variables in the class body
audit_log: list[str]
def __init__(self, account_name: str, initial_balance: int = 0) -> None:
super().__init__(account_name, initial_balance)
self.audit_log: list[str] = []
def deposit(self, amount: int) -> None:
self.audit_log.append(f"Deposited {amount}")
self.balance += amount
def withdraw(self, amount: int) -> None:
self.audit_log.append(f"Withdrew {amount}")
self.balance -= amount
audited = AuditedBankAccount("Bob", 300)
transfer(audited, account, 100) # type checks!
transfer
函數的第一個參數型別應為BankAccount
,而AuditedBankAccount
是BankAccount
的子類別,所以在做類型檢查時不會出錯。
Python中類別的變數有類別變數別實例變數兩種。如果想要將成員變數標記為類別變數,可以用ClassVar[type]
。
# You can use the ClassVar annotation to declare a class variable
class Car:
seats: ClassVar[int] = 4
passengers: ClassVar[list[str]]
# If you want dynamic attributes on your class, have it
# override "__setattr__" or "__getattr__"
class A:
# This will allow assignment to any A.x, if x is the same type as "value"
# (use "value: Any" to allow arbitrary types)
def __setattr__(self, name: str, value: int) -> None: ...
# This will allow access to any A.x, if x is compatible with the return type
def __getattr__(self, name: str) -> int: ...
a.foo = 42 # Works
a.bar = 'Ex-parrot' # Fails type checking
__setattr__
函數可以為類別新增實體變數。
PyTorch中自定義的類型。
torch/types.py
import torch
from typing import Any, List, Sequence, Tuple, Union
import builtins
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
_TensorOrTensors = Union[torch.Tensor, Sequence[torch.Tensor]]
# In some cases, these basic types are shadowed by corresponding
# top-level values. The underscore variants let us refer to these
# types. See https://github.com/python/mypy/issues/4146 for why these
# workarounds is necessary
_int = builtins.int
_float = builtins.float
_bool = builtins.bool
_dtype = torch.dtype
_device = torch.device
_qscheme = torch.qscheme
_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
_layout = torch.layout
_dispatchkey = Union[str, torch._C.DispatchKey]
class SymInt:
pass
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]
# Meta-type for "device-like" things. Not to be confused with 'device' (a
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
Device = Union[_device, str, _int, None]
# Storage protocol implemented by ${Type}StorageBase classes
class Storage(object):
_cdata: int
device: torch.device
dtype: torch.dtype
_torch_load_uninitialized: bool
def __deepcopy__(self, memo) -> 'Storage':
...
def _new_shared(self, int) -> 'Storage':
...
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None:
...
def element_size(self) -> int:
...
def is_shared(self) -> bool:
...
def share_memory_(self) -> 'Storage':
...
def nbytes(self) -> int:
...
def cpu(self) -> 'Storage':
...
def data_ptr(self) -> int:
...
def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage':
...
def _new_with_file(self, f: Any, element_size: int) -> 'Storage':
...
...
torch.types
中的_int
, _float
, _bool
就是Python內建的builtins.int
, builtins.float
, builtins.bool
。
PyTorch中定義的Number
則是_int
, _float
, _bool
中的其中一個。
builtins — Built-in objects
This module provides direct access to all ‘built-in’ identifiers of Python; for example, builtins.open is the full name for the built-in function open().
可以透過builtins
這個模組存取Python內建的identifier,例如Python中的open()
函數可以使用builtins.open
來存取。
參考What does the Star operator mean in Python?
Single asterisk as used in function declaration allows variable number of arguments passed from calling environment. Inside the function it behaves as a tuple.
在函數參數前加上*
表示可以接受任意個參數,在函數內部,該參數會被當成一個tuple。
def function(*arg):
print (type(arg))
for i in arg:
print (i)
function(1,2,3)
#
# 1
# 2
# 3