Browse Source

Distinct iterator on schema/schema_class

theenglishway (time) 6 years ago
parent
commit
3c9ff394d8
2 changed files with 34 additions and 9 deletions
  1. 31 6
      pydantic_form/iterators.py
  2. 3 3
      tests/test_iterators.py

+ 31 - 6
pydantic_form/iterators.py

@@ -1,4 +1,7 @@
 from wtforms import Form, FormField
+from pydantic.fields import Shape
+
+from .utils import *
 
 
 def iter_field(parent_class, leafs_only=True, path=()):
@@ -42,13 +45,35 @@ def iter_form(form, leafs_only=True, path=()):
     else:
         yield path, field
 
-def iter_schema(schema, leafs_only=True, path=()):
-    type_ = schema.type_ if path else schema
+def iter_schema_class_field(field, leafs_only=True, path=(), parent=None):
+    type_ = field.type_
     if hasattr(type_, '__fields__'):
         if path and not leafs_only:
-            yield path, schema
+            yield path, field, None
 
-        for key, value in type_.__fields__.items():
-            yield from iter_schema(value, leafs_only, path + (key,))
+        for key, sub_field in type_.__fields__.items():
+            yield from iter_schema_class_field(sub_field, leafs_only, path + (key,), field)
     else:
-        yield path, schema
+        iterate_on = None
+        if parent and parent.shape != Shape.SINGLETON:
+            iterate_on = path[:-1]
+        elif field.shape != Shape.SINGLETON:
+            iterate_on = path
+
+        yield path, field, iterate_on
+
+def iter_schema_class(schema, leafs_only=True, path=()):
+    for key, field in schema.__fields__.items():
+        yield from iter_schema_class_field(field, leafs_only, path + (key,))
+
+def iter_schema(schema, leafs_only=True):
+    for key, field, iterate_on in iter_schema_class(schema.__class__, leafs_only):
+        if iterate_on:
+            sub_values = rgetattr(schema, iterate_on)
+            if isinstance(sub_values, list):
+                for n, sub_value in enumerate(sub_values):
+                    yield key + (n,), sub_value
+            else:
+                raise NotImplementedError()
+            continue
+        yield key, rgetattr(schema, key)

+ 3 - 3
tests/test_iterators.py

@@ -1,5 +1,5 @@
 import pytest
-from pydantic_form.iterators import iter_form, iter_schema, iter_form_class
+from pydantic_form.iterators import iter_form, iter_schema, iter_form_class, iter_schema_class
 
 
 @pytest.fixture
@@ -31,10 +31,10 @@ def test_iterator_schema(scenario):
     data = scenario.data_factory.valid()
     keys = scenario.keys
 
-    assert [k for k, _ in iter_schema(scenario.schema)] == keys.class_
+    assert [k for k, _, _ in iter_schema_class(scenario.schema)] == keys.class_
 
     schema_instance = scenario.schema(**data)
-    assert [k for k, _ in iter_schema(scenario.schema)] == keys.class_
+    assert [k for k, _, _ in iter_schema_class(scenario.schema)] == keys.class_
     assert [k for k, _ in iter_schema(schema_instance)] == keys.instance