@@ -114,36 +114,41 @@ def create_return_code_frome_schema(schema, allow_return_ref=True):
114114def create_transform_input_to_cpu_code (fun_config ):
115115 input_process_code = ""
116116 schema = fun_config ["schema" ]
117+ opname = get_op_name_from_schema (schema )
117118 inputs = re .findall ("Tensor +([\w\d_]+)" , schema [: schema .find ("->" )])
118119 for input in inputs :
119120 input_process_code += (
120- f"at::Tensor { input } _cpu = to_cpu_without_diopi ({ input } );\n "
121+ f"at::Tensor { input } _cpu = toCpuTensorWithoutDiopiCopy ({ input } );\n "
121122 )
122123
123124 optional_inputs = re .findall ("Tensor *\? +([\w\d_]+)" , schema [: schema .find ("->" )])
124125 for input in optional_inputs :
125- input_process_code += f"\n c10::optional<at::Tensor> { input } _cpu = { input } .has_value() && { input } .value().defined() ? c10::make_optional<at::Tensor>(to_cpu_without_diopi ({ input } .value())) : { input } ;\n "
126+ input_process_code += f"\n c10::optional<at::Tensor> { input } _cpu = { input } .has_value() && { input } .value().defined() ? c10::make_optional<at::Tensor>(toCpuTensorWithoutDiopiCopy ({ input } .value())) : { input } ;\n "
126127
127128 optional_tensor_list_inputs = re .findall (
128129 "Tensor *\? *\[ *\] +([\w\d_]+)" , schema [: schema .find ("->" )]
129130 )
130131 for input in optional_tensor_list_inputs :
131132 input_process_code += f"\n c10::List<c10::optional<at::Tensor>> { input } _cpu;\n "
132133 input_process_code += f"for (int i = 0; i < { input } .size();++i)" + " {\n "
133- input_process_code += f" { input } _cpu.push_back({ input } [i].has_value() && { input } [i].value().defined() ? c10::make_optional<at::Tensor>(({ input } [i].value())) : { input } [i]);\n "
134+ input_process_code += f" { input } _cpu.push_back({ input } [i].has_value() && { input } [i].value().defined() ? c10::make_optional<at::Tensor>(toCpuTensorWithoutDiopiCopy ({ input } [i].value())) : { input } [i]);\n "
134135 input_process_code += "}\n "
135136
136137 outputs = re .findall (
137138 "Tensor\([a-z]!\)[ ]+([\w\d_]+){1}" , schema [: schema .find ("->" )]
138139 )
139140 for output in outputs :
140- if output .strip ().endswith ("?" ):
141- output = output .replace ("?" , "" )
142- input_process_code += f"\n c10::optional<at::Tensor> { output } _cpu = { output } .has_value() && { output } .value().defined() ? c10::make_optional<at::Tensor>(to_cpu_without_diopi({ output } .value()) : { output } ;\n "
143- else :
144- input_process_code += (
145- f"at::Tensor { output } _cpu = to_cpu_without_diopi({ output } );\n "
146- )
141+ input_process_code += (
142+ f"at::Tensor { output } _cpu = toCpuTensorWithoutDiopiCopy({ output } );\n "
143+ )
144+ if ".out" in opname or "_out" in opname :
145+ for i in range (len (inputs )):
146+ input_process_code += (
147+ f"if (({ inputs [i ]} .data_ptr()) == { output } .data_ptr())"
148+ )
149+ input_process_code += "{\n \t "
150+ input_process_code += f"{ inputs [i ]} _cpu = { output } _cpu;\n \t "
151+ input_process_code += "}\n "
147152
148153 tensors_arrays = re .findall (
149154 "Tensor *\[ *\] * +([\w\d_]+)" , schema [: schema .find ("->" )]
@@ -161,9 +166,8 @@ def create_transform_input_to_cpu_code(fun_config):
161166 )
162167 input_process_code += (
163168 f"std::transform({ tensors_arg } .begin(), { tensors_arg } .end(), { tensors_arg } _cpu.begin(), [](const at::Tensor& tensor)"
164- + "{return to_cpu_without_diopi (tensor);});\n "
169+ + "{return toCpuTensorWithoutDiopiCopy (tensor);});\n "
165170 )
166-
167171 return input_process_code
168172
169173
@@ -487,6 +491,9 @@ def create_call_aten_cpu_cpp_function_code_from_config(fun_config):
487491 code ,
488492 )
489493
494+ if "device" in code :
495+ code = code .replace ("device" , "at::kCPU" )
496+
490497 inputs = re .findall ("Tensor +([\w\d_]+)" , schema [: schema .find ("->" )])
491498 optional_inputs = re .findall ("Tensor *\? +([\w\d_]+)" , schema [: schema .find ("->" )])
492499 outputs = re .findall (
@@ -550,7 +557,6 @@ def create_result_compare_code(fun_config):
550557 for i in range (len (inputs )):
551558 code += separator_code
552559 code += f'std::cout << "autocompare:\t { op_name } \t { inputs [i ]} : " << std::endl << allclose_autocompare({ inputs [i ]} _cpu, { inputs [i ]} ) << std::endl;\n '
553-
554560 return code
555561
556562
@@ -972,9 +978,12 @@ def functions_code_gen(fun_config):
972978
973979
974980def boolean_string (s ):
975- if s not in {"False" , "True" }:
976- raise ValueError ("Not a valid boolean string" )
977- return s == "True"
981+ if s .lower () in ["true" , "on" ]:
982+ return True
983+ elif s .lower () in ["false" , "off" ]:
984+ return False
985+ else :
986+ raise ValueError ("Not a valid boolean string." )
978987
979988
980989def parse_args ():
0 commit comments