PyTorch中的python_torch_functions_i.cpp檔案生成機制

PyTorch中的python_torch_functions_i.cpp檔案生成機制

  • 前言
  • setup.py
    • main
    • build_deps
  • tools/build_pytorch_libs.py
    • build_caffe2
  • caffe2/CMakeLists.txt
  • tools/setup_helpers/generate_code.py
    • main
    • generate_code
  • tools/autograd/gen_autograd.py
    • gen_autograd_python
  • tools/autograd/gen_python_functions.py
    • gen_python_functions.gen
    • create_python_bindings_sharded
  • torchgen/utils.py
    • write_sharded
  • 生成結果

前言

編譯PyTorch後,torch/csrc/autograd/generated/目錄下會有python_torch_functions_0.cpp,python_torch_functions_1.cpppython_torch_functions_2.cpp等檔案,本文便從setup.py依次來探討這些檔案是如何生成的。

setup.py

main

setup.py

################################################################################
# Parameters parsed from environment
################################################################################

VERBOSE_SCRIPT = True
RUN_BUILD_DEPS = True

filtered_args = []
for i, arg in enumerate(sys.argv):
    # ...
    if arg in ['clean', 'egg_info', 'sdist']:
        RUN_BUILD_DEPS = False

# ...

def main():
    # ...
    if RUN_BUILD_DEPS:
        build_deps()

RUN_BUILD_DEPS預設為True,如果RUN_BUILD_DEPS為True,則運行build_deps函數,推測是用於建構PyTorch的dependencies。

build_deps

setup.py

from tools.build_pytorch_libs import build_caffe2
# ...

# all the work we need to do _before_ setup runs
def build_deps():
    #...

    build_caffe2(version=version,
                 cmake_python_library=cmake_python_library,
                 build_python=True,
                 rerun_cmake=RERUN_CMAKE,
                 cmake_only=CMAKE_ONLY,
                 cmake=cmake)

build_deps函數中最主要的部份便是調用build_caffe2

tools/build_pytorch_libs.py

build_caffe2

tools/build_pytorch_libs.py

def build_caffe2(
    version: Optional[str],
    cmake_python_library: Optional[str],
    build_python: bool,
    rerun_cmake: bool,
    cmake_only: bool,
    cmake: CMake,
) -> None:
    my_env = _create_build_env()
    build_test = not check_negative_env_flag("BUILD_TEST")
    cmake.generate(
        version, cmake_python_library, build_python, build_test, my_env, rerun_cmake
    )
    if cmake_only:
        return
    cmake.build(my_env)
    if build_python:
        caffe2_proto_dir = os.path.join(cmake.build_dir, "caffe2", "proto")
        for proto_file in glob(os.path.join(caffe2_proto_dir, "*.py")):
            if proto_file != os.path.join(caffe2_proto_dir, "__init__.py"):
                shutil.copy(proto_file, os.path.join("caffe2", "proto"))

當中有運行caffe2/CMakeLists.txt?

caffe2/CMakeLists.txt

caffe2/CMakeLists.txt

#...
file(GLOB_RECURSE autograd_python "${TOOLS_PATH}/autograd/*.py")
file(GLOB_RECURSE autograd_yaml "${TOOLS_PATH}/autograd/*.yaml")
file(GLOB_RECURSE autograd_templates "${TOOLS_PATH}/autograd/templates/*")
add_custom_command(
  OUTPUT
  ${TORCH_GENERATED_CODE}
  COMMAND
  "${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
    --native-functions-path "aten/src/ATen/native/native_functions.yaml"
    --tags-path "aten/src/ATen/native/tags.yaml"
    $<$:--disable-autograd>
    $<$:--selected-op-list-path="${SELECTED_OP_LIST}">
    --force_schema_registration
    --gen_lazy_ts_backend
    ${GEN_PER_OPERATOR_FLAG}
  DEPENDS
    "${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml"
    "${TORCH_ROOT}/aten/src/ATen/native/tags.yaml"
    "${TORCH_ROOT}/aten/src/ATen/native/ts_native_functions.yaml"
    "${TORCH_ROOT}/torch/csrc/lazy/core/shape_inference.h"
    "${TORCH_ROOT}/torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
    "${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
    "${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
    "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
    "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
    "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
    ${autograd_python}
    ${autograd_yaml}
    ${autograd_templates}
    ${torchgen_python}
  WORKING_DIRECTORY "${TORCH_ROOT}")
#...

add_custom_command這一段是由COMMAND利用DEPENDS中所列出的檔案生成OUTPUT

先來看一下DEPENDS當中的${autograd_templates}

file(GLOB_RECURSE autograd_templates "${TOOLS_PATH}/autograd/templates/*")

當中的${TOOLS_PATH}是:

# Generate files
set(TOOLS_PATH "${TORCH_ROOT}/tools")

所以${autograd_templates}指的是tools/autograd/templates/目錄下的所有檔案,其中就包含了tools/autograd/templates/python_torch_functions.cpp

python_torch_functions.cpp中最核心的一段代碼如下:

static PyMethodDef torch_functions_shard[] = {
  ${py_method_defs}
};

其中${py_method_defs}的位置便是為了待會自動生成代碼時預留的空位。注意到torch_functions_shard的型別是PyMethodDef的陣列,詳見PyMethodDef。

接著看COMMAND,它會調用tools/setup_helpers/generate_code.py,由DEPENDS(包含aten/src/ATen/native/native_functions.yamltools/autograd/templates/目錄下的所有檔案)生成OUTPUT,即${TORCH_GENERATED_CODE}

set(TORCH_GENERATED_CODE
  ${GENERATED_CXX_TORCH}
  ${GENERATED_H_TORCH}
  ${GENERATED_CXX_PYTHON}
  ${GENERATED_H_PYTHON}
  ${GENERATED_TESTING_PYTHON}
  )

當中的GENERATED_CXX_PYTHON如下:

set(GENERATED_CXX_PYTHON
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_0.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_1.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_2.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_3.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions_4.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_variable_methods.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_0.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_1.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_torch_functions_2.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_fft_functions.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_linalg_functions.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_nested_functions.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp"
  "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp"
  )

可以看到當中便包含了python_torch_functions_0.cpp,python_torch_functions_1.cpppython_torch_functions_2.cpp等三個檔案。

所以這一段command就是調用tools/setup_helpers/generate_code.py,由python_torch_functions.cppnative_functions.yaml生成python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp

接著深入generate_code.py,看看python_torch_functions_i.cpp具體是如何生成的。

tools/setup_helpers/generate_code.py

main

tools/setup_helpers/generate_code.py

def main() -> None:
    parser = argparse.ArgumentParser(description="Autogenerate code")
    parser.add_argument("--native-functions-path")
    parser.add_argument("--tags-path")
    parser.add_argument(
        "--gen-dir",
        type=pathlib.Path,
        default=pathlib.Path("."),
        help="Root directory where to install files. Defaults to the current working directory.",
    )
    parser.add_argument(
        "--install_dir",
        help=(
            "Deprecated. Use --gen-dir instead. The semantics are different, do not change "
            "blindly."
        ),
    )
    parser.add_argument(
        "--subset",
        help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.',
    )
    parser.add_argument(
        "--disable-autograd",
        default=False,
        action="store_true",
        help="It can skip generating autograd related code when the flag is set",
    )
    parser.add_argument(
        "--selected-op-list-path",
        help="Path to the YAML file that contains the list of operators to include for custom build.",
    )
    parser.add_argument(
        "--operators_yaml_path",
        help="Path to the model YAML file that contains the list of operators to include for custom build.",
    )
    parser.add_argument(
        "--force_schema_registration",
        action="store_true",
        help="force it to generate schema-only registrations for ops that are not"
        "listed on --selected-op-list",
    )
    parser.add_argument(
        "--gen_lazy_ts_backend",
        action="store_true",
        help="Enable generation of the torch::lazy TorchScript backend",
    )
    parser.add_argument(
        "--per_operator_headers",
        action="store_true",
        help="Build lazy tensor ts backend with per-operator ATen headers, must match how ATen was built",
    )
    options = parser.parse_args()

    generate_code(
        options.gen_dir,
        options.native_functions_path,
        options.tags_path,
        options.install_dir,
        options.subset,
        options.disable_autograd,
        options.force_schema_registration,
        # options.selected_op_list
        operator_selector=get_selector(
            options.selected_op_list_path, options.operators_yaml_path
        ),
    )

    if options.gen_lazy_ts_backend:
        aten_path = os.path.dirname(os.path.dirname(options.native_functions_path))
        ts_backend_yaml = os.path.join(aten_path, "native/ts_native_functions.yaml")
        ts_native_functions = "torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
        ts_node_base = "torch/csrc/lazy/ts_backend/ts_node.h"
        install_dir = options.install_dir or os.fspath(options.gen_dir / "torch/csrc")
        lazy_install_dir = os.path.join(install_dir, "lazy/generated")
        os.makedirs(lazy_install_dir, exist_ok=True)

        assert os.path.isfile(
            ts_backend_yaml
        ), f"Unable to access ts_backend_yaml: {ts_backend_yaml}"
        assert os.path.isfile(
            ts_native_functions
        ), f"Unable to access {ts_native_functions}"
        from torchgen.dest.lazy_ir import GenTSLazyIR
        from torchgen.gen_lazy_tensor import run_gen_lazy_tensor

        run_gen_lazy_tensor(
            aten_path=aten_path,
            source_yaml=ts_backend_yaml,
            backend_name="TorchScript",
            output_dir=lazy_install_dir,
            dry_run=False,
            impl_path=ts_native_functions,
            node_base="TsNode",
            node_base_hdr=ts_node_base,
            build_in_tree=True,
            lazy_ir_generator=GenTSLazyIR,
            per_operator_headers=options.per_operator_headers,
            gen_forced_fallback_code=True,
        )


if __name__ == "__main__":
    main()

裡面最關鍵的便是generate_code函數。

  • options.gen_dir:預設是’.’
  • options.native_functions_path:從CMakeLists.txt傳入,是為aten/src/ATen/native/native_functions.yaml
  • options.tags_path:從CMakeLists.txt傳入,是為aten/src/ATen/native/tags.yaml

generate_code

tools/setup_helpers/generate_code.py

def generate_code(
    gen_dir: pathlib.Path,
    native_functions_path: Optional[str] = None,
    tags_path: Optional[str] = None,
    install_dir: Optional[str] = None,
    subset: Optional[str] = None,
    disable_autograd: bool = False,
    force_schema_registration: bool = False,
    operator_selector: Any = None,
) -> None:
    from torchgen.selective_build.selector import SelectiveBuilder

    from tools.autograd.gen_annotated_fn_args import gen_annotated
    from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python

    # Build ATen based Variable classes
    if install_dir is None:
        install_dir = os.fspath(gen_dir / "torch/csrc")
        python_install_dir = os.fspath(gen_dir / "torch/testing/_internal/generated")
    else:
        python_install_dir = install_dir
    autograd_gen_dir = os.path.join(install_dir, "autograd", "generated")
    for d in (autograd_gen_dir, python_install_dir):
        os.makedirs(d, exist_ok=True)
    autograd_dir = os.fspath(pathlib.Path(__file__).parent.parent / "autograd")

    if subset == "pybindings" or not subset:
        gen_autograd_python(
            native_functions_path or NATIVE_FUNCTIONS_PATH,
            tags_path or TAGS_PATH,
            autograd_gen_dir,
            autograd_dir,
        )

    if operator_selector is None:
        operator_selector = SelectiveBuilder.get_nop_selector()

    if subset == "libtorch" or not subset:

        gen_autograd(
            native_functions_path or NATIVE_FUNCTIONS_PATH,
            tags_path or TAGS_PATH,
            autograd_gen_dir,
            autograd_dir,
            disable_autograd=disable_autograd,
            operator_selector=operator_selector,
        )

    if subset == "python" or not subset:
        gen_annotated(
            native_functions_path or NATIVE_FUNCTIONS_PATH,
            tags_path or TAGS_PATH,
            python_install_dir,
            autograd_dir,
        )

首先設定出以下變數:

  • gen_dir:‘.’
  • native_functions_path :aten/src/ATen/native/native_functions.yaml
  • tags_path:aten/src/ATen/native/tags.yaml
  • install_dirgen_dir/torch/csrc → \rarr ‘./torch/csrc’
  • autograd_gen_dirinstall_dir/autograd/generated → \rarr ./torch/csrc/autograd/generated
  • autograd_dir: tools/autograd

此處關注的是python_torch_functions_i.cpp,所以接著進入gen_autograd_python函數。

tools/autograd/gen_autograd.py

gen_autograd_python

tools/autograd/gen_autograd.py

def gen_autograd_python(
    native_functions_path: str,
    tags_path: str,
    out: str,
    autograd_dir: str,
) -> None:
    differentiability_infos, _ = load_derivatives(
        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
    )

    template_path = os.path.join(autograd_dir, "templates")

    # Generate Functions.h/cpp
    gen_autograd_functions_python(out, differentiability_infos, template_path)

    # Generate Python bindings
    deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
    gen_python_functions.gen(
        out, native_functions_path, tags_path, deprecated_path, template_path
    )

首先設定出以下變數:

  • out:./torch/csrc/autograd/generated
  • native_functions_path:aten/src/ATen/native/native_functions.yaml
  • tags_path:aten/src/ATen/native/tags.yaml
  • deprecated_path:tools/autograd/deprecated.yaml
  • template_path:tools/autograd/templates

其中gen_autograd_functions_python函數生成torch/csrc/autograd/generated資料夾下的python_functionsEverything.cpppython_functions_0.cpppython_functions_4.cpp

gen_python_functions.gen則生成其它許多檔案,包括我們所關注的torch/csrc/autograd/generated資料夾下的python_torch_functionsEverything.cpppython_torch_functions_0.cpppython_torch_functions_2.cpp

接著繼續深入gen_python_functions.gen函數。

tools/autograd/gen_python_functions.py

gen_python_functions.gen

tools/autograd/gen_python_functions.py

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
#                            Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def gen(
    out: str,
    native_yaml_path: str,
    tags_yaml_path: str,
    deprecated_yaml_path: str,
    template_path: str,
    *,
    symint: bool = True,
) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    native_functions = parse_native_yaml(
        native_yaml_path, tags_yaml_path
    ).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm,
        methods,
        is_py_variable_method,
        None,
        "python_variable_methods.cpp",
        method=True,
        symint=symint,
    )

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
    create_python_bindings_sharded(
        fm,
        functions,
        is_py_torch_function,
        "torch",
        "python_torch_functions.cpp",
        method=False,
        num_shards=3,
        symint=symint,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_nn_function,
        "torch.nn",
        "python_nn_functions.cpp",
        method=False,
        symint=symint,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_fft_function,
        "torch.fft",
        "python_fft_functions.cpp",
        method=False,
        symint=symint,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_linalg_function,
        "torch.linalg",
        "python_linalg_functions.cpp",
        method=False,
        symint=symint,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_nested_function,
        "torch.nested",
        "python_nested_functions.cpp",
        method=False,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_sparse_function,
        "torch.sparse",
        "python_sparse_functions.cpp",
        method=False,
        symint=symint,
    )

    create_python_bindings(
        fm,
        functions,
        is_py_special_function,
        "torch.special",
        "python_special_functions.cpp",
        method=False,
        symint=symint,
    )

    # Currently, we only use `functions` to generate `return_types` bindings.
    # All methods which return namedtuple have function variant at this point.
    # If any method only operator with namedtuple is added in the future,
    # we will have to address that.
    create_python_return_type_bindings(
        fm, functions, lambda fn: True, "python_return_types.cpp"
    )

    valid_tags = parse_tags_yaml(tags_yaml_path)

    def gen_tags_enum() -> Dict[str, str]:
        return {
            "enum_of_valid_tags": (
                "".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
            )
        }

    fm.write("python_enum_tag.cpp", gen_tags_enum)

首先創建一個FileManager變數fm,它的前兩個參數如下(詳見FileManager建構子):

  • install_dir:./torch/csrc/autograd/generated
  • template_dir:tools/autograd/templates

此處解析native_functions.yaml後得到native_functions

    native_functions = parse_native_yaml(
        native_yaml_path, tags_yaml_path
    ).native_functions
    native_functions = list(filter(should_generate_py_binding, native_functions))

接著載入它們的函數簽名,得到functions

    # NOTE: num_shards here must be synced with gatherTorchFunctions in
    #       torch/csrc/autograd/python_torch_functions_manual.cpp
    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)

後續將解析出來的函數簽名傳入create_python_bindings_sharded,生成python_torch_functions_0.cpp,python_torch_functions_1.cpppython_torch_functions_2.cpp等三個檔案。

    create_python_bindings_sharded(
        fm,
        functions,
        is_py_torch_function,
        "torch",
        "python_torch_functions.cpp",
        method=False,
        num_shards=3,
        symint=symint,
    )

繼續深入create_python_bindings_sharded

create_python_bindings_sharded

tools/autograd/gen_python_functions.py

def create_python_bindings_sharded(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    module: Optional[str],
    filename: str,
    *,
    method: bool,
    num_shards: int,
    symint: bool = True,
) -> None:
    """Generates Python bindings to ATen functions"""
    grouped = group_filter_overloads(pairs, pred)

    def key_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> str:
        return kv[0].base

    def env_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> Dict[str, List[str]]:
        name, fn_pairs = kv
        return {
            "ops_headers": [f"#include {name.base}.h>"],
            "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
            "py_methods": [
                method_impl(name, module, fn_pairs, method=method, symint=symint)
            ],
            "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
        }

    fm.write_sharded(
        filename,
        grouped.items(),
        base_env={
            "generated_comment": "@"
            + f"generated from {fm.template_dir_for_comments()}/{filename}",
        },
        key_fn=key_func,
        env_callable=env_func,
        num_shards=num_shards,
        sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
    )

write_sharded的參數如下:

  • filenamepython_torch_functions.cpp
  • grouped.items():一個Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair],也就是將運算子名稱對應到(函數簽名,函數)對列表的一個字典

繼續深入write_sharded

torchgen/utils.py

write_sharded

torchgen/utils.py

    def write_sharded(
        self,
        filename: str,
        items: Iterable[T],
        *,
        key_fn: Callable[[T], str],
        env_callable: Callable[[T], Dict[str, List[str]]],
        num_shards: int,
        base_env: Optional[Dict[str, Any]] = None,
        sharded_keys: Set[str],
    ) -> None:

        everything: Dict[str, Any] = {"shard_id": "Everything"}
        shards: List[Dict[str, Any]] = [
            {"shard_id": f"_{i}"} for i in range(num_shards)
        ]
        all_shards = [everything] + shards

        if base_env is not None:
            for shard in all_shards:
                shard.update(base_env)

        for key in sharded_keys:
            for shard in all_shards:
                if key in shard:
                    assert isinstance(
                        shard[key], list
                    ), "sharded keys in base_env must be a list"
                    shard[key] = shard[key].copy()
                else:
                    shard[key] = []

        def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
            for k, v in from_.items():
                assert k in sharded_keys, f"undeclared sharded key {k}"
                into[k] += v

        if self.dry_run:
            # Dry runs don't write any templates, so incomplete environments are fine
            items = ()

        for item in items:
            key = key_fn(item)
            sid = string_stable_hash(key) % num_shards
            env = env_callable(item)

            merge_env(shards[sid], env)
            merge_env(everything, env)

        dot_pos = filename.rfind(".")
        if dot_pos == -1:
            dot_pos = len(filename)
        base_filename = filename[:dot_pos]
        extension = filename[dot_pos:]

        for shard in all_shards:
            shard_id = shard["shard_id"]
            self.write_with_template(
                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
            )

        # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
        self.filenames.discard(
            f"{self.install_dir}/{base_filename}Everything{extension}"
        )

注意到all_shards是由[everything]shards所組成,是一個List[Dict[str, Any]]

而以下這一段會把env字典裡的內容併入shards[sid]everything,所以會間接更新all_shards

        for item in items:
            key = key_fn(item)
            sid = string_stable_hash(key) % num_shards
            env = env_callable(item)

            merge_env(shards[sid], env)
            merge_env(everything, env)

因此all_shards中的內容除了下面的之外:

[{'shard_id': 'Everything'}, {'shard_id': '_0'}, {'shard_id': '_1'}, {'shard_id': '_2'}]

當中的各字典還包括了env裡的內容。

最後一段for迴圈遍歷all_shards,一一調用write_with_template,會調用lambda: shard函數由filename生成f"{base_filename}{shard_id}{extension}",也就是由python_torch_functions.cpp生成python_torch_functionsEverything.cpp, python_torch_functions_0.cpp, python_torch_functions_1.cpppython_torch_functions_2.cpp

write_with_template函數已獨立成篇,詳見PyTorch檔案生成機制中的FileManager.write_with_template。

生成結果

回顧create_python_bindings_sharded,那裡列出了generated_comment, ops_headers, py_forwards, py_method_defs, py_methods等key。

python_torch_functions.cpp中各key會被替換成:

  • generated_comment

    // ${generated_comment}
    

    會被替換成:

    // @generated from ../tools/autograd/templates/python_torch_functions.cpp
    
  • ops_headers

    #ifndef AT_PER_OPERATOR_HEADERS
    #include 
    #else
    $ops_headers
    #endif
    

    被替換成:

    #ifndef AT_PER_OPERATOR_HEADERS
    #include 
    #else
    #include 
    // ...
    #include 
    #endif
    

    表示如果沒有定義AT_PER_OPERATOR_HEADERS這個巨集就會include torch/include/ATen/Functions.h(或build/aten/src/ATen/Functions.h),否則include各算子專屬的headers。筆者的環境中沒有定義AT_PER_OPERATOR_HEADERS,所以torch/include/ATen/ops資料夾下只有from_blob.htensor.h兩個檔案。

  • py_forwards

    // generated forward declarations start here
    
    ${py_forwards}
    

    生成結果如:

    static PyObject * THPVariable__cast_Byte(PyObject* self_, PyObject* args, PyObject* kwargs);
    

    是為C++與Python介接函數的宣告。

  • py_method_defs

    static PyMethodDef torch_functions_shard[] = {
      ${py_method_defs}
    };
    

    生成結果如:

      {"_cast_Byte", castPyCFunctionWithKeywords(THPVariable__cast_Byte), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
    

    是為PyMethodDef結構體,詳見撰寫自己的Python C擴展!和Python/C API - 模組,型別,Tuple,例外和引用計數。

  • py_methods

    // generated methods start here
    
    ${py_methods}
    

    生成結果如下:

    // _cast_Byte
    static PyObject * THPVariable__cast_Byte(PyObject* self_, PyObject* args, PyObject* kwargs)
    {
      HANDLE_TH_ERRORS
      static PythonArgParser parser({
        "_cast_Byte(Tensor input, bool non_blocking=False)",
      }, /*traceable=*/true);
    
      ParsedArgs<2> parsed_args;
      auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
      if(_r.has_torch_function()) {
        return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
      }
      // aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
      
      auto dispatch__cast_Byte = [](const at::Tensor & self, bool non_blocking) -> at::Tensor {
        pybind11::gil_scoped_release no_gil;
        return at::_cast_Byte(self, non_blocking);
      };
      return wrap(dispatch__cast_Byte(_r.tensor(0), _r.toBool(1)));
      Py_RETURN_NONE;
      END_HANDLE_TH_ERRORS
    }
    

    是為介接函數的定義,詳見撰寫自己的Python C擴展!。

你可能感兴趣的:(python,c++,pytorch)