Explorar el Código

Major redesign to handle complex cases with nested lists

theenglishway (time) hace 6 años
padre
commit
ac0ba479b2
Se han modificado 7 ficheros con 317 adiciones y 157 borrados
  1. 2 19
      pydantic_form/form.py
  2. 72 26
      pydantic_form/iterators.py
  3. 51 33
      pydantic_form/translator.py
  4. 20 13
      pydantic_form/utils.py
  5. 135 28
      tests/conftest.py
  6. 27 22
      tests/test_iterators.py
  7. 10 16
      tests/test_process.py

+ 2 - 19
pydantic_form/form.py

@@ -1,23 +1,9 @@
-from collections import defaultdict
-from wtforms import Form, FieldList, FormField, Field
-from pydantic import ValidationError, BaseModel
+from wtforms import Form
 
 from .utils import formdata_demangle
 from .translator import SchemaToForm
 
 
-class PydanticFieldList(FieldList):
-    _baked_instance = None
-
-    def populate_obj(self, obj, name):
-        attr = getattr(obj, name)
-        for n, v in enumerate(self._baked_instance):
-            self.entries[n].form._baked_instance = v
-            model = self.entries[n].form.get_model()
-            new = model(**v.dict())
-            attr.append(new)
-
-
 class PydanticForm(Form):
     _errors = {}
     _schema = None
@@ -44,7 +30,4 @@ class PydanticForm(Form):
             return False
 
         self.translator.set_baked()
-        return True
-
-    def process_obj(self, obj):
-        ...
+        return True

+ 72 - 26
pydantic_form/iterators.py

@@ -1,54 +1,100 @@
-from wtforms import Form, FormField
+from wtforms import Form, FormField, FieldList
+from pydantic.fields import Shape
 
+from .utils import *
 
-def iter_field(parent_class, leafs_only=True, path=()):
-    field = getattr(parent_class, path[-1]) if path else parent_class
 
+def iter_field(field, leafs_only=True, path=(), parent=None):
     if issubclass(field.field_class, FormField):
         field_class = field.kwargs['form_class']
         if path and not leafs_only:
-            yield path, field_class
+            yield path, field_class, parent
 
         for subfield_name, subfield in field_class._unbound_fields:
-            yield from iter_field(field_class, leafs_only, path + (subfield_name,))
+            yield from iter_field(subfield, leafs_only, path + (subfield_name,), parent)
+
+    elif issubclass(field.field_class, FieldList):
+        if path and not leafs_only:
+            yield path, field.field_class, parent
+        yield from iter_field(field.args[0], leafs_only, path, path)
 
     else:
-        yield path, field
+        yield path, field, parent
 
 
 def iter_form_class(form_class, leafs_only=True, path=()):
     field = getattr(form_class, path[-1]) if path else form_class
 
-    if path and not leafs_only:
-        yield path, field
+    if not leafs_only:
+        yield path, field, None
 
     try:
         for subfield_name, subfield in field._unbound_fields:
-            yield from iter_field(field, leafs_only, path + (subfield_name,))
+            yield from iter_field(subfield, leafs_only, path + (subfield_name,))
     except TypeError as e:
         # Ensure that the _unbound_fields is populated (that happens on first
         # instantiation)
         form_class()
-        yield from iter_form_class(field, leafs_only, path )
-
-def iter_form(form, leafs_only=True, path=()):
-    field = getattr(form, path[-1]) if path else form
-    if isinstance(field, Form) or isinstance(field, FormField):
-        if path and not leafs_only:
-            yield path, field
+        yield from iter_form_class(field, leafs_only, path)
 
-        for f in field._fields:
-            yield from iter_form(field, leafs_only, path + (f,))
-    else:
-        yield path, field
+def iter_form(form, leafs_only=True):
+    for key, field, iterate_on in iter_form_class(form.__class__, leafs_only):
+        if iterate_on:
+            value = rgetattr(form, iterate_on)
+            for n, entry in enumerate(value.entries):
+                yield iterate_on + (n, *key[len(iterate_on):]), entry
+            continue
+        yield key, field
 
-def iter_schema(schema, leafs_only=True, path=()):
-    type_ = schema.type_ if path else schema
+def iter_schema_class_field(field, leafs_only=True, path=(), parent=None):
+    type_ = field.type_
     if hasattr(type_, '__fields__'):
         if path and not leafs_only:
-            yield path, schema
+            yield path, field, None
 
-        for key, value in type_.__fields__.items():
-            yield from iter_schema(value, leafs_only, path + (key,))
+        for key, sub_field in type_.__fields__.items():
+            yield from iter_schema_class_field(sub_field, leafs_only, path + (key,), field)
     else:
-        yield path, schema
+        iterate_on = None
+        if parent and parent.shape != Shape.SINGLETON:
+            iterate_on = path[:-1]
+        elif field.shape != Shape.SINGLETON:
+            iterate_on = path
+
+        yield path, field, iterate_on
+
+def iter_schema_class(schema, leafs_only=True, path=()):
+    for key, field in schema.__fields__.items():
+        yield from iter_schema_class_field(field, leafs_only, path + (key,))
+
+
+def iter_schema(schema, leafs_only=True):
+    def get_schema_value(schema, key):
+        try:
+            return rgetattr(schema, key)
+        except AttributeError:
+            try:
+                return recursive_get(getattr(schema, key[0]), *key[1:])
+            except AttributeError:
+                return recursive_get(schema, *key)
+
+    for key, field, iterate_on in iter_schema_class(schema.__class__, leafs_only):
+        if iterate_on:
+            try:
+                sub_values = get_schema_value(schema, iterate_on)
+                if isinstance(sub_values, list):
+                    for n, sub_value in enumerate(sub_values):
+                        yield iterate_on + (n, *key[len(iterate_on):]), get_schema_value(sub_value, key[len(iterate_on):])
+                else:
+                    raise NotImplementedError()
+            except AttributeError as e:
+                if schema.dict() != {}:
+                    raise e
+
+            continue
+
+        try:
+            yield key, get_schema_value(schema, key)
+        except AttributeError as e:
+            if schema.dict() != {}:
+                raise e

+ 51 - 33
pydantic_form/translator.py

@@ -39,46 +39,64 @@ class SchemaToForm:
     def errors(self):
         return self._errors
 
-    def set_data(self):
-        schema = self._schema
-        for src_key, src_value in iter_schema(self.schema):
-            try:
-                value = rgetattr(schema, src_key)
-            except AttributeError:
-                try:
-                    value = recursive_get(getattr(schema, src_key[0]), *src_key[1:])
-                except AttributeError:
-                    continue
+    @classmethod
+    def get_item(cls, dict, key):
+        for k in key:
+            if is_int(k):
+                sub_dict = dict[k]
+            else:
+                sub_dict = dict[k]
+            return cls.get_item(sub_dict, key[1:])
 
-            dest_key = self.lut(src_key)
-            dest_field = rgetattr(self.form, dest_key)
-            if isinstance(dest_field, FieldList):
-                for n, f in enumerate(dest_field.entries):
-                    setattr(f, 'data', value[n])
+        return dict
+
+    @classmethod
+    def get_field(cls, form, key):
+        for k in key:
+            if is_int(k):
+                sub_form = form.entries[k]
+            else:
+                sub_form = getattr(form, k)
+            return cls.get_field(sub_form, key[1:])
+
+        return form
+
+    @classmethod
+    def get_schema(cls, schema, key):
+        for k in key:
+            if is_int(k):
+                sub = schema[k]
             else:
-                setattr(dest_field, 'data', value)
+                sub = getattr(schema, k)
+            return cls.get_schema(sub, key[1:])
+
+        return schema
+
+    def set_data(self):
+        for src_key, value in iter_schema(self.schema):
+            dest_key = self.lut(src_key)
+            dest_field = self.get_field(self.form, dest_key)
+            dest_field.data = value
 
     def set_baked(self):
         for k, field in iter_form(self.form, leafs_only=False):
+            field = self.get_field(self.form, k)
             if isinstance(field, FormField):
-                field._baked_instance = rgetattr(self.schema, k)
+                field._baked_instance = self.get_schema(self.schema, k)
 
     def set_errors(self):
-        for k, error_list in nested_dict_iter(self.errors):
-            try:
-                field = rgetattr(self.form, k)
-                if isinstance(field, FormField):
-                    setattr(field.form, '_errors', error_list)
-                else:
-                    setattr(field, 'errors', error_list)
-            except TypeError:
-                *field_list, n = k
-                rgetattr(self.form, tuple(field_list)).entries[n].errors = error_list
-
-        for k, error_list in nested_dict_iter(self.errors):
-            try:
-                rgetattr(self.form, k)
-            except TypeError:
-                *field_list, n = k
+        for path, error_list in nested_dict_iter(self.errors):
+            field = self.get_field(self.form, path)
+            if isinstance(field, FormField):
+                setattr(field.form, '_errors', error_list)
+            else:
+                setattr(field, 'errors', error_list)
+
+        for path, error_list in nested_dict_iter(self.errors):
+            paths_with_list = [
+                (path[:n], path[n], path[n+1:])
+                for n, val in enumerate(path) if is_int(val)
+            ]
+            for field_list, _, _ in paths_with_list:
                 form = rgetattr(self.form, tuple(field_list))
                 form.errors = [entry.errors for entry in form.entries]

+ 20 - 13
pydantic_form/utils.py

@@ -55,12 +55,18 @@ def default_to_regular(d):
     return d
 
 
-def nested_dict_iter(nested, path=()):
-    for key, value in nested.items():
-        if isinstance(value, abc.Mapping):
-            yield from nested_dict_iter(value, path + (key,))
-        else:
-            yield path + (key,), value
+def nested_dict_iter(nested, path=(), iter_list=False):
+    try:
+        for key, value in nested.items():
+            if isinstance(value, abc.Mapping):
+                yield from nested_dict_iter(value, path + (key,))
+            elif iter_list and isinstance(value, abc.Iterable) and not isinstance(value, str):
+                for n, sub_val in enumerate(value):
+                    yield from nested_dict_iter(sub_val, path + (key, n))
+            else:
+                yield path + (key,), value
+    except AttributeError:
+        yield path, nested
 
 
 def to_formdata_item(data, prefix=''):
@@ -84,17 +90,18 @@ def formdata_mangle(data):
     return ImmutableMultiDict(to_formdata_item(data))
 
 
+def is_int(value):
+    try:
+        int(value)
+        return True
+    except ValueError:
+        return False
+
+
 def formdata_demangle(formdata):
     def iter_by_reverse_key_length(iter):
         return sorted(iter, key=lambda k: len(k[0]), reverse=True)
 
-    def is_int(value):
-        try:
-            int(value)
-            return True
-        except ValueError:
-            return False
-
     if not formdata:
         return None
 

+ 135 - 28
tests/conftest.py

@@ -25,6 +25,10 @@ Scenario = namedtuple(
     ['form', 'schema', 'keys', 'data_factory', 'errors']
 )
 
+Keys = namedtuple(
+    "Keys",
+    ["class_", "instance"]
+)
 
 @pytest.fixture
 def scenario(request):
@@ -38,7 +42,10 @@ class MissingDataFactory(factory.Factory):
 
 # Simplest case
 
-simple_keys = [('integer',), ('string',)]
+simple_keys = Keys(
+    [('integer',), ('string',)],
+    [('integer',), ('string',)]
+)
 
 class SimpleSchema(BaseModel):
     integer: int
@@ -79,12 +86,12 @@ simple_data_factories = DataFactories(
 simple_expected_errors = ExpectedErrors(
     {},
     {
-        ('integer',): 'type_error.integer',
-        ('string',): 'type_error.str'
+        'integer': 'type_error.integer',
+        'string': 'type_error.str'
     },
     {
-        ('integer',): 'value_error.missing',
-        ('string',): 'value_error.missing',
+        'integer': 'value_error.missing',
+        'string': 'value_error.missing',
     }
 )
 
@@ -100,7 +107,10 @@ def scenario_simple():
 
 # Simple list case
 
-simple_list_keys = [('integer',), ('string',)]
+simple_list_keys = Keys(
+    [('integers',), ('strings',)],
+    [('integers', 0), ('integers', 1), ('strings', 0), ('strings', 1)],
+)
 
 class SimpleListSchema(BaseModel):
     integers: List[int]
@@ -140,12 +150,12 @@ simple_list_data_factories = DataFactories(
 simple_list_expected_errors = ExpectedErrors(
     {},
     {
-        ('integers',): ['type_error.integer', 'type_error.integer'],
-        ('strings',): ['type_error.str', 'type_error.str']
+        'integers': ['type_error.integer', 'type_error.integer'],
+        'strings': ['type_error.str', 'type_error.str']
     },
     {
-        ('integers',): 'value_error.missing',
-        ('strings',): 'value_error.missing'
+        'integers': 'value_error.missing',
+        'strings': 'value_error.missing'
     }
 )
 
@@ -162,7 +172,10 @@ def scenario_simple_list():
 
 # Case with one level of nesting
 
-nested_keys = [('integer',), ('nested', 'integer'), ('nested', 'string')]
+nested_keys = Keys(
+    [('integer',), ('nested', 'integer'), ('nested', 'string')],
+    [('integer',), ('nested', 'integer'), ('nested', 'string')]
+)
 
 class NestedSchema(BaseModel):
     integer: int
@@ -204,13 +217,15 @@ nested_data_factories = DataFactories(
 nested_expected_errors = ExpectedErrors(
     {},
     {
-        ('integer',): 'type_error.integer',
-        ('nested', 'integer',): 'type_error.integer',
-        ('nested', 'string',): 'type_error.str'
+        'integer': 'type_error.integer',
+        'nested': {
+            'integer': 'type_error.integer',
+            'string': 'type_error.str'
+        }
     },
     {
-        ('integer',): 'value_error.missing',
-        ('nested',): 'value_error.missing',
+        'integer': 'value_error.missing',
+        'nested': 'value_error.missing',
     }
 )
 
@@ -225,13 +240,101 @@ def scenario_nested():
     )
 
 
+# Case with one level of nesting inside a list
+
+nested_list_keys = Keys(
+    [('integer',), ('nested_list', 'integer'), ('nested_list', 'string')],
+    [
+        ('integer',),
+        ('nested_list', 0, 'integer'),
+        ('nested_list', 1, 'integer'),
+        ('nested_list', 0, 'string'),
+        ('nested_list', 1, 'string')
+    ]
+)
+
+class NestedListSchema(BaseModel):
+    integer: int
+    nested_list: List[SimpleSchema]
+
+
+class NestedListWTForm(Form):
+    _schema = NestedListSchema
+
+    integer = fields.IntegerField()
+    nested_list = fields.FieldList(fields.FormField(form_class=SimpleForm), min_entries=2)
+
+
+class NestedListForm(NestedListWTForm, PydanticForm):
+    _schema = NestedListSchema
+
+
+class NestedListDataFactory(factory.Factory):
+    class Meta:
+        model = dict
+
+    integer = factory.Faker('pyint')
+    nested_list = factory.List([
+        factory.SubFactory(SimpleDataFactory),
+        factory.SubFactory(SimpleDataFactory)
+    ])
+
+
+class NestedListBadDataFactory(factory.Factory):
+    class Meta:
+        model = dict
+
+    integer = factory.Faker('pystr')
+    nested_list = factory.List([
+        factory.SubFactory(SimpleBadDataFactory),
+        factory.SubFactory(SimpleBadDataFactory)
+    ])
+
+nested_list_data_factories = DataFactories(
+    NestedListDataFactory,
+    NestedListBadDataFactory,
+    MissingDataFactory
+)
+
+nested_list_expected_errors = ExpectedErrors(
+    {},
+    {
+        'integer': 'type_error.integer',
+        'nested_list': [
+            {'integer': 'type_error.integer'},
+            {'string': 'type_error.str'}
+        ],
+    },
+    {
+        'integer': 'value_error.missing',
+        'nested_list': 'value_error.missing',
+    }
+)
+
+@pytest.fixture
+def scenario_nested_list():
+    return Scenario(
+        NestedListForm,
+        NestedListSchema,
+        nested_list_keys,
+        nested_list_data_factories,
+        nested_list_expected_errors
+    )
+
 # Case with two levels of nesting
-double_nested_keys = [
-    ('integer',),
-    ('double_nested', 'integer'),
-    ('double_nested', 'nested', 'integer'),
-    ('double_nested', 'nested', 'string')
-]
+double_nested_keys = Keys([
+        ('integer',),
+        ('double_nested', 'integer'),
+        ('double_nested', 'nested', 'integer'),
+        ('double_nested', 'nested', 'string')
+    ],
+    [
+        ('integer',),
+        ('double_nested', 'integer'),
+        ('double_nested', 'nested', 'integer'),
+        ('double_nested', 'nested', 'string')
+    ]
+)
 
 class DoubleNestedSchema(BaseModel):
     integer: int
@@ -274,14 +377,18 @@ double_nested_data_factories = DataFactories(
 double_nested_expected_errors = ExpectedErrors(
     {},
     {
-        ('integer',): 'type_error.integer',
-        ('double_nested', 'integer',): 'type_error.integer',
-        ('double_nested', 'nested', 'integer',): 'type_error.integer',
-        ('double_nested', 'nested', 'string',): 'type_error.str'
+        'integer': 'type_error.integer',
+        'double_nested': {
+            'integer': 'type_error.integer',
+            'nested': {
+                'integer': 'type_error.integer',
+                'string': 'type_error.str'
+            }
+        }
     },
     {
-        ('integer',): 'value_error.missing',
-        ('double_nested',): 'value_error.missing',
+        'integer': 'value_error.missing',
+        'double_nested': 'value_error.missing',
     }
 )
 

+ 27 - 22
tests/test_iterators.py

@@ -1,41 +1,46 @@
 import pytest
-from pydantic_form.iterators import iter_form, iter_schema, iter_form_class
+from pydantic_form.iterators import iter_form, iter_schema, iter_form_class, iter_schema_class
 
 
-@pytest.fixture
-def instance_factory(request):
-    def _factory(scenario, data):
-        instances = dict(
-            schema_class=scenario.schema,
-            schema=scenario.schema(**data),
-            form_class=scenario.wtf_form,
-            form=scenario.wtf_form()
-        )
+@pytest.mark.parametrize(
+    'scenario',
+    [
+        'scenario_simple',
+        'scenario_simple_list',
+        'scenario_nested',
+        'scenario_nested_list',
+        'scenario_double_nested'
+    ], indirect=True
+)
+def test_iterator_schema(scenario):
+    data = scenario.data_factory.valid()
+    keys = scenario.keys
 
-        return instances[request.param]
+    assert [k for k, _, _ in iter_schema_class(scenario.schema)] == keys.class_
+
+    schema_instance = scenario.schema(**data)
+    assert [k for k, _, _ in iter_schema_class(scenario.schema)] == keys.class_
+    assert [k for k, _ in iter_schema(schema_instance)] == keys.instance
 
-    return _factory
+    schema_construct = scenario.schema.construct(data, scenario.schema.__fields__)
+    assert [k for k, _ in iter_schema(schema_construct)] == keys.instance
 
 
 @pytest.mark.parametrize(
     'scenario',
     [
         'scenario_simple',
+        'scenario_simple_list',
         'scenario_nested',
+        'scenario_nested_list',
         'scenario_double_nested'
     ], indirect=True
 )
-def test_iterators(scenario):
-    data = scenario.data_factory.valid()
+def test_iterator_form(scenario):
     keys = scenario.keys
 
-    assert [k for k, _ in iter_form_class(scenario.form)] == keys
+    assert [k for k, _, _ in iter_form_class(scenario.form)] == keys.class_
 
     form = scenario.form()
-    assert [k for k, _ in iter_form_class(scenario.form)] == keys
-    assert [k for k, _ in iter_form(form)] == keys
-
-    assert [k for k, _ in iter_schema(scenario.schema)] == keys
-
-    schema_instance = scenario.schema(**data)
-    assert [k for k, _ in iter_schema(schema_instance)] == keys
+    assert [k for k, _, _ in iter_form_class(scenario.form)] == keys.class_
+    assert [k for k, _ in iter_form(form)] == keys.instance

+ 10 - 16
tests/test_process.py

@@ -1,8 +1,8 @@
 import pytest
 from wtforms import FormField
 
-from pydantic_form.translator import iter_form
-from pydantic_form.utils import recursive_get, formdata_mangle
+from pydantic_form.translator import iter_form, SchemaToForm
+from pydantic_form.utils import recursive_get, formdata_mangle, nested_dict_iter
 
 
 @pytest.fixture
@@ -14,6 +14,7 @@ SCENARIOS = [
     'scenario_simple',
     'scenario_simple_list',
     'scenario_nested',
+    'scenario_nested_list',
     'scenario_double_nested'
 ]
 
@@ -96,7 +97,6 @@ def test_errors_valid(scenario, data):
     form = scenario.form(data=data)
     assert form.errors == {}
 
-
 @pytest.mark.parametrize(
     'data, errors_factory',
     [
@@ -115,18 +115,11 @@ def test_errors_invalid(scenario, data, errors_factory):
 
     form.validate()
     assert form.errors
-
-    for k, expected in errors_factory(scenario).items():
-        field_errors = recursive_get(form.errors, *k)
-
-        if isinstance(expected, list):
-            errors = [e['type'] for e_list in field_errors for e in e_list]
-            assert errors == expected
-
-        else:
-            errors = [e['type'] for e in field_errors]
-            assert len(errors) == 1
-            assert errors[0] == expected
+    for k, expected in nested_dict_iter(errors_factory(scenario), iter_list=True):
+        field_errors = SchemaToForm.get_item(form.errors, k)
+        errors = [e['type'] for e in field_errors]
+        assert len(errors) == 1
+        assert errors[0] == expected
 
 
 @pytest.mark.parametrize(
@@ -149,7 +142,8 @@ def test_baked_instance(scenario, data):
     assert form._baked_instance.dict() == data
 
     for key, sub_form in iter_form(form, leafs_only=False):
+        sub_form = SchemaToForm.get_field(form, key)
         if isinstance(sub_form, FormField):
             assert hasattr(sub_form, '_baked_instance')
             assert sub_form._baked_instance is not None
-            assert sub_form._baked_instance.dict() == recursive_get(data, *key)
+            assert sub_form._baked_instance.dict() == SchemaToForm.get_item(data, key)