-
Notifications
You must be signed in to change notification settings - Fork 769
Description
Motivation: Why do you think this is important?
We pass around most of our configuration in flyte via python @dataclass as this makes it very easy for us to manage configuration and such. Unfortunately, this also means that we have to write a ton of wrappers for the inputs + outputs of our flyte @task and @workflow's, as we are often dereferencing attributes to pass along to other inputs, but need to wrap them (in lists and other dataclasses). Something like:
@dataclass
class FunctionAInput:
x: int
y: int
@dataclass
class FunctionAOutput:
a: int
b: int
@workflow
def my_sub_wf(my_in: FunctionAInput) -> FunctionAOutput:
return FunctionAOutput(a=my_in.x, b=my_in.y)
@dataclass
class WorkflowInput:
in_1: int
in_2: int
@workflow
def run_my_fn(workflow_in: WorkflowInput) -> FunctionAOutput:
return my_sub_wf(FunctionAInput(x=workflow_in.in_1, y=workflow_in.in_2))doesn't end up working, as we can't use the attributes of a promise dataclass nested in containers as inputs.
Instead, we're forced to do things like:
@dataclass
class FunctionAInput:
x: int
y: int
@task
def wrap_fn_a_inputs(x: int, y:int) -> FunctionAInput:
return FunctionAInput(x=x, y=y)
@dataclass
class FunctionAOutput:
a: int
b: int
@workflow
def wrap_fn_a_outputs(a: int, b: int) -> FunctionAOutput:
return FunctionAOutput(a=a, b=b)
@task
def my_sub_wf(my_in: FunctionAInput) -> FunctionAOutput:
return wrap_fn_a_outputs(a=my_in.x, b=my_in.y)
@dataclass
class WorkflowInput:
in_1: int
in_2: int
@workflow
def run_my_fn(workflow_in: WorkflowInput) -> FunctionAOutput:
return my_sub_wf(wrap_fn_a_inputs(x=workflow_in.in_1, y=workflow_in.in_2))While this doesn't add much to the above toy example, when the dataclasses have many more fields and depth, this gets messy fast. Making changes to a dataclass requires us to update all wrappers for it, the actual visualized graph becomes bloated with wrappers, and ease-of-use goes down.
It's also an issue when trying to wrap a dataclass promise's attribute in a container, such as:
@task
def get_int() -> int:
return 3
@dataclass
class IntWrapper(DataClassJSONMixin):
x: int
@task
def get_wrapped_int() -> IntWrapper:
return IntWrapper(x=3)
@task
def sum_list(input_list: list[int]) -> int:
return sum(input_list)
@workflow
def convert_list_workflow1() -> int:
# This workflow is fine
promised_int = get_int()
joined_list = [4, promised_int]
return sum_list(input_list=joined_list)
@workflow
def convert_list_workflow2() -> int:
# But this one is not
wrapped_int = get_wrapped_int()
joined_list = [4, wrapped_int.x]
return sum_list(input_list=joined_list)Goal: What should the final outcome look like, ideally?
Overall I imagine the approach is to increase the level of support for Dataclass attributes across flytekit, doing more to:
- Allow wrapping a dataclass promise attribute in a standard collection, such as
some_task([task_output.x])
a. Importantly, retain the typing as defined by the dataclass, even when only receiving an attribute (issue for complex types and forints) - Create tooling that allows flyte to use promises in the construction of a dataclass, treating dataclasses like other collection types which can be passed around as input to workflows and functions.
I think this would also resolve issues like #5427
Describe alternatives you've considered
We could use try @eager execution mode for this, but that seems to change a lot of the semantics of flyte as well as not have widespread support just yet. Plus, it would require significant refactors throughout the codebase.
Propose: Link/Inline OR Additional context
I've taken first stabs at pieces of this:
- [wip] Updating flytekit to handle dereferencing lists of promises (local) datologyai/flytekit#5 makes it possible locally to wrap dataclass promise attributes in simple containers like lists, dicts, and tuples, but I don't even know where to get started trying to get similar functionality working for remote. Given this only touches
promise.translate_inputs_to_literalsandbase_task.local_execute, I think it probably has a significant way to go, but I don't know what I don't know. It's also possible that JSON IDL flytekit#2600 completely resolves this part of my issue, but I'd need to dig more there. - First pass at promise logic in dataclasses datologyai/flytekit#2 attempts to make it possible to treat dataclasses like python-native collections, updating
binding_data_from_python_stdto iterate over dataclasses that contain promises and resolve them using aBindingDataMaprather than aBindingDatascalar. This almost works, but if thePromise's types are not primitive then duringto_literalduring serialization we may not have access to the non-primitive (in a similar issue to what the above attempts to resolve, and what likely causes [BUG] Accessing attributes fails on complex types #5427).
Overall I think this is a huge quality-of-life win for using Flyte with dataclasses, and I'm happy to actually work out the implementation, but I feel like I'm missing some pieces and context on the approach.
Are you sure this issue hasn't been raised already?
- Yes
Have you read the Code of Conduct?
- Yes
Metadata
Metadata
Assignees
Labels
Type
Projects
Status