Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def _setitem_static(x, indices, values):
)
if in_pir_mode():
# map var to the new output, for dy2static
from paddle.jit.pir_dy2static.parameter_recorder import (
from paddle.jit.dy2static.parameter_recorder import (
_global_inplace_map,
)

Expand Down Expand Up @@ -678,7 +678,7 @@ def _setitem_static(x, indices, values):
decrease_axes,
none_axes,
)
from paddle.jit.pir_dy2static.parameter_recorder import (
from paddle.jit.dy2static.parameter_recorder import (
_global_inplace_map,
)

Expand Down
4 changes: 2 additions & 2 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def convert_load(x):

# get the new output of the var
if isinstance(x, Value):
from paddle.jit.pir_dy2static.parameter_recorder import (
from paddle.jit.dy2static.parameter_recorder import (
_global_inplace_map,
)

Expand Down Expand Up @@ -449,8 +449,8 @@ def _run_paddle_cond(
_convert_tensor_array_if_necessary(helper, push_pop_names)
pred = cast_bool_if_necessary(pred)
init_args = helper.get(return_name_ids)
from paddle.jit.dy2static.parameter_recorder import _global_inplace_map
from paddle.jit.dy2static.program_translator import ProgramTranslator
from paddle.jit.pir_dy2static.parameter_recorder import _global_inplace_map

if use_pir_api():
inplace_map = _global_inplace_map
Expand Down
17 changes: 3 additions & 14 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from paddle.framework import in_dynamic_mode, use_pir_api
from paddle.nn.layer import layers
from paddle.pir import Value
from paddle.pir.core import _convert_into_value, static_op_arg_cast_guard
from paddle.utils import flatten, gast

from . import error, logging_utils
Expand Down Expand Up @@ -66,6 +65,7 @@
backend_guard,
cuda_pinned_tensors_move_to_excepted_place,
func_to_source_code,
graph_tracing_guard,
input_specs_compatible,
is_paddle_func,
make_hashable,
Expand Down Expand Up @@ -1265,8 +1265,7 @@ def pir_from_func_spec(

with (
ir_static.program_guard(main_program, startup_program),
to_static_mode_guard(is_to_static=True),
static_op_arg_cast_guard(_convert_into_value),
graph_tracing_guard(main_program) as ctx,
):
# 1. Adds `paddle.static.data` layers for input if needed
static_inputs, program_inputs = (
Expand Down Expand Up @@ -1309,16 +1308,6 @@ def pir_from_func_spec(
error_data.raise_new_exception()
raise

# 3. Gets all ParamBases and buffered VarBases in the function
from ..pir_dy2static.parameter_recorder import (
_global_inplace_map,
_global_parameter_recorder,
)

all_parameters_and_buffers = _global_parameter_recorder.pop(
main_program
)
_global_inplace_map.pop(main_program)
if outputs is not None:
need_wrap_into_list = (
not isinstance(outputs, (tuple, list)) or len(outputs) == 1
Expand All @@ -1334,7 +1323,7 @@ def pir_from_func_spec(
return ConcreteProgram(
inputs=program_inputs,
outputs=outputs,
parameters=all_parameters_and_buffers,
parameters=ctx.get_params_with_values(),
function=dygraph_function,
main_program=main_program,
startup_program=startup_program,
Expand Down
41 changes: 41 additions & 0 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@
import paddle
from paddle.base import backward, core, framework, unique_name
from paddle.base.data_feeder import convert_dtype
from paddle.base.dygraph.base import (
to_static_mode_guard,
)
from paddle.base.layer_helper import LayerHelper
from paddle.base.wrapped_decorator import signature_safe_contextmanager
from paddle.framework import CUDAPinnedPlace
from paddle.jit.utils import OrderedSet
from paddle.pir.core import _convert_into_value, static_op_arg_cast_guard
from paddle.utils import flatten, gast
from paddle.utils.environments import (
BooleanEnvironmentVariable,
Expand Down Expand Up @@ -1095,3 +1099,40 @@ def extract_tensor_dynamic_dims(
f"Expected {DYNAMIC_DIMS_ATTR_NAME} to be a tuple, but got {type(dynamic_dims).__name__}"
)
return dynamic_dims


class GraphTracingContext:
params_with_values: tuple[list[paddle.Tensor], list[paddle.Tensor]] | None

def __init__(self):
self.params_with_values = None

def set_params_with_values(
self,
params_with_values: tuple[list[paddle.Tensor], list[paddle.Tensor]],
):
self.params_with_values = params_with_values

def get_params_with_values(
self,
) -> tuple[list[paddle.Tensor], list[paddle.Tensor]]:
assert self.params_with_values is not None
return self.params_with_values


@contextmanager
def graph_tracing_guard(main_program: paddle.static.Program):
ctx = GraphTracingContext()
with (
to_static_mode_guard(is_to_static=True),
static_op_arg_cast_guard(_convert_into_value),
):
yield ctx

from ..dy2static.parameter_recorder import (
_global_inplace_map,
_global_parameter_recorder,
)

ctx.set_params_with_values(_global_parameter_recorder.pop(main_program))
_global_inplace_map.pop(main_program)
13 changes: 0 additions & 13 deletions python/paddle/jit/pir_dy2static/__init__.py

This file was deleted.

15 changes: 13 additions & 2 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import copy
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, TypeVar

import paddle
Expand All @@ -33,7 +34,11 @@
from paddle.distributed.auto_parallel.static.utils import (
convert_to_dims_mapping,
)
from paddle.jit.dy2static.utils import extract_tensor_dynamic_dims
from paddle.jit.dy2static.utils import (
ALREADY_D2S,
extract_tensor_dynamic_dims,
graph_tracing_guard,
)
from paddle.pir import is_fake_value
from paddle.static import InputSpec
from paddle.utils import flatten, is_sequence
Expand Down Expand Up @@ -459,6 +464,7 @@ def infer_meta(self, func, *args, **kwargs):
convert_meta_to_variable(kwargs),
)

graph_tracing_context_manager = nullcontext()
with paddle.static.program_guard(
self.main_program, self.startup_program
):
Expand All @@ -467,7 +473,12 @@ def infer_meta(self, func, *args, **kwargs):
# Do we need add condition check here?
func = getattr(args[0], func)
args = args[1:]
out = func(*args, **kwargs)
if hasattr(func, ALREADY_D2S):
graph_tracing_context_manager = graph_tracing_guard(
self.main_program
)
with graph_tracing_context_manager:
out = func(*args, **kwargs)
return convert_variable_to_meta_info(out)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _map_dataclass_variable(variable: VariableBase | object):
new_dataclass = dataclass_from_dict(
variable.get_py_type(),
{
fd.name: map_func(variable.getattr(fd.name))
fd.name: _map_variable(variable.getattr(fd.name))
for fd in fields(variable.get_py_type())
},
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/pir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def _convert_into_value(tensor):
Convert Tensor into Value.
"""
import paddle
from paddle.jit.pir_dy2static.parameter_recorder import (
from paddle.jit.dy2static.parameter_recorder import (
_global_parameter_recorder,
)

Expand Down
1 change: 0 additions & 1 deletion python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,6 @@ packages=['paddle',
'paddle.jit',
'paddle.jit.dy2static',
'paddle.jit.dy2static.transformers',
'paddle.jit.pir_dy2static',
'paddle.jit.sot',
'paddle.jit.sot.opcode_translator',
'paddle.jit.sot.opcode_translator.executor',
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,6 @@ def get_setup_parameters():
'paddle.jit',
'paddle.jit.dy2static',
'paddle.jit.dy2static.transformers',
'paddle.jit.pir_dy2static',
'paddle.jit.sot',
'paddle.jit.sot.opcode_translator',
'paddle.jit.sot.opcode_translator.executor',
Expand Down
36 changes: 36 additions & 0 deletions test/sot/test_capture_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

import paddle
from paddle import nn


@paddle.jit.marker.capture_control_flow
Expand Down Expand Up @@ -66,5 +67,40 @@ def test_case_capture_control_flow(self):
self.assertEqual(ctx.translate_count, 1)


class NetWithCaptureControlFlow(nn.Layer):
def __init__(self):
super().__init__()
self.layer = nn.Linear(8, 8)

@paddle.jit.marker.capture_control_flow
def fn(self, x):
x = self.layer(x)
if x.sum() > 0:
x += paddle.ones_like(x)
else:
x -= paddle.zeros_like(x)
return x

def forward(self, x):
return self.fn(x) + 1


def model_call(x: paddle.Tensor, net: paddle.nn.Layer):
return net(x)


class TestEagerParamsToPirValue(TestCaseBase):
def test_case_without_capture_control_flow(self):
model = NetWithCaptureControlFlow()
with test_instruction_translator_cache_context() as ctx:
self.assertEqual(ctx.translate_count, 0)
x = paddle.randn([4, 8])
self.assert_results(model_call, x, model)
self.assertEqual(ctx.translate_count, 1)
x = paddle.randn([4, 8])
self.assert_results(model_call, x, model)
self.assertEqual(ctx.translate_count, 1)


if __name__ == "__main__":
unittest.main()
Loading