PyTorch檔案生成機制中的FileManager.write_with_template

PyTorch檔案生成機制中的FileManager.write_with_template

  • 前言
  • FileManager.write_with_template調用
    • gen_pyi
    • gen_nn_functional
    • write_sharded
  • FileManager.write_with_template實現
    • torchgen/utils.py
      • FileManager.write_with_template
      • FileManager.substitute_with_template
      • _read_template
    • torchgen/code_template.py
      • CodeTemplate
      • CodeTemplate.from_file
      • CodeTemplate.\__init__
      • substitute

前言

PyTorch中有些檔案是在編譯過程中跑腳本生成的,如.pyi檔是由.pyi.in檔生成,torch/csrc/autograd/generated目錄下的.cpp檔則是由tools/autograd/templates下的template .cpp檔生成的。

它們底層都是調用FileManager.write_with_template函數,其功能是對原檔案中的特定字串依照callback function所指示的方式做替換,進而生成對應的.pyi.cpp檔。

本文會先查看FileManager.write_with_template函數是如何被調用的,再細看它的實現。

FileManager.write_with_template調用

gen_pyi

tools/pyi/gen_pyi.py

    fm.write_with_template(
        "torch/_C/__init__.pyi",
        "torch/_C/__init__.pyi.in",
        lambda: {
            "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/_VF.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )
    fm.write_with_template(
        "torch/return_types.pyi",
        "torch/_C/return_types.pyi.in",
        lambda: {
            "generated_comment": "@" + "generated from torch/_C/return_types.pyi",
            **env,
        },
    )
    gen_nn_functional(fm)

此處的四個fm.write_with_template會由torch/_C資料夾下的四個.pyi.in檔生成torch/_C資料夾下的__init__.pyi, _VariableFunctions.pyitorch資料夾下的_VF.pyi, return_types.pyi

gen_nn_functional

tools/pyi/gen_pyi.py

def gen_nn_functional(fm: FileManager) -> None:
    # ...
    fm.write_with_template(
        "torch/nn/functional.pyi",
        "torch/nn/functional.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )
    # ...
    fm.write_with_template(
        "torch/_C/_nn.pyi",
        "torch/_C/_nn.pyi.in",
        lambda: {
            "imported_hints": import_code,
            "dispatched_hints": dispatch_code,
        },
    )

此處的兩個fm.write_with_template會由torch/nn/functional.pyi.intorch/_C/_nn.pyi.in生成torch/nn/functional.pyitorch/_C/_nn.pyi.in

write_sharded

torchgen/utils.py

    def write_sharded(
        self,
        filename: str,
        items: Iterable[T],
        *,
        key_fn: Callable[[T], str],
        env_callable: Callable[[T], Dict[str, List[str]]],
        num_shards: int,
        base_env: Optional[Dict[str, Any]] = None,
        sharded_keys: Set[str],
    ) -> None:
        #...
        for shard in all_shards:
            shard_id = shard["shard_id"]
            self.write_with_template(
                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
            )
        #...

其中的all_shards為:

[{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]

所以這裡的write_with_template會由filenamepython_torch_functions.cpp生成python_torch_functionsEverything.cpp, python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp四個檔案。

注意到上面三個例子中,write_with_template的第三個參數(env_callable)都是一個呼叫後會返回dict的lambda函數。

FileManager.write_with_template實現

torchgen/utils.py

FileManager.write_with_template

write_with_template除了self以外有三個參數:

  • filename:生成的.pyi的檔名或.cpp的檔名
  • template_fn:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • env_callable:在做替換時會用到的callback function
    def write_with_template(
        self,
        filename: str,
        template_fn: str,
        env_callable: Callable[[], Union[str, Dict[str, Any]]],
    ) -> None:
        filename = "{}/{}".format(self.install_dir, filename)
        assert filename not in self.filenames, "duplicate file write {filename}"
        self.filenames.add(filename)
        if not self.dry_run:
            substitute_out = self.substitute_with_template(
                template_fn=template_fn,
                env_callable=env_callable,
            )
            self._write_if_changed(filename=filename, contents=substitute_out)

可以看到這段代碼最核心的內容就是調用substitute_with_template生成substitute_out

之後再將替換後的結果,也就是substitute_out寫入filename.pyi檔)這個檔案中。

注:在做類型檢查時,callback function是由typing.Callable表示的,詳見Python typing函式庫和torch.types。

FileManager.substitute_with_template

torchgen/utils.py

self外有兩個參數:

  • template_fn:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • env_callable:在做替換時會用到的callback function
    # Read from template file and replace pattern with callable (type could be dict or str).
    def substitute_with_template(
        self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
    ) -> str:
        template_path = os.path.join(self.template_dir, template_fn)
        env = env_callable()
        if isinstance(env, dict):
            # TODO: Update the comment reference to the correct location
            if "generated_comment" not in env:
                comment = "@" + "generated by torchgen/gen.py"
                comment += " from {}".format(os.path.basename(template_path))
                env["generated_comment"] = comment
            template = _read_template(template_path)
            return template.substitute(env)
        elif isinstance(env, str):
            return env
        else:
            assert_never(env)

env_callable是一個呼叫後會返回dict的lambda函數,所以會進入isinstance(env, dict)這個分支,先由_read_template讀入template檔案(.pyi.in檔或template .cpp檔)後調用template.substitute

_read_template

torchgen/utils.py

參數template_fnpyi或template cpp的檔名。

@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
    return CodeTemplate.from_file(template_fn)

讀入template_fn,生成CodeTemplate物件並回傳。

torchgen/code_template.py

CodeTemplate

torchgen/code_template.py

先來看看CodeTemplate類別的作用。

# match $identifier or ${identifier} and replace with value in env
# If this identifier is at the beginning of whitespace on a line
# and its value is a list then it is treated as
# block substitution by indenting to that depth and putting each element
# of the list on its own line
# if the identifier is on a line starting with non-whitespace and a list
# then it is comma separated ${,foo} will insert a comma before the list
# if this list is not empty and ${foo,} will insert one after.


class CodeTemplate:
    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
    substitution = re.compile(substitution_str, re.MULTILINE)

    pattern: str
    filename: str
    
    # ...

注釋裡說明了CodeTemplate的功用是把模板中${identifier}字樣替換成env中對應的value。

torch/_C/_VariableFunctions.pyi.in中就有以下字樣:

# ${generated_comment}
# ...
${function_hints}

${all_directive}

python_torch_functions.cpp中則有以下字樣:

#ifndef AT_PER_OPERATOR_HEADERS
#include 
#else
$ops_headers
#endif
    
// ...
// generated forward declarations start here

${py_forwards}

// ...
static PyMethodDef torch_functions_shard[] = {
  ${py_method_defs}
};

// ...
// generated methods start here

${py_methods}

CodeTemplate.from_file

torchgen/code_template.py

class CodeTemplate:
    # ...

    @staticmethod
    def from_file(filename: str) -> "CodeTemplate":
        with open(filename, "r") as f:
            return CodeTemplate(f.read(), filename)
        
    # ...

調用CodeTemplate的建構子,傳入filename的內容及名稱。

CodeTemplate._init_

  • filename:作為輸入的.pyi.in的檔名或template .cpp的檔名
  • pattern:在CodeTemplate.from_file中是以CodeTemplate(f.read(), filename)調用CodeTemplate建構子,所以pattern成員變數會被設為從filename檔案裡讀出來的東西
class CodeTemplate:
    # ...
    
    def __init__(self, pattern: str, filename: str = "") -> None:
        self.pattern = pattern
        self.filename = filename
        
    # ...

substitute

torchgen/code_template.py

回顧torchgen/utils.pysubstitute_with_template中的:

            template = _read_template(template_path)

生成了CodeTemplate物件template後繼續調用:

            return template.substitute(env)

其功能是做一些正則替換:

class CodeTemplate:
    # ...
    def substitute(
        self, env: Optional[Mapping[str, object]] = None, **kwargs: object
    ) -> str:
        if env is None:
            env = {}

        def lookup(v: str) -> object:
            assert env is not None
            return kwargs[v] if v in kwargs else env[v]

        def indent_lines(indent: str, v: Sequence[object]) -> str:
            return "".join(
                [indent + l + "\n" for e in v for l in str(e).splitlines()]
            ).rstrip()

        def replace(match: Match[str]) -> str:
            indent = match.group(1)
            key = match.group(2)
            comma_before = ""
            comma_after = ""
            if key[0] == "{":
                key = key[1:-1]
                if key[0] == ",":
                    comma_before = ", "
                    key = key[1:]
                if key[-1] == ",":
                    comma_after = ", "
                    key = key[:-1]
            v = lookup(key)
            if indent is not None:
                if not isinstance(v, list):
                    v = [v]
                return indent_lines(indent, v)
            elif isinstance(v, list):
                middle = ", ".join([str(x) for x in v])
                if len(v) == 0:
                    return middle
                return comma_before + middle + comma_after
            else:
                return str(v)

        return self.substitution.sub(replace, self.pattern)

函數最後的self.substitution.sub(replace, self.pattern)中的self.substitutionCodeTemplate的成員:

    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
    substitution = re.compile(substitution_str, re.MULTILINE)

re.compile後得到的substitution是一個re.Pattern物件。

先來看看re.Pattern.sub是什麼,參考Passing a function to re.sub in Python及Python: re.compile and re.sub中給出的例子:

import re
substitution = re.compile(r'\d')
number_mapping = {'1': 'one', '2': 'two', '3': 'three'}
s = "1 testing 2 3"
substitution.sub(lambda x: number_mapping[x.group()], s) # 'one testing two three'

re.Pattern.sub的第一個參數是做替換的函數,第二個參數則是欲處理的字串,它會尋找特定樣式的字串(此處是r'\d'),對它們做替換後回傳。

所以self.substitution.sub(replace, self.pattern)這句是在self.pattern(也就是pyi.in或template cpp檔中的內容)中尋找substitution_str樣式的字串,並用replace這個函數所指定的方式做替換。

得到替換後的結果後,回到substitute_with_template函數:

            return template.substitute(env)

那裡繼續將結果回傳,來到write_with_template函數:

            substitute_out = self.substitute_with_template(
                template_fn=template_fn,
                env_callable=env_callable,
            )
            self._write_if_changed(filename=filename, contents=substitute_out)

在那裡會把替換結果substitute_out寫入filename,也就是生成的.pyi的檔名或.cpp的檔名。

來看看torch/_C/_VariableFunctions.pyi中的${generated_comment}

回顧gen_pyi函數中呼叫write_with_template時,與env一同傳入了generated_comment的key value pair:

    fm.write_with_template(
        "torch/_C/_VariableFunctions.pyi",
        "torch/_C/_VariableFunctions.pyi.in",
        lambda: {
            "generated_comment": "@"
            + "generated from torch/_C/_VariableFunctions.pyi.in",
            **env,
        },
    )

所以到了substitute函數,env參數便是一個包含generated_comment的key value pair的字典。

# ${generated_comment}在做替換後,會變成生成的torch/_C/_VariableFunctions.pyi檔案中的第一行:

# @generated from torch/_C/_VariableFunctions.pyi.in

你可能感兴趣的:(PyTorch,pytorch,python,c++)