Skip to content

[Core feature] Enhanced promise attribute support for Dataclasses in Python Flytekit #5667

@JackUrb

Description

@JackUrb

Motivation: Why do you think this is important?

We pass around most of our configuration in flyte via python @dataclass as this makes it very easy for us to manage configuration and such. Unfortunately, this also means that we have to write a ton of wrappers for the inputs + outputs of our flyte @task and @workflow's, as we are often dereferencing attributes to pass along to other inputs, but need to wrap them (in lists and other dataclasses). Something like:

@dataclass
class FunctionAInput:
    x: int
    y: int

@dataclass
class FunctionAOutput:
    a: int
    b: int

@workflow
def my_sub_wf(my_in: FunctionAInput) -> FunctionAOutput:
    return FunctionAOutput(a=my_in.x, b=my_in.y)

@dataclass
class WorkflowInput:
    in_1: int
    in_2: int

@workflow
def run_my_fn(workflow_in: WorkflowInput) -> FunctionAOutput:
    return my_sub_wf(FunctionAInput(x=workflow_in.in_1, y=workflow_in.in_2))

doesn't end up working, as we can't use the attributes of a promise dataclass nested in containers as inputs.

Instead, we're forced to do things like:

@dataclass
class FunctionAInput:
    x: int
    y: int

@task 
def wrap_fn_a_inputs(x: int, y:int) -> FunctionAInput:
    return FunctionAInput(x=x, y=y)

@dataclass
class FunctionAOutput:
    a: int
    b: int

@workflow 
def wrap_fn_a_outputs(a: int, b: int) -> FunctionAOutput:
    return FunctionAOutput(a=a, b=b)

@task
def my_sub_wf(my_in: FunctionAInput) -> FunctionAOutput:
    return wrap_fn_a_outputs(a=my_in.x, b=my_in.y)

@dataclass
class WorkflowInput:
    in_1: int
    in_2: int

@workflow
def run_my_fn(workflow_in: WorkflowInput) -> FunctionAOutput:
    return my_sub_wf(wrap_fn_a_inputs(x=workflow_in.in_1, y=workflow_in.in_2))

While this doesn't add much to the above toy example, when the dataclasses have many more fields and depth, this gets messy fast. Making changes to a dataclass requires us to update all wrappers for it, the actual visualized graph becomes bloated with wrappers, and ease-of-use goes down.

It's also an issue when trying to wrap a dataclass promise's attribute in a container, such as:

@task
def get_int() -> int:
    return 3

@dataclass
class IntWrapper(DataClassJSONMixin):
    x: int

@task
def get_wrapped_int() -> IntWrapper:
    return IntWrapper(x=3)

@task 
def sum_list(input_list: list[int]) -> int:
    return sum(input_list)

@workflow
def convert_list_workflow1() -> int:
    # This workflow is fine
    promised_int = get_int()
    joined_list = [4, promised_int]
    return sum_list(input_list=joined_list)

@workflow
def convert_list_workflow2() -> int:
    # But this one is not
    wrapped_int = get_wrapped_int()
    joined_list = [4, wrapped_int.x]
    return sum_list(input_list=joined_list)

Goal: What should the final outcome look like, ideally?

Overall I imagine the approach is to increase the level of support for Dataclass attributes across flytekit, doing more to:

  1. Allow wrapping a dataclass promise attribute in a standard collection, such as some_task([task_output.x])
    a. Importantly, retain the typing as defined by the dataclass, even when only receiving an attribute (issue for complex types and for ints)
  2. Create tooling that allows flyte to use promises in the construction of a dataclass, treating dataclasses like other collection types which can be passed around as input to workflows and functions.

I think this would also resolve issues like #5427

Describe alternatives you've considered

We could use try @eager execution mode for this, but that seems to change a lot of the semantics of flyte as well as not have widespread support just yet. Plus, it would require significant refactors throughout the codebase.

Propose: Link/Inline OR Additional context

I've taken first stabs at pieces of this:

  • [wip] Updating flytekit to handle dereferencing lists of promises (local) datologyai/flytekit#5 makes it possible locally to wrap dataclass promise attributes in simple containers like lists, dicts, and tuples, but I don't even know where to get started trying to get similar functionality working for remote. Given this only touches promise.translate_inputs_to_literals and base_task.local_execute, I think it probably has a significant way to go, but I don't know what I don't know. It's also possible that JSON IDL flytekit#2600 completely resolves this part of my issue, but I'd need to dig more there.
  • First pass at promise logic in dataclasses datologyai/flytekit#2 attempts to make it possible to treat dataclasses like python-native collections, updating binding_data_from_python_std to iterate over dataclasses that contain promises and resolve them using a BindingDataMap rather than a BindingData scalar. This almost works, but if the Promise's types are not primitive then during to_literal during serialization we may not have access to the non-primitive (in a similar issue to what the above attempts to resolve, and what likely causes [BUG] Accessing attributes fails on complex types #5427).

Overall I think this is a huge quality-of-life win for using Flyte with dataclasses, and I'm happy to actually work out the implementation, but I feel like I'm missing some pieces and context on the approach.

Are you sure this issue hasn't been raised already?

  • Yes

Have you read the Code of Conduct?

  • Yes

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions