Skip to content
84 changes: 55 additions & 29 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from types import NoneType
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -57,6 +57,7 @@
DEFINITIONS = "definitions"
TITLE = "title"

JSON_TRANSFORMER_CACHE = {}

class BatchSize:
"""
Expand Down Expand Up @@ -149,7 +150,7 @@ def type_assertions_enabled(self) -> bool:
return self._type_assertions_enabled

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
if not ((get_origin(t) is not None) or isinstance(v, t)):
raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}")

@abstractmethod
Expand Down Expand Up @@ -452,8 +453,12 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
schema = JSONSchema().dump(s)
else: # DataClassJSONMixin
from mashumaro.jsonschema import build_json_schema

schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict()
origin_type = self._get_origin_type_in_annotation(t)
if origin_type not in JSON_TRANSFORMER_CACHE:
schema = build_json_schema(cast(DataClassJSONMixin, origin_type)).to_dict()
JSON_TRANSFORMER_CACHE[origin_type] = schema
else:
schema = JSON_TRANSFORMER_CACHE[origin_type]
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
logger.warning(
Expand Down Expand Up @@ -493,22 +498,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

self._make_dataclass_serializable(python_val, python_type)

# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder
# The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It deserializes a JSON string into a data class, and provides additional functionality over JSONEncoder
if hasattr(python_val, "to_json"):
json_str = python_val.to_json()
else:
# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)
try:
json_str = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

Expand Down Expand Up @@ -652,15 +662,20 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

json_str = _json_format.MessageToJson(lv.scalar.generic)

# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder
# The `from_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder
if hasattr(expected_python_type, "from_json"):
dc = expected_python_type.from_json(json_str) # type: ignore
else:
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder

dc = decoder.decode(json_str)
dc = decoder.decode(json_str)

dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, dc)
Expand Down Expand Up @@ -696,11 +711,22 @@ def tag(expected_python_type: Type[T]) -> str:

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})

def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal:
if len(elems) == 0:
return Literal(collection=LiteralCollection(literals=[]))
st = type(elems[0])
lt = TypeEngine.to_literal_type(st)
lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems]
return Literal(collection=LiteralCollection(literals=lits))

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
struct = Struct()
try:
struct.update(_MessageToDict(cast(Message, python_val)))
message_dict = _MessageToDict(cast(Message, python_val))
if isinstance(message_dict, list):
return self._handle_list_literal(ctx, message_dict)
struct.update(message_dict)
except Exception:
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
return Literal(scalar=Scalar(generic=struct))
Expand Down Expand Up @@ -1051,7 +1077,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if python_val is None and expected and expected.union_type is None:
if (python_val is None and python_type != NoneType) and expected and expected.union_type is None:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down