diff options
Diffstat (limited to 'gnu/packages/patches/python-pytorch-fix-codegen.patch')
-rw-r--r-- | gnu/packages/patches/python-pytorch-fix-codegen.patch | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/gnu/packages/patches/python-pytorch-fix-codegen.patch b/gnu/packages/patches/python-pytorch-fix-codegen.patch index cb246b25de..b30094de09 100644 --- a/gnu/packages/patches/python-pytorch-fix-codegen.patch +++ b/gnu/packages/patches/python-pytorch-fix-codegen.patch @@ -6,7 +6,7 @@ is later corrected. codegen_external.py is patched to avoid duplicate functions and add the static keyword as in the existing generated file. diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh -index cc0263dbbf..ac34e84b82 100644 +index cc0263d..ac34e84 100644 --- a/tools/gen_flatbuffers.sh +++ b/tools/gen_flatbuffers.sh @@ -1,13 +1,13 @@ @@ -32,10 +32,10 @@ index cc0263dbbf..ac34e84b82 100644 -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs" echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h" diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py -index bc69b05162..0f8df81de3 100644 +index 5dcf1b2..0e20b0c 100644 --- a/torch/csrc/jit/tensorexpr/codegen_external.py +++ b/torch/csrc/jit/tensorexpr/codegen_external.py -@@ -20,9 +20,14 @@ def gen_external(native_functions_path, tags_path, external_path): +@@ -21,9 +21,14 @@ def gen_external(native_functions_path, tags_path, external_path): native_functions = parse_native_yaml(native_functions_path, tags_path) func_decls = [] func_registrations = [] @@ -51,7 +51,7 @@ index bc69b05162..0f8df81de3 100644 args = schema.arguments # Only supports extern calls for functions with out variants if not schema.is_out_fn(): -@@ -62,7 +67,7 @@ def gen_external(native_functions_path, tags_path, external_path): +@@ -63,7 +68,7 @@ def gen_external(native_functions_path, tags_path, external_path): # print(tensor_decls, name, arg_names) func_decl = f"""\ @@ -61,7 +61,7 @@ index bc69b05162..0f8df81de3 100644 void** buf_data, int64_t* buf_ranks, diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py -index 7cfbb803f9..2e69bb1868 100644 +index 7a0024f..6b2445f 100644 --- a/torchgen/decompositions/gen_jit_decompositions.py +++ b/torchgen/decompositions/gen_jit_decompositions.py @@ -1,8 +1,12 @@ @@ -88,12 +88,12 @@ index 7cfbb803f9..2e69bb1868 100644 write_decomposition_util_file(str(upgrader_path)) diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py -index dab1568580..55c58715fc 100644 +index 2907076..6866332 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py -@@ -2,10 +2,12 @@ - import os +@@ -3,10 +3,12 @@ import os from enum import Enum + from operator import itemgetter from pathlib import Path +import sys from typing import Any, Dict, List @@ -106,7 +106,7 @@ index dab1568580..55c58715fc 100644 from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( -@@ -262,7 +264,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: +@@ -263,7 +265,10 @@ def construct_register_size(register_size_from_yaml: int) -> str: def construct_version_maps( upgrader_bytecode_function_to_index_map: Dict[str, Any] ) -> str: @@ -115,10 +115,10 @@ index dab1568580..55c58715fc 100644 + version_map = torch._C._get_operator_version_map() + else: + version_map = {} - sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0]) # type: ignore[no-any-return] + sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] sorted_version_map = dict(sorted_version_map_) -@@ -378,7 +383,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +@@ -379,7 +384,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def main() -> None: @@ -131,12 +131,12 @@ index dab1568580..55c58715fc 100644 for up in sorted_upgrader_list: print("after sort upgrader : ", next(iter(up))) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py -index c6336a6951..34e394d818 100644 +index bdfd5c7..72b237a 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -18,16 +18,20 @@ you are in the root directory of the Pytorch git repo""" if not file_path.exists(): - raise Exception(err_msg) + raise Exception(err_msg) # noqa: TRY002 -spec = importlib.util.spec_from_file_location(module_name, file_path) -assert spec is not None |