Quellcode durchsuchen

Handle ListField

theenglishway (time) vor 6 Jahren
Ursprung
Commit
462138dbf0
5 geänderte Dateien mit 204 neuen und 30 gelöschten Zeilen
  1. 2 1
      pydantic_form/form.py
  2. 24 22
      pydantic_form/translator.py
  3. 96 0
      pydantic_form/utils.py
  4. 63 0
      tests/conftest.py
  5. 19 7
      tests/test_process.py

+ 2 - 1
pydantic_form/form.py

@@ -2,6 +2,7 @@ from collections import defaultdict
 from wtforms import Form, FieldList, FormField, Field
 from pydantic import ValidationError, BaseModel
 
+from .utils import formdata_demangle
 from .translator import SchemaToForm
 
 
@@ -33,7 +34,7 @@ class PydanticForm(Form):
     def process(self, formdata=None, obj=None, data=None, **kwargs):
         super().process()
         formdata = self.meta.wrap_formdata(self, formdata)
-        input = formdata or data or kwargs
+        input = formdata_demangle(formdata) or data or kwargs
         self.translator(input)
         self.translator.set_data()
 

+ 24 - 22
pydantic_form/translator.py

@@ -1,23 +1,9 @@
-from collections import defaultdict, abc
-from .iterators import *
+from wtforms import FieldList, FormField
+from .iterators import iter_schema, iter_form
 from .utils import *
 from pydantic import ValidationError
 
 
-def default_to_regular(d):
-    if isinstance(d, defaultdict):
-        d = {k: default_to_regular(v) for k, v in d.items()}
-    return d
-
-
-def nested_dict_iter(nested, path=()):
-    for key, value in nested.items():
-        if isinstance(value, abc.Mapping):
-            yield from nested_dict_iter(value, path + (key,))
-        else:
-            yield path + (key,), value
-
-
 class SchemaToForm:
     _schema = None
     _errors = None
@@ -66,7 +52,11 @@ class SchemaToForm:
 
             dest_key = self.lut(src_key)
             dest_field = rgetattr(self.form, dest_key)
-            setattr(dest_field, 'data', value)
+            if isinstance(dest_field, FieldList):
+                for n, f in enumerate(dest_field.entries):
+                    setattr(f, 'data', value[n])
+            else:
+                setattr(dest_field, 'data', value)
 
     def set_baked(self):
         for k, field in iter_form(self.form, leafs_only=False):
@@ -75,8 +65,20 @@ class SchemaToForm:
 
     def set_errors(self):
         for k, error_list in nested_dict_iter(self.errors):
-            field = rgetattr(self.form, k)
-            if isinstance(field, FormField):
-                setattr(field.form, '_errors', error_list)
-            else:
-                setattr(field, 'errors', error_list)
+            try:
+                field = rgetattr(self.form, k)
+                if isinstance(field, FormField):
+                    setattr(field.form, '_errors', error_list)
+                else:
+                    setattr(field, 'errors', error_list)
+            except TypeError:
+                *field_list, n = k
+                rgetattr(self.form, tuple(field_list)).entries[n].errors = error_list
+
+        for k, error_list in nested_dict_iter(self.errors):
+            try:
+                rgetattr(self.form, k)
+            except TypeError:
+                *field_list, n = k
+                form = rgetattr(self.form, tuple(field_list))
+                form.errors = [entry.errors for entry in form.entries]

+ 96 - 0
pydantic_form/utils.py

@@ -1,3 +1,5 @@
+from collections import defaultdict, abc
+from werkzeug.datastructures import ImmutableMultiDict
 import functools
 
 
@@ -45,3 +47,97 @@ def rgetattr(obj, attr, *args):
         return getattr(obj, attr, *args)
 
     return functools.reduce(_getattr, (obj,) + attr)
+
+
+def default_to_regular(d):
+    if isinstance(d, defaultdict):
+        d = {k: default_to_regular(v) for k, v in d.items()}
+    return d
+
+
+def nested_dict_iter(nested, path=()):
+    for key, value in nested.items():
+        if isinstance(value, abc.Mapping):
+            yield from nested_dict_iter(value, path + (key,))
+        else:
+            yield path + (key,), value
+
+
+def to_formdata_item(data, prefix=''):
+    def get_prefixed(key):
+        return f'{prefix}-{key}' if prefix else key
+
+    for k, val in data.items():
+        if isinstance(val, list):
+            for n, sub_val in enumerate(val):
+                if isinstance(sub_val, dict):
+                    yield from to_formdata_item(sub_val, get_prefixed(f'{k}-{n}'))
+                else:
+                    yield get_prefixed(f'{k}-{n}'), sub_val
+        elif isinstance(val, dict):
+            yield from to_formdata_item(val, get_prefixed(k))
+        else:
+            yield get_prefixed(k), val
+
+
+def formdata_mangle(data):
+    return ImmutableMultiDict(to_formdata_item(data))
+
+
+def formdata_demangle(formdata):
+    def iter_by_reverse_key_length(iter):
+        return sorted(iter, key=lambda k: len(k[0]), reverse=True)
+
+    def is_int(value):
+        try:
+            int(value)
+            return True
+        except ValueError:
+            return False
+
+    if not formdata:
+        return None
+
+    by_tuple = {tuple(k.split('-')): val for k, val in formdata.items()}
+
+    formdata_by_tuple = defaultdict(defaultdict)
+    for k, val in iter_by_reverse_key_length(by_tuple.items()):
+        recursive_setdefault(formdata_by_tuple, val, *k)
+
+    formdata_leaf_lists = defaultdict(list)
+    for k, val in iter_by_reverse_key_length(nested_dict_iter(formdata_by_tuple)):
+        if is_int(k[-1]):
+            recursive_dict_operation(
+                formdata_leaf_lists,
+                lambda d, k: d.setdefault(k, defaultdict(list)),
+                lambda d, k: d[k].append(val),
+                *k[:-1]
+            )
+        else:
+            recursive_setdefault(formdata_leaf_lists, val, *k)
+
+    formdata_final = defaultdict(list)
+    nested_lists = []
+    for k, val in iter_by_reverse_key_length(nested_dict_iter(formdata_leaf_lists)):
+        for n, s_k in enumerate(reversed(k)):
+            if is_int(s_k):
+                value_key = k[:len(k) - n]
+                if value_key not in nested_lists:
+                    value = recursive_get(formdata_leaf_lists, *value_key)
+                    nested_list_key = value_key[:-1]
+                    recursive_dict_operation(
+                        formdata_final,
+                        lambda d, k: d.setdefault(k, defaultdict(list)),
+                        lambda d, k: d[k].append(default_to_regular(value)),
+                        *nested_list_key
+                    )
+                    nested_lists.append(value_key)
+                break
+        else:
+            recursive_setdefault(
+                formdata_final,
+                val,
+                *k
+            )
+
+    return default_to_regular(formdata_final)

+ 63 - 0
tests/conftest.py

@@ -6,6 +6,7 @@ from werkzeug.datastructures import ImmutableMultiDict
 from pydantic_form import PydanticForm
 from pydantic import BaseModel, ValidationError
 from pydantic.types import StrictStr
+from typing import List
 
 
 
@@ -97,6 +98,68 @@ def scenario_simple():
         simple_expected_errors
     )
 
+# Simple list case
+
+simple_list_keys = [('integer',), ('string',)]
+
+class SimpleListSchema(BaseModel):
+    integers: List[int]
+    strings: List[StrictStr]
+
+class SimpleListWTForm(Form):
+    _schema = SimpleSchema
+
+    integers = fields.FieldList(fields.IntegerField(), min_entries=2)
+    strings = fields.FieldList(fields.StringField(), min_entries=2)
+
+
+class SimpleListForm(SimpleListWTForm, PydanticForm):
+    _schema = SimpleListSchema
+
+
+class SimpleListDataFactory(factory.Factory):
+    class Meta:
+        model = dict
+
+    integers = factory.List([factory.Faker('pyint'), factory.Faker('pyint')])
+    strings = factory.List([factory.Faker('pystr'), factory.Faker('pystr')])
+
+class SimpleListBadDataFactory(factory.Factory):
+    class Meta:
+        model = dict
+
+    integers = factory.List([factory.Faker('pystr'), factory.Faker('pystr')])
+    strings = factory.List([factory.Faker('pyint'), factory.Faker('pyint')])
+
+simple_list_data_factories = DataFactories(
+    SimpleListDataFactory,
+    SimpleListBadDataFactory,
+    MissingDataFactory
+)
+
+simple_list_expected_errors = ExpectedErrors(
+    {},
+    {
+        ('integers',): ['type_error.integer', 'type_error.integer'],
+        ('strings',): ['type_error.str', 'type_error.str']
+    },
+    {
+        ('integers',): 'value_error.missing',
+        ('strings',): 'value_error.missing'
+    }
+)
+
+@pytest.fixture
+def scenario_simple_list():
+    return Scenario(
+        SimpleListForm,
+        SimpleListSchema,
+        simple_list_keys,
+        simple_list_data_factories,
+        simple_list_expected_errors
+    )
+
+
 # Case with one level of nesting
 
 nested_keys = [('integer',), ('nested', 'integer'), ('nested', 'string')]

+ 19 - 7
tests/test_process.py

@@ -1,16 +1,18 @@
-from itertools import product
-from werkzeug.datastructures import ImmutableMultiDict
-from pydantic_form.translator import *
-from pydantic_form.utils import recursive_get
 import pytest
+from wtforms import FormField
+
+from pydantic_form.translator import iter_form
+from pydantic_form.utils import recursive_get, formdata_mangle
 
 
 @pytest.fixture
 def data(request, scenario):
     return getattr(scenario.data_factory, request.param)()
 
+
 SCENARIOS = [
     'scenario_simple',
+    'scenario_simple_list',
     'scenario_nested',
     'scenario_double_nested'
 ]
@@ -19,7 +21,7 @@ SCENARIOS = [
     'kwargs_factory',
     [
         lambda data: {'data': data},
-        lambda data: {'formdata': ImmutableMultiDict(data)},
+        lambda data: {'formdata': formdata_mangle(data)},
         lambda data: data,
     ], ids=[
         'data', 'formdata', 'kwargs'
@@ -113,8 +115,18 @@ def test_errors_invalid(scenario, data, errors_factory):
 
     form.validate()
     assert form.errors
-    for k, error in errors_factory(scenario).items():
-        assert recursive_get(form.errors, *k)[0]['type'] == error
+
+    for k, expected in errors_factory(scenario).items():
+        field_errors = recursive_get(form.errors, *k)
+
+        if isinstance(expected, list):
+            errors = [e['type'] for e_list in field_errors for e in e_list]
+            assert errors == expected
+
+        else:
+            errors = [e['type'] for e in field_errors]
+            assert len(errors) == 1
+            assert errors[0] == expected
 
 
 @pytest.mark.parametrize(