瀏覽代碼

Add Translator object

theenglishway (time) 6 年之前
父節點
當前提交
bb2620dc2c
共有 4 個文件被更改,包括 178 次插入6 次删除
  1. 65 0
      pydantic_form/translator.py
  2. 30 0
      pydantic_form/utils.py
  3. 24 6
      tests/conftest.py
  4. 59 0
      tests/test_translator.py

+ 65 - 0
pydantic_form/translator.py

@@ -0,0 +1,65 @@
+from typing import Any
+from .utils import *
+
+
+class FieldTranslator:
+    def __init__(self, src, dest):
+        self.src = src
+        self.dest = dest
+
+    def get(self, key: tuple) -> Any:
+        raise NotImplementedError()
+
+    def transform(self, value: Any) -> Any:
+        raise NotImplementedError()
+
+    def set(self, src_key, value) -> None:
+        raise NotImplementedError()
+
+    def __call__(self, key):
+        return self.set(key, self.transform(self.get(key)))
+
+
+class InstanceTranslator:
+    field_translator_classes = (
+        ('valid', None),
+    )
+
+    def __init__(self, src, dest, keys):
+        self.field_translators = [
+            (k, v(src, dest)) for k, v in self.field_translator_classes
+        ]
+        self.keys = keys
+
+    def __call__(self, *args, **kwargs):
+        for k in self.keys:
+            for _, t in self.field_translators:
+                t(k)
+
+
+class SchemaInstanceToFormDataField(FieldTranslator):
+    def get(self, key: tuple):
+        return recursive_get(self.src.dict(), *key)
+
+    def transform(self, value: Any):
+        return value
+
+    def set(self, src_key, value):
+        rsetattr(self.dest, src_key + ('data',), value)
+
+
+class SchemaInstanceToFormErrorField(FieldTranslator):
+    def get(self, key: tuple):
+        return recursive_get(self.src.dict(), *key)
+
+    def transform(self, value: Any):
+        return value
+
+    def set(self, src_key, value):
+        rsetattr(self.dest, src_key + ('data',), value)
+
+
+class SchemaToForm(InstanceTranslator):
+    field_translator_classes = [
+        ('valid', SchemaInstanceToFormDataField),
+    ]

+ 30 - 0
pydantic_form/utils.py

@@ -0,0 +1,30 @@
+import functools
+
+
+def recursive_get(d, *keys):
+    """https://stackoverflow.com/a/28225747/8783170"""
+    return functools.reduce(lambda c, k: c.get(k, {}), keys, d)
+
+
+def recursive_setdefault(root, value, *keys):
+    """https://stackoverflow.com/a/21025122/8783170"""
+    inner = functools.reduce(lambda d, k: d.setdefault(k, {}), keys[:-1],
+                             root)
+    inner.update({keys[-1]: value})
+    return root
+
+
+def rsetattr(obj, attr, val):
+    """https://stackoverflow.com/a/31174427/8783170"""
+    attr = '.'.join(attr)
+    pre, _, post = attr.rpartition('.')
+    return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+    """https://stackoverflow.com/a/31174427/8783170"""
+
+    def _getattr(obj, attr):
+        return getattr(obj, attr, *args)
+
+    return functools.reduce(_getattr, [obj] + attr.split('.'))

+ 24 - 6
tests/conftest.py

@@ -4,12 +4,12 @@ from collections import namedtuple
 from wtforms import Form, fields
 from werkzeug.datastructures import ImmutableMultiDict
 from pydantic_form import PydanticForm
-from pydantic import BaseModel
+from pydantic import BaseModel, ValidationError
 
 
 ScenarioClasses = namedtuple(
     "ScenarioClasses",
-    ['wtf_form', 'pydantic_form', 'schema', 'data_factory']
+    ['wtf_form', 'pydantic_form', 'schema', 'data_factory', 'bad_data_factory', 'keys']
 )
 
 ScenarioInstances = namedtuple(
@@ -17,7 +17,8 @@ ScenarioInstances = namedtuple(
     [
         'wtf', 'wtf_formdata',
         'pydantic', 'pydantic_formdata',
-        'formdata', 'data'
+        'formdata', 'data',
+        'schema'
     ]
 )
 
@@ -26,16 +27,22 @@ ScenarioInstances = namedtuple(
 def instance_factory():
     def _factory(scenario_classes, data):
         formdata = ImmutableMultiDict(data)
+        try:
+            schema_instance = scenario_classes.schema(**data)
+        except ValidationError as e:
+            schema_instance = None
         return ScenarioInstances(
             scenario_classes.wtf_form(data=data),
             scenario_classes.wtf_form(formdata=formdata),
             scenario_classes.pydantic_form(data=data),
             scenario_classes.pydantic_form(formdata=formdata),
             formdata,
-            data
+            data,
+            schema_instance
         )
     return _factory
 
+simple_keys = [('integer',), ('string',)]
 
 class SimpleSchema(BaseModel):
     integer: int
@@ -55,7 +62,12 @@ class SimpleForm(SimpleWTForm, PydanticForm):
 
 @pytest.fixture(scope="session")
 def scenario_classes_simple():
-    return ScenarioClasses(SimpleWTForm, SimpleForm, SimpleSchema, SimpleDataFactory)
+    return ScenarioClasses(
+        SimpleWTForm, SimpleForm,
+        SimpleSchema,
+        SimpleDataFactory, SimpleBadDataFactory,
+        simple_keys
+    )
 
 
 class SimpleDataFactory(factory.Factory):
@@ -83,6 +95,7 @@ def scenario_simple_bad(instance_factory, scenario_classes_simple):
     return instance_factory(scenario_classes_simple, SimpleBadDataFactory())
 
 
+nested_keys = [('integer',), ('nested', 'integer'), ('nested', 'string')]
 
 class NestedSchema(BaseModel):
     integer: int
@@ -118,7 +131,12 @@ class NestedBadDataFactory(factory.Factory):
 
 @pytest.fixture(scope="session")
 def scenario_classes_nested():
-    return ScenarioClasses(NestedWTForm, NestedForm, NestedSchema, NestedDataFactory)
+    return ScenarioClasses(
+        NestedWTForm, NestedForm,
+        NestedSchema,
+        NestedDataFactory, NestedBadDataFactory,
+        nested_keys
+    )
 
 
 @pytest.fixture

+ 59 - 0
tests/test_translator.py

@@ -0,0 +1,59 @@
+import pytest
+from pydantic_form.translator import SchemaInstanceToFormDataField, SchemaToForm
+
+
+@pytest.mark.parametrize('translator', [SchemaInstanceToFormDataField])
+@pytest.mark.parametrize(
+    'scenario_name',
+    ['scenario_classes_simple', 'scenario_classes_nested']
+)
+def test_data_translator(request, scenario_name, translator):
+    scenario = request.getfixturevalue(scenario_name)
+    data = scenario.data_factory()
+    schema = scenario.schema(**data)
+    form = scenario.wtf_form()
+    keys = scenario.keys
+
+    t = translator(schema, form)
+    for k in keys:
+        t(k)
+
+    assert schema.dict() == form.data
+
+@pytest.mark.skip
+@pytest.mark.parametrize('translator', [SchemaInstanceToFormDataField])
+@pytest.mark.parametrize(
+    'scenario_name',
+    ['scenario_classes_simple', 'scenario_classes_nested']
+)
+def test_error_translator_bad(request, scenario_name, translator):
+    scenario = request.getfixturevalue(scenario_name)
+    data = scenario.data_factory()
+    schema = scenario.schema(**data)
+    form = scenario.wtf_form()
+    keys = scenario.keys
+
+    t = translator(schema, form)
+    for k in keys:
+        t(k)
+
+    assert schema.dict() == form.data
+
+
+@pytest.mark.parametrize('translator', [SchemaToForm])
+@pytest.mark.parametrize('data_factory_name', ['data_factory', 'bad_data_factory'])
+@pytest.mark.parametrize(
+    'scenario_name',
+    ['scenario_classes_simple', 'scenario_classes_nested']
+)
+def test_translator(request, scenario_name, translator, data_factory_name):
+    scenario = request.getfixturevalue(scenario_name)
+    data = getattr(scenario, data_factory_name)()
+    schema = scenario.schema(**data)
+    form = scenario.wtf_form()
+    keys = scenario.keys
+
+    t = SchemaToForm(schema, form, keys)
+    t()
+
+    assert schema.dict() == form.data