Skip to content

Commit 64e9cb1

Browse files
DrRyanHuangAlAuAu
authored andcommitted
[SOT] Ensure run AST transformed function under graph_tracing_guard (PaddlePaddle#76198)
1 parent 9c71583 commit 64e9cb1

File tree

12 files changed

+99
-37
lines changed

12 files changed

+99
-37
lines changed

python/paddle/base/variable_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def _setitem_static(x, indices, values):
553553
)
554554
if in_pir_mode():
555555
# map var to the new output, for dy2static
556-
from paddle.jit.pir_dy2static.parameter_recorder import (
556+
from paddle.jit.dy2static.parameter_recorder import (
557557
_global_inplace_map,
558558
)
559559

@@ -678,7 +678,7 @@ def _setitem_static(x, indices, values):
678678
decrease_axes,
679679
none_axes,
680680
)
681-
from paddle.jit.pir_dy2static.parameter_recorder import (
681+
from paddle.jit.dy2static.parameter_recorder import (
682682
_global_inplace_map,
683683
)
684684

python/paddle/jit/dy2static/convert_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def convert_load(x):
9393

9494
# get the new output of the var
9595
if isinstance(x, Value):
96-
from paddle.jit.pir_dy2static.parameter_recorder import (
96+
from paddle.jit.dy2static.parameter_recorder import (
9797
_global_inplace_map,
9898
)
9999

@@ -449,8 +449,8 @@ def _run_paddle_cond(
449449
_convert_tensor_array_if_necessary(helper, push_pop_names)
450450
pred = cast_bool_if_necessary(pred)
451451
init_args = helper.get(return_name_ids)
452+
from paddle.jit.dy2static.parameter_recorder import _global_inplace_map
452453
from paddle.jit.dy2static.program_translator import ProgramTranslator
453-
from paddle.jit.pir_dy2static.parameter_recorder import _global_inplace_map
454454

455455
if use_pir_api():
456456
inplace_map = _global_inplace_map

python/paddle/jit/dy2static/program_translator.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from paddle.framework import in_dynamic_mode, use_pir_api
3838
from paddle.nn.layer import layers
3939
from paddle.pir import Value
40-
from paddle.pir.core import _convert_into_value, static_op_arg_cast_guard
4140
from paddle.utils import flatten, gast
4241

4342
from . import error, logging_utils
@@ -66,6 +65,7 @@
6665
backend_guard,
6766
cuda_pinned_tensors_move_to_excepted_place,
6867
func_to_source_code,
68+
graph_tracing_guard,
6969
input_specs_compatible,
7070
is_paddle_func,
7171
make_hashable,
@@ -1265,8 +1265,7 @@ def pir_from_func_spec(
12651265

12661266
with (
12671267
ir_static.program_guard(main_program, startup_program),
1268-
to_static_mode_guard(is_to_static=True),
1269-
static_op_arg_cast_guard(_convert_into_value),
1268+
graph_tracing_guard(main_program) as ctx,
12701269
):
12711270
# 1. Adds `paddle.static.data` layers for input if needed
12721271
static_inputs, program_inputs = (
@@ -1309,16 +1308,6 @@ def pir_from_func_spec(
13091308
error_data.raise_new_exception()
13101309
raise
13111310

1312-
# 3. Gets all ParamBases and buffered VarBases in the function
1313-
from ..pir_dy2static.parameter_recorder import (
1314-
_global_inplace_map,
1315-
_global_parameter_recorder,
1316-
)
1317-
1318-
all_parameters_and_buffers = _global_parameter_recorder.pop(
1319-
main_program
1320-
)
1321-
_global_inplace_map.pop(main_program)
13221311
if outputs is not None:
13231312
need_wrap_into_list = (
13241313
not isinstance(outputs, (tuple, list)) or len(outputs) == 1
@@ -1334,7 +1323,7 @@ def pir_from_func_spec(
13341323
return ConcreteProgram(
13351324
inputs=program_inputs,
13361325
outputs=outputs,
1337-
parameters=all_parameters_and_buffers,
1326+
parameters=ctx.get_params_with_values(),
13381327
function=dygraph_function,
13391328
main_program=main_program,
13401329
startup_program=startup_program,

python/paddle/jit/dy2static/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@
4141
import paddle
4242
from paddle.base import backward, core, framework, unique_name
4343
from paddle.base.data_feeder import convert_dtype
44+
from paddle.base.dygraph.base import (
45+
to_static_mode_guard,
46+
)
4447
from paddle.base.layer_helper import LayerHelper
4548
from paddle.base.wrapped_decorator import signature_safe_contextmanager
4649
from paddle.framework import CUDAPinnedPlace
4750
from paddle.jit.utils import OrderedSet
51+
from paddle.pir.core import _convert_into_value, static_op_arg_cast_guard
4852
from paddle.utils import flatten, gast
4953
from paddle.utils.environments import (
5054
BooleanEnvironmentVariable,
@@ -1095,3 +1099,40 @@ def extract_tensor_dynamic_dims(
10951099
f"Expected {DYNAMIC_DIMS_ATTR_NAME} to be a tuple, but got {type(dynamic_dims).__name__}"
10961100
)
10971101
return dynamic_dims
1102+
1103+
1104+
class GraphTracingContext:
1105+
params_with_values: tuple[list[paddle.Tensor], list[paddle.Tensor]] | None
1106+
1107+
def __init__(self):
1108+
self.params_with_values = None
1109+
1110+
def set_params_with_values(
1111+
self,
1112+
params_with_values: tuple[list[paddle.Tensor], list[paddle.Tensor]],
1113+
):
1114+
self.params_with_values = params_with_values
1115+
1116+
def get_params_with_values(
1117+
self,
1118+
) -> tuple[list[paddle.Tensor], list[paddle.Tensor]]:
1119+
assert self.params_with_values is not None
1120+
return self.params_with_values
1121+
1122+
1123+
@contextmanager
1124+
def graph_tracing_guard(main_program: paddle.static.Program):
1125+
ctx = GraphTracingContext()
1126+
with (
1127+
to_static_mode_guard(is_to_static=True),
1128+
static_op_arg_cast_guard(_convert_into_value),
1129+
):
1130+
yield ctx
1131+
1132+
from ..dy2static.parameter_recorder import (
1133+
_global_inplace_map,
1134+
_global_parameter_recorder,
1135+
)
1136+
1137+
ctx.set_params_with_values(_global_parameter_recorder.pop(main_program))
1138+
_global_inplace_map.pop(main_program)

python/paddle/jit/pir_dy2static/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

python/paddle/jit/sot/infer_meta.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import copy
17+
from contextlib import nullcontext
1718
from typing import TYPE_CHECKING, Any, TypeVar
1819

1920
import paddle
@@ -33,7 +34,11 @@
3334
from paddle.distributed.auto_parallel.static.utils import (
3435
convert_to_dims_mapping,
3536
)
36-
from paddle.jit.dy2static.utils import extract_tensor_dynamic_dims
37+
from paddle.jit.dy2static.utils import (
38+
ALREADY_D2S,
39+
extract_tensor_dynamic_dims,
40+
graph_tracing_guard,
41+
)
3742
from paddle.pir import is_fake_value
3843
from paddle.static import InputSpec
3944
from paddle.utils import flatten, is_sequence
@@ -459,6 +464,7 @@ def infer_meta(self, func, *args, **kwargs):
459464
convert_meta_to_variable(kwargs),
460465
)
461466

467+
graph_tracing_context_manager = nullcontext()
462468
with paddle.static.program_guard(
463469
self.main_program, self.startup_program
464470
):
@@ -467,7 +473,12 @@ def infer_meta(self, func, *args, **kwargs):
467473
# Do we need add condition check here?
468474
func = getattr(args[0], func)
469475
args = args[1:]
470-
out = func(*args, **kwargs)
476+
if hasattr(func, ALREADY_D2S):
477+
graph_tracing_context_manager = graph_tracing_guard(
478+
self.main_program
479+
)
480+
with graph_tracing_context_manager:
481+
out = func(*args, **kwargs)
471482
return convert_variable_to_meta_info(out)
472483

473484

python/paddle/jit/sot/opcode_translator/executor/variables/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _map_dataclass_variable(variable: VariableBase | object):
163163
new_dataclass = dataclass_from_dict(
164164
variable.get_py_type(),
165165
{
166-
fd.name: map_func(variable.getattr(fd.name))
166+
fd.name: _map_variable(variable.getattr(fd.name))
167167
for fd in fields(variable.get_py_type())
168168
},
169169
)

python/paddle/pir/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def _convert_into_value(tensor):
501501
Convert Tensor into Value.
502502
"""
503503
import paddle
504-
from paddle.jit.pir_dy2static.parameter_recorder import (
504+
from paddle.jit.dy2static.parameter_recorder import (
505505
_global_parameter_recorder,
506506
)
507507

python/setup.py.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,6 @@ packages=['paddle',
937937
'paddle.jit',
938938
'paddle.jit.dy2static',
939939
'paddle.jit.dy2static.transformers',
940-
'paddle.jit.pir_dy2static',
941940
'paddle.jit.sot',
942941
'paddle.jit.sot.opcode_translator',
943942
'paddle.jit.sot.opcode_translator.executor',

0 commit comments

Comments
 (0)