diff --git a/python/paddle/base/variable_index.py b/python/paddle/base/variable_index.py index 242850860a5671..250729e260e6bc 100644 --- a/python/paddle/base/variable_index.py +++ b/python/paddle/base/variable_index.py @@ -553,7 +553,7 @@ def _setitem_static(x, indices, values): ) if in_pir_mode(): # map var to the new output, for dy2static - from paddle.jit.pir_dy2static.parameter_recorder import ( + from paddle.jit.dy2static.parameter_recorder import ( _global_inplace_map, ) @@ -678,7 +678,7 @@ def _setitem_static(x, indices, values): decrease_axes, none_axes, ) - from paddle.jit.pir_dy2static.parameter_recorder import ( + from paddle.jit.dy2static.parameter_recorder import ( _global_inplace_map, ) diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index ed2fac98614836..0ac1bdc883f690 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -93,7 +93,7 @@ def convert_load(x): # get the new output of the var if isinstance(x, Value): - from paddle.jit.pir_dy2static.parameter_recorder import ( + from paddle.jit.dy2static.parameter_recorder import ( _global_inplace_map, ) @@ -449,8 +449,8 @@ def _run_paddle_cond( _convert_tensor_array_if_necessary(helper, push_pop_names) pred = cast_bool_if_necessary(pred) init_args = helper.get(return_name_ids) + from paddle.jit.dy2static.parameter_recorder import _global_inplace_map from paddle.jit.dy2static.program_translator import ProgramTranslator - from paddle.jit.pir_dy2static.parameter_recorder import _global_inplace_map if use_pir_api(): inplace_map = _global_inplace_map diff --git a/python/paddle/jit/pir_dy2static/parameter_recorder.py b/python/paddle/jit/dy2static/parameter_recorder.py similarity index 100% rename from python/paddle/jit/pir_dy2static/parameter_recorder.py rename to python/paddle/jit/dy2static/parameter_recorder.py diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index d7a35916c48f4c..10140a0167aa6b 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -37,7 +37,6 @@ from paddle.framework import in_dynamic_mode, use_pir_api from paddle.nn.layer import layers from paddle.pir import Value -from paddle.pir.core import _convert_into_value, static_op_arg_cast_guard from paddle.utils import flatten, gast from . import error, logging_utils @@ -66,6 +65,7 @@ backend_guard, cuda_pinned_tensors_move_to_excepted_place, func_to_source_code, + graph_tracing_guard, input_specs_compatible, is_paddle_func, make_hashable, @@ -1265,8 +1265,7 @@ def pir_from_func_spec( with ( ir_static.program_guard(main_program, startup_program), - to_static_mode_guard(is_to_static=True), - static_op_arg_cast_guard(_convert_into_value), + graph_tracing_guard(main_program) as ctx, ): # 1. Adds `paddle.static.data` layers for input if needed static_inputs, program_inputs = ( @@ -1309,16 +1308,6 @@ def pir_from_func_spec( error_data.raise_new_exception() raise - # 3. Gets all ParamBases and buffered VarBases in the function - from ..pir_dy2static.parameter_recorder import ( - _global_inplace_map, - _global_parameter_recorder, - ) - - all_parameters_and_buffers = _global_parameter_recorder.pop( - main_program - ) - _global_inplace_map.pop(main_program) if outputs is not None: need_wrap_into_list = ( not isinstance(outputs, (tuple, list)) or len(outputs) == 1 @@ -1334,7 +1323,7 @@ def pir_from_func_spec( return ConcreteProgram( inputs=program_inputs, outputs=outputs, - parameters=all_parameters_and_buffers, + parameters=ctx.get_params_with_values(), function=dygraph_function, main_program=main_program, startup_program=startup_program, diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 4ed0749b96725d..04be9edadf3e29 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -41,10 +41,14 @@ import paddle from paddle.base import backward, core, framework, unique_name from paddle.base.data_feeder import convert_dtype +from paddle.base.dygraph.base import ( + to_static_mode_guard, +) from paddle.base.layer_helper import LayerHelper from paddle.base.wrapped_decorator import signature_safe_contextmanager from paddle.framework import CUDAPinnedPlace from paddle.jit.utils import OrderedSet +from paddle.pir.core import _convert_into_value, static_op_arg_cast_guard from paddle.utils import flatten, gast from paddle.utils.environments import ( BooleanEnvironmentVariable, @@ -1095,3 +1099,40 @@ def extract_tensor_dynamic_dims( f"Expected {DYNAMIC_DIMS_ATTR_NAME} to be a tuple, but got {type(dynamic_dims).__name__}" ) return dynamic_dims + + +class GraphTracingContext: + params_with_values: tuple[list[paddle.Tensor], list[paddle.Tensor]] | None + + def __init__(self): + self.params_with_values = None + + def set_params_with_values( + self, + params_with_values: tuple[list[paddle.Tensor], list[paddle.Tensor]], + ): + self.params_with_values = params_with_values + + def get_params_with_values( + self, + ) -> tuple[list[paddle.Tensor], list[paddle.Tensor]]: + assert self.params_with_values is not None + return self.params_with_values + + +@contextmanager +def graph_tracing_guard(main_program: paddle.static.Program): + ctx = GraphTracingContext() + with ( + to_static_mode_guard(is_to_static=True), + static_op_arg_cast_guard(_convert_into_value), + ): + yield ctx + + from ..dy2static.parameter_recorder import ( + _global_inplace_map, + _global_parameter_recorder, + ) + + ctx.set_params_with_values(_global_parameter_recorder.pop(main_program)) + _global_inplace_map.pop(main_program) diff --git a/python/paddle/jit/pir_dy2static/__init__.py b/python/paddle/jit/pir_dy2static/__init__.py deleted file mode 100644 index 595add0aed9e11..00000000000000 --- a/python/paddle/jit/pir_dy2static/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/paddle/jit/sot/infer_meta.py b/python/paddle/jit/sot/infer_meta.py index c448eef86473b1..bf6c9ec70728ea 100644 --- a/python/paddle/jit/sot/infer_meta.py +++ b/python/paddle/jit/sot/infer_meta.py @@ -14,6 +14,7 @@ from __future__ import annotations import copy +from contextlib import nullcontext from typing import TYPE_CHECKING, Any, TypeVar import paddle @@ -33,7 +34,11 @@ from paddle.distributed.auto_parallel.static.utils import ( convert_to_dims_mapping, ) -from paddle.jit.dy2static.utils import extract_tensor_dynamic_dims +from paddle.jit.dy2static.utils import ( + ALREADY_D2S, + extract_tensor_dynamic_dims, + graph_tracing_guard, +) from paddle.pir import is_fake_value from paddle.static import InputSpec from paddle.utils import flatten, is_sequence @@ -459,6 +464,7 @@ def infer_meta(self, func, *args, **kwargs): convert_meta_to_variable(kwargs), ) + graph_tracing_context_manager = nullcontext() with paddle.static.program_guard( self.main_program, self.startup_program ): @@ -467,7 +473,12 @@ def infer_meta(self, func, *args, **kwargs): # Do we need add condition check here? func = getattr(args[0], func) args = args[1:] - out = func(*args, **kwargs) + if hasattr(func, ALREADY_D2S): + graph_tracing_context_manager = graph_tracing_guard( + self.main_program + ) + with graph_tracing_context_manager: + out = func(*args, **kwargs) return convert_variable_to_meta_info(out) diff --git a/python/paddle/jit/sot/opcode_translator/executor/variables/base.py b/python/paddle/jit/sot/opcode_translator/executor/variables/base.py index a0b9a0d9c9ef3b..e62b052520207b 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/variables/base.py +++ b/python/paddle/jit/sot/opcode_translator/executor/variables/base.py @@ -163,7 +163,7 @@ def _map_dataclass_variable(variable: VariableBase | object): new_dataclass = dataclass_from_dict( variable.get_py_type(), { - fd.name: map_func(variable.getattr(fd.name)) + fd.name: _map_variable(variable.getattr(fd.name)) for fd in fields(variable.get_py_type()) }, ) diff --git a/python/paddle/pir/core.py b/python/paddle/pir/core.py index 01bfcb983c3750..25196ac695b7db 100644 --- a/python/paddle/pir/core.py +++ b/python/paddle/pir/core.py @@ -501,7 +501,7 @@ def _convert_into_value(tensor): Convert Tensor into Value. """ import paddle - from paddle.jit.pir_dy2static.parameter_recorder import ( + from paddle.jit.dy2static.parameter_recorder import ( _global_parameter_recorder, ) diff --git a/python/setup.py.in b/python/setup.py.in index 69ba503d18e597..2f55ebd466b2ff 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -937,7 +937,6 @@ packages=['paddle', 'paddle.jit', 'paddle.jit.dy2static', 'paddle.jit.dy2static.transformers', - 'paddle.jit.pir_dy2static', 'paddle.jit.sot', 'paddle.jit.sot.opcode_translator', 'paddle.jit.sot.opcode_translator.executor', diff --git a/setup.py b/setup.py index ffd73650937a6c..351b84999e1088 100644 --- a/setup.py +++ b/setup.py @@ -2394,7 +2394,6 @@ def get_setup_parameters(): 'paddle.jit', 'paddle.jit.dy2static', 'paddle.jit.dy2static.transformers', - 'paddle.jit.pir_dy2static', 'paddle.jit.sot', 'paddle.jit.sot.opcode_translator', 'paddle.jit.sot.opcode_translator.executor', diff --git a/test/sot/test_capture_control_flow.py b/test/sot/test_capture_control_flow.py index 1720d368dd7f71..d76622d8b4269b 100644 --- a/test/sot/test_capture_control_flow.py +++ b/test/sot/test_capture_control_flow.py @@ -20,6 +20,7 @@ ) import paddle +from paddle import nn @paddle.jit.marker.capture_control_flow @@ -66,5 +67,40 @@ def test_case_capture_control_flow(self): self.assertEqual(ctx.translate_count, 1) +class NetWithCaptureControlFlow(nn.Layer): + def __init__(self): + super().__init__() + self.layer = nn.Linear(8, 8) + + @paddle.jit.marker.capture_control_flow + def fn(self, x): + x = self.layer(x) + if x.sum() > 0: + x += paddle.ones_like(x) + else: + x -= paddle.zeros_like(x) + return x + + def forward(self, x): + return self.fn(x) + 1 + + +def model_call(x: paddle.Tensor, net: paddle.nn.Layer): + return net(x) + + +class TestEagerParamsToPirValue(TestCaseBase): + def test_case_without_capture_control_flow(self): + model = NetWithCaptureControlFlow() + with test_instruction_translator_cache_context() as ctx: + self.assertEqual(ctx.translate_count, 0) + x = paddle.randn([4, 8]) + self.assert_results(model_call, x, model) + self.assertEqual(ctx.translate_count, 1) + x = paddle.randn([4, 8]) + self.assert_results(model_call, x, model) + self.assertEqual(ctx.translate_count, 1) + + if __name__ == "__main__": unittest.main()