Bläddra i källkod

Fix nested_list

theenglishway (time) 6 år sedan
förälder
incheckning
fceef49db8
4 ändrade filer med 102 tillägg och 29 borttagningar
  1. 27 4
      pydantic_form/iterators.py
  2. 56 14
      pydantic_form/translator.py
  3. 7 7
      tests/conftest.py
  4. 12 4
      tests/test_process.py

+ 27 - 4
pydantic_form/iterators.py

@@ -14,6 +14,8 @@ def iter_field(field, leafs_only=True, path=(), parent=None):
             yield from iter_field(subfield, leafs_only, path + (subfield_name,), parent)
 
     elif issubclass(field.field_class, FieldList):
+        if path and not leafs_only:
+            yield path, field.field_class, parent
         yield from iter_field(field.args[0], leafs_only, path, path)
 
     else:
@@ -23,7 +25,7 @@ def iter_field(field, leafs_only=True, path=(), parent=None):
 def iter_form_class(form_class, leafs_only=True, path=()):
     field = getattr(form_class, path[-1]) if path else form_class
 
-    if path and not leafs_only:
+    if not leafs_only:
         yield path, field, None
 
     try:
@@ -40,7 +42,7 @@ def iter_form(form, leafs_only=True):
         if iterate_on:
             value = rgetattr(form, iterate_on)
             for n, entry in enumerate(value.entries):
-                yield key + (n,), entry
+                yield iterate_on + (n, *key[len(iterate_on):]), entry
             continue
         yield key, field
 
@@ -65,12 +67,32 @@ def iter_schema_class(schema, leafs_only=True, path=()):
     for key, field in schema.__fields__.items():
         yield from iter_schema_class_field(field, leafs_only, path + (key,))
 
+def get_schema_value(schema, key):
+    parent = None
+    idx = None
+    if is_int(key[-1]):
+        *parent, idx = key
+    try:
+        if parent:
+            return rgetattr(getattr(schema, parent[0])[idx], tuple(parent[1:]))
+        else:
+            return rgetattr(schema, key)
+    except AttributeError:
+        if parent:
+            return recursive_get(getattr(schema, parent[0])[idx], *parent[1:])
+        else:
+            return recursive_get(schema, *key)
+
+
 def iter_schema(schema, leafs_only=True):
     def get_schema_value(schema, key):
         try:
             return rgetattr(schema, key)
         except AttributeError:
-            return recursive_get(getattr(schema, key[0]), *key[1:])
+            try:
+                return recursive_get(getattr(schema, key[0]), *key[1:])
+            except AttributeError:
+                return recursive_get(schema, *key)
 
     for key, field, iterate_on in iter_schema_class(schema.__class__, leafs_only):
         if iterate_on:
@@ -78,10 +100,11 @@ def iter_schema(schema, leafs_only=True):
                 sub_values = get_schema_value(schema, iterate_on)
                 if isinstance(sub_values, list):
                     for n, sub_value in enumerate(sub_values):
-                        yield key + (n,), sub_value
+                        yield iterate_on + (n, *key[len(iterate_on):]), get_schema_value(sub_value, key[len(iterate_on):])
                 else:
                     raise NotImplementedError()
             except AttributeError:
+                print('attr error')
                 ...
 
             continue

+ 56 - 14
pydantic_form/translator.py

@@ -1,5 +1,5 @@
 from wtforms import FieldList, FormField
-from .iterators import iter_schema, iter_form
+from .iterators import iter_schema, iter_form, get_schema_value
 from .utils import *
 from pydantic import ValidationError
 
@@ -39,38 +39,80 @@ class SchemaToForm:
     def errors(self):
         return self._errors
 
+    @classmethod
+    def get_field(cls, form, key):
+        for k in key:
+            if is_int(k):
+                sub_form = form.entries[k]
+            else:
+                sub_form = getattr(form, k)
+            return cls.get_field(sub_form, key[1:])
+
+        return form
+
+    @classmethod
+    def get_schema(cls, schema, key):
+        for k in key:
+            if is_int(k):
+                sub = schema[k]
+            else:
+                sub = getattr(schema, k)
+            return cls.get_schema(sub, key[1:])
+
+        return schema
+
     def set_data(self):
         for src_key, value in iter_schema(self.schema):
             dest_key = self.lut(src_key)
             if is_int(dest_key[-1]):
-                *dest_field, idx = dest_key
-                dest_field = rgetattr(self.form, dest_key[:-1])
-                dest_field.entries[idx].data = value
+                *parent_key, idx = dest_key
+
+                try:
+                    dest_field = rgetattr(self.form, tuple(parent_key))
+                    dest_field.entries[idx].data = value
+                except AttributeError:
+                    dest_field = getattr(
+                        rgetattr(self.form, tuple(parent_key[:-1])).entries[idx],
+                        parent_key[-1]
+                    )
+                    dest_field.data = value
+
             else:
-                dest_field = rgetattr(self.form, dest_key)
+                dest_field = self.get_field(self.form, dest_key)
                 setattr(dest_field, 'data', value)
 
     def set_baked(self):
-        for k, field in iter_form(self.form, leafs_only=False):
-            if isinstance(field, FormField):
-                field._baked_instance = rgetattr(self.schema, k)
+        for k, field in iter_form(self.form):
+            field._baked_instance = self.get_schema(self.schema, k)
 
     def set_errors(self):
         for k, error_list in nested_dict_iter(self.errors):
+
             try:
-                field = rgetattr(self.form, k)
+                print(k)
+                field = self.get_field(self.form, k)
+                print(k, field.__class__)
                 if isinstance(field, FormField):
                     setattr(field.form, '_errors', error_list)
                 else:
                     setattr(field, 'errors', error_list)
             except TypeError:
-                *field_list, n = k
-                rgetattr(self.form, tuple(field_list)).entries[n].errors = error_list
+                try:
+                    *field_list, n = k
+                    rgetattr(self.form, tuple(field_list)).entries[n].errors = error_list
+                except TypeError:
+                    *field_list, n, key = k
+                    getattr(rgetattr(self.form, tuple(field_list))[n], key).errors = error_list
+
 
         for k, error_list in nested_dict_iter(self.errors):
             try:
                 rgetattr(self.form, k)
             except TypeError:
-                *field_list, n = k
-                form = rgetattr(self.form, tuple(field_list))
-                form.errors = [entry.errors for entry in form.entries]
+                try:
+                    *field_list, n = k
+                    form = rgetattr(self.form, tuple(field_list))
+                except TypeError:
+                    *field_list, n, key = k
+                    form = rgetattr(self.form, tuple(field_list))
+                form.errors = [entry.errors for entry in form.entries]

+ 7 - 7
tests/conftest.py

@@ -244,10 +244,10 @@ nested_list_keys = Keys(
     [('integer',), ('nested_list', 'integer'), ('nested_list', 'string')],
     [
         ('integer',),
-        ('nested_list', 'integer', 0),
-        ('nested_list', 'integer', 1),
-        ('nested_list', 'string', 0),
-        ('nested_list', 'string', 1)
+        ('nested_list', 0, 'integer'),
+        ('nested_list', 1, 'integer'),
+        ('nested_list', 0, 'string'),
+        ('nested_list', 1, 'string')
     ]
 )
 
@@ -298,12 +298,12 @@ nested_list_expected_errors = ExpectedErrors(
     {},
     {
         ('integer',): 'type_error.integer',
-        ('nested', 'integer',): 'type_error.integer',
-        ('nested', 'string',): 'type_error.str'
+        ('nested_list', 'integer',): ['type_error.integer', 'type_error.integer'],
+        ('nested_list', 'string',): ['type_error.str', 'type_error.str']
     },
     {
         ('integer',): 'value_error.missing',
-        ('nested',): 'value_error.missing',
+        ('nested_list',): 'value_error.missing',
     }
 )
 

+ 12 - 4
tests/test_process.py

@@ -14,6 +14,7 @@ SCENARIOS = [
     'scenario_simple',
     'scenario_simple_list',
     'scenario_nested',
+    'scenario_nested_list',
     'scenario_double_nested'
 ]
 
@@ -96,7 +97,7 @@ def test_errors_valid(scenario, data):
     form = scenario.form(data=data)
     assert form.errors == {}
 
-
+import pprint
 @pytest.mark.parametrize(
     'data, errors_factory',
     [
@@ -115,10 +116,12 @@ def test_errors_invalid(scenario, data, errors_factory):
 
     form.validate()
     assert form.errors
-
+    print()
+    pprint.pprint(form.translator.errors)
+    pprint.pprint(form.errors)
     for k, expected in errors_factory(scenario).items():
+        print(k, expected)
         field_errors = recursive_get(form.errors, *k)
-
         if isinstance(expected, list):
             errors = [e['type'] for e_list in field_errors for e in e_list]
             assert errors == expected
@@ -127,7 +130,7 @@ def test_errors_invalid(scenario, data, errors_factory):
             errors = [e['type'] for e in field_errors]
             assert len(errors) == 1
             assert errors[0] == expected
-
+from pydantic_form.translator import SchemaToForm
 
 @pytest.mark.parametrize(
     'data',
@@ -149,7 +152,12 @@ def test_baked_instance(scenario, data):
     assert form._baked_instance.dict() == data
 
     for key, sub_form in iter_form(form, leafs_only=False):
+
+        print(key, sub_form)
+        sub_form = SchemaToForm.get_field(form, key)
+        print(key, sub_form)
         if isinstance(sub_form, FormField):
+            print('?')
             assert hasattr(sub_form, '_baked_instance')
             assert sub_form._baked_instance is not None
             assert sub_form._baked_instance.dict() == recursive_get(data, *key)