3737from paddle .framework import in_dynamic_mode , use_pir_api
3838from paddle .nn .layer import layers
3939from paddle .pir import Value
40- from paddle .pir .core import _convert_into_value , static_op_arg_cast_guard
4140from paddle .utils import flatten , gast
4241
4342from . import error , logging_utils
6665 backend_guard ,
6766 cuda_pinned_tensors_move_to_excepted_place ,
6867 func_to_source_code ,
68+ graph_tracing_guard ,
6969 input_specs_compatible ,
7070 is_paddle_func ,
7171 make_hashable ,
@@ -1265,8 +1265,7 @@ def pir_from_func_spec(
12651265
12661266 with (
12671267 ir_static .program_guard (main_program , startup_program ),
1268- to_static_mode_guard (is_to_static = True ),
1269- static_op_arg_cast_guard (_convert_into_value ),
1268+ graph_tracing_guard (main_program ) as ctx ,
12701269 ):
12711270 # 1. Adds `paddle.static.data` layers for input if needed
12721271 static_inputs , program_inputs = (
@@ -1309,16 +1308,6 @@ def pir_from_func_spec(
13091308 error_data .raise_new_exception ()
13101309 raise
13111310
1312- # 3. Gets all ParamBases and buffered VarBases in the function
1313- from ..pir_dy2static .parameter_recorder import (
1314- _global_inplace_map ,
1315- _global_parameter_recorder ,
1316- )
1317-
1318- all_parameters_and_buffers = _global_parameter_recorder .pop (
1319- main_program
1320- )
1321- _global_inplace_map .pop (main_program )
13221311 if outputs is not None :
13231312 need_wrap_into_list = (
13241313 not isinstance (outputs , (tuple , list )) or len (outputs ) == 1
@@ -1334,7 +1323,7 @@ def pir_from_func_spec(
13341323 return ConcreteProgram (
13351324 inputs = program_inputs ,
13361325 outputs = outputs ,
1337- parameters = all_parameters_and_buffers ,
1326+ parameters = ctx . get_params_with_values () ,
13381327 function = dygraph_function ,
13391328 main_program = main_program ,
13401329 startup_program = startup_program ,
0 commit comments