Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _map_dataclass_variable(variable: VariableBase | object):
new_dataclass = dataclass_from_dict(
variable.get_py_type(),
{
fd.name: map_func(variable.getattr(fd.name))
fd.name: _map_variable(variable.getattr(fd.name))
for fd in fields(variable.get_py_type())
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import contextlib
import dataclasses
import dis
import functools
Expand Down Expand Up @@ -395,6 +396,37 @@ def call_function(self, /, *args, **kwargs):
raise InnerError("UserCodeVariable call_function is not implemented.")


def to_unified_fn(fn):
def wrap_as_to_static_fn(fn):
def new_fn(*args, **kwargs):
if paddle.base.dygraph.base.in_to_static_mode():
to_static_guard = contextlib.nullcontext
else:

@contextlib.contextmanager
def to_static_guard():
with (
paddle.base.dygraph.base.to_static_mode_guard(
is_to_static=True
),
paddle.pir.core.static_op_arg_cast_guard(
paddle.pir.core._convert_into_value
),
):
yield

with to_static_guard():
return (
paddle.jit.dy2static.program_translator.convert_to_static(
fn
)(*args, **kwargs)
)

return new_fn

return wrap_as_to_static_fn(fn)


class PaddleApiVariable(FunctionVariable):
"""
PaddleApiVariable is a subclass of FunctionVariable used to wrap a paddlepaddle API function.
Expand Down Expand Up @@ -424,9 +456,6 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
if callable(value) and need_capture_control_flow(value):
# NOTE(SigureMo): We assume that if a function use AST transform,
# it already be already unified in dynamic and static graph.
to_unified_fn = (
paddle.jit.dy2static.program_translator.convert_to_static
)
unified_fn = to_unified_fn(value)
paddle.jit.marker.unified(unified_fn, for_sot=True)
return PaddleApiVariable(unified_fn, graph, tracker)
Expand Down
Loading