Skip to content

Commit 3fc6fa5

Browse files
committed
v0.1.1: Support of direct model annotations
Now, instead specifying explicit schema, one could write: ``` from django.db import Model from django_pydantic_field import SchemaField class MyModel(models.Model): field: list[int] = SchemaField() ``` The schema will be inferred at the model freezing step. It's not possible to specify not-yet resolved forward reference at the moment.
1 parent cfc7ae8 commit 3fc6fa5

File tree

10 files changed

+86
-35
lines changed

10 files changed

+86
-35
lines changed

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ install:
99

1010
.PHONY: activate
1111
activate:
12-
. .env/bin/activate
12+
@ . .env/bin/activate
1313

1414

1515
.PHONY: build
@@ -24,6 +24,12 @@ test: activate
2424
test:
2525
pytest $(A)
2626

27+
.PHONY: lint
28+
lint: A=.
29+
lint: activate
30+
lint:
31+
mypy $(A)
32+
2733

2834
.PHONY: upload
2935
upload: activate

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@ class Bar(pydantic.BaseModel):
2727

2828

2929
class MyModel(models.Model):
30-
foo_field: Foo = SchemaField(schema=Foo)
31-
bar_list: list[Bar] = SchemaField(schema=list[Bar])
32-
raw_date_map: dict[date, int] = SchemaField(schema=dict[date, int])
33-
raw_uids: set[UUID] = SchemaField(schema=set[UUID])
30+
# Infer schema from field annotation
31+
foo_field: Foo = SchemaField()
32+
33+
# or pecify schema explicitly
34+
bar_list: typing.Sequence[Bar] = SchemaField(schema=list[Bar])
35+
36+
# Pydantic exportable types are supported
37+
raw_date_map: dict[int, date] = SchemaField()
38+
raw_uids: set[UUID] = SchemaField()
3439

3540
...
3641

django_pydantic_field/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ def default_error_handler(obj, err):
4242

4343

4444
class SchemaEncoder(DjangoJSONEncoder):
45-
def __init__(self, *args, schema: "ModelType", export_cfg=None, **kwargs):
45+
def __init__(self, *args, schema: "ModelType", export=None, **kwargs):
4646
self.schema = schema
47-
self.export_cfg = export_cfg or {}
47+
self.export_params = export or {}
4848
super().__init__(*args, **kwargs)
4949

5050
def encode(self, obj):
5151
try:
52-
data = self.schema(__root__=obj).json(**self.export_cfg)
52+
data = self.schema(__root__=obj).json(**self.export_params)
5353
except pydantic.ValidationError:
5454
# This branch used for expressions like .filter(data__contains={}).
5555
# We don't want that {} to be parsed as a schema.

django_pydantic_field/fields.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from functools import partial
55

6+
from django.core.exceptions import FieldError
67
from django.db.models.fields import NOT_PROVIDED
78
from django.db.models.query_utils import DeferredAttribute
89
from django.db.models import JSONField
@@ -12,7 +13,7 @@
1213

1314
from . import base
1415

15-
__all__ = "SchemaField",
16+
__all__ = ("SchemaField",)
1617

1718

1819
class SchemaDeferredAttribute(DeferredAttribute):
@@ -38,30 +39,35 @@ class PydanticSchemaField(base.SchemaWrapper["base.ST"], JSONField):
3839

3940
def __init__(
4041
self,
41-
schema: t.Union[t.Type["base.ST"], "GenericContainer"],
42-
config: "base.ConfigType" = None,
4342
*args,
43+
schema: t.Union[t.Type["base.ST"], "GenericContainer"] = None,
44+
config: "base.ConfigType" = None,
4445
error_handler=base.default_error_handler,
4546
**kwargs
4647
):
47-
if isinstance(schema, GenericContainer):
48-
schema = t.cast(t.Type["base.ST"], schema.reconstruct_type())
48+
super().__init__(*args, **kwargs)
4949

50-
self.schema = schema
5150
self.config = config
52-
self.export_cfg = self._extract_export_kwargs(kwargs, dict.pop)
53-
54-
field_schema = self._wrap_schema(schema, config)
55-
decoder = partial(base.SchemaDecoder, schema=field_schema, error_handler=error_handler)
56-
encoder = partial(base.SchemaEncoder, schema=field_schema, export_cfg=self.export_cfg)
57-
58-
kwargs.update(decoder=decoder, encoder=encoder)
59-
super().__init__(*args, **kwargs)
51+
self.export_params = self._extract_export_kwargs(kwargs, dict.pop)
52+
self.error_handler = error_handler
53+
self._init_schema(schema)
6054

6155
def __copy__(self):
6256
_, _, args, kwargs = self.deconstruct()
6357
return type(self)(*args, **kwargs)
6458

59+
def contribute_to_class(self, cls, name, private_only=False):
60+
if self.schema is None:
61+
annotated_schema = t.get_type_hints(cls).get(name, None)
62+
if annotated_schema is None:
63+
raise FieldError(
64+
f"{cls._meta.label}.{name} needs to be either annotated "
65+
"or `schema=` field attribute should be explicitly passed"
66+
)
67+
self._init_schema(annotated_schema)
68+
69+
super().contribute_to_class(cls, name, private_only)
70+
6571
def get_default(self):
6672
value = super().get_default()
6773
return self.to_python(value)
@@ -81,6 +87,19 @@ def to_python(self, value) -> "base.SchemaT":
8187
assert self.decoder is not None
8288
return self.decoder().decode(value)
8389

90+
def _init_schema(
91+
self,
92+
schema: t.Union[t.Type["base.ST"], "GenericContainer", None],
93+
):
94+
if isinstance(schema, GenericContainer):
95+
schema = t.cast(t.Type["base.ST"], schema.reconstruct_type())
96+
97+
self.schema = schema
98+
if schema is not None:
99+
serializer = self._wrap_schema(schema, self.config)
100+
self.decoder = partial(base.SchemaDecoder, serializer, self.error_handler) # type: ignore
101+
self.encoder = partial(base.SchemaEncoder, schema=serializer, export=self.export_params) # type: ignore
102+
84103
def _deconstruct_default(self, kwargs):
85104
default = kwargs.get("default", NOT_PROVIDED)
86105

@@ -99,17 +118,20 @@ def _deconstruct_schema(self, kwargs):
99118
kwargs.update(schema=schema)
100119

101120
def _deconstruct_config(self, kwargs):
102-
kwargs.update(self.export_cfg, config=self.config)
121+
kwargs.update(self.export_params, config=self.config)
122+
123+
if self.error_handler is not base.default_error_handler:
124+
kwargs.update(error_handler=self.error_handler)
103125

104126

105127
def SchemaField(
106-
schema: t.Union[t.Type["base.ST"], "GenericContainer"],
107-
config: "base.ConfigType" = None,
108128
*args,
129+
schema: t.Type["base.ST"] = None,
130+
config: "base.ConfigType" = None,
109131
error_handler=base.default_error_handler,
110132
**kwargs
111133
) -> t.Any:
112-
return PydanticSchemaField(schema, config, *args, error_handler=error_handler, **kwargs)
134+
return PydanticSchemaField(*args, schema=schema, config=config, error_handler=error_handler, **kwargs)
113135

114136

115137
# Django Migration serializer helpers

django_pydantic_field/rest_framework.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,15 @@ def __init__(
6969
):
7070
self.schema = field_schema = self._wrap_schema(schema, config)
7171
self.decoder = base.SchemaDecoder[base.ST](field_schema, serializer_error_handler)
72-
self.export_cfg = self._extract_export_kwargs(kwargs, dict.pop)
72+
self.export_params = self._extract_export_kwargs(kwargs, dict.pop)
7373
super().__init__(**kwargs)
7474

7575
def to_internal_value(self, data) -> t.Optional["base.ST"]:
7676
return self.decoder.decode(data)
7777

7878
def to_representation(self, value):
7979
obj = self.schema.parse_obj(value)
80-
raw_obj = obj.dict(**self.export_cfg)
80+
raw_obj = obj.dict(**self.export_params)
8181
return raw_obj["__root__"]
8282

8383

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = django-pydantic-field
3-
version = 0.1.0
3+
version = 0.1.1
44
url = https://github.com/surenkov/django-pydantic-field
55

66
description = Django JSONField with Pydantic models as a Schema

tests/sample_app/migrations/0001_initial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Migration(migrations.Migration):
2121
('meta_builtin_list', django_pydantic_field.fields.PydanticSchemaField(config=None, default=list, schema=django_pydantic_field.fields.GenericContainer(list, (tests.sample_app.models.BuildingMeta,)))),
2222
('meta_typing_list', django_pydantic_field.fields.PydanticSchemaField(config=None, default=list, schema=django_pydantic_field.fields.GenericContainer(list, (tests.sample_app.models.BuildingMeta,)))),
2323
('meta_untyped_list', django_pydantic_field.fields.PydanticSchemaField(config=None, default=list, schema=list)),
24-
('meta_untyped_builtim_list', django_pydantic_field.fields.PydanticSchemaField(config=None, default=list, schema=list)),
24+
('meta_untyped_builtin_list', django_pydantic_field.fields.PydanticSchemaField(config=None, default=list, schema=list)),
2525
],
2626
),
2727
]

tests/sample_app/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class BuildingMeta(pydantic.BaseModel):
1919
default_meta = BuildingMeta(type=BuildingTypes.FRAME)
2020

2121
class Building(models.Model):
22-
meta: BuildingMeta = SchemaField(schema=BuildingMeta, default=default_meta)
22+
meta: BuildingMeta = SchemaField(default=default_meta)
2323
meta_builtin_list: list[BuildingMeta] = SchemaField(schema=list[BuildingMeta], default=list)
2424
meta_typing_list: t.List[BuildingMeta] = SchemaField(schema=t.List[BuildingMeta], default=list)
2525
meta_untyped_list: list = SchemaField(schema=t.List, default=list)

tests/test_base_marshalling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_schema_encoder_with_raw_dict():
2828

2929

3030
def test_schema_encoder_with_custom_config():
31-
encoder = base.SchemaEncoder(schema=SampleSchema, export_cfg={"exclude": {"__root__": {"stub_list"}}})
31+
encoder = base.SchemaEncoder(schema=SampleSchema, export={"exclude": {"__root__": {"stub_list"}}})
3232
existing_raw = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}
3333
expected_encoded = '{"stub_str": "abc", "stub_int": 1}'
3434
assert encoder.encode(existing_raw) == expected_encoded

tests/test_django_model_field.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from django.db import models
1212
from django.db.migrations.writer import MigrationWriter
13+
from django.core.exceptions import FieldError
1314

1415
from .conftest import InnerSchema, SampleDataclass
1516

1617

17-
1818
def test_sample_field():
1919
sample_field = fields.PydanticSchemaField(schema=InnerSchema)
2020
existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])
@@ -35,12 +35,21 @@ def test_sample_field_with_raw_data():
3535

3636
def test_simple_model_field():
3737
class SampleModel(models.Model):
38-
sample_field = fields.PydanticSchemaField(schema=InnerSchema)
39-
sample_list = fields.PydanticSchemaField(schema=t.List[InnerSchema])
38+
sample_field: InnerSchema = fields.SchemaField()
39+
sample_list: t.List[InnerSchema] = fields.SchemaField()
40+
sample_seq: t.Sequence[InnerSchema] = fields.SchemaField(schema=t.List[InnerSchema])
4041

4142
class Meta:
4243
app_label = "sample_app"
4344

45+
sample_field = SampleModel._meta.get_field("sample_field")
46+
assert sample_field.schema == InnerSchema
47+
48+
sample_list_field = SampleModel._meta.get_field("sample_list")
49+
assert sample_list_field.schema == t.List[InnerSchema]
50+
51+
sample_seq_field = SampleModel._meta.get_field("sample_seq")
52+
assert sample_seq_field.schema == t.List[InnerSchema]
4453

4554
existing_raw_field = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}
4655
existing_raw_list = [{"stub_str": "abc", "stub_list": []}]
@@ -54,6 +63,15 @@ class Meta:
5463
assert instance.sample_list == expected_list
5564

5665

66+
def test_untyped_model_field_raises():
67+
with pytest.raises(FieldError):
68+
class SampleModel(models.Model):
69+
sample_field = fields.SchemaField()
70+
71+
class Meta:
72+
app_label = "sample_app"
73+
74+
5775
@pytest.mark.parametrize("field", [
5876
fields.PydanticSchemaField(schema=InnerSchema, default=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])),
5977
fields.PydanticSchemaField(schema=InnerSchema, default=(("stub_str", "abc"), ("stub_list", [date(2022, 7, 1)]))),

0 commit comments

Comments
 (0)