Browse Source

Add "edit" function

theenglishway (time) 6 năm trước cách đây
mục cha
commit
c5f4acd6ed
3 tập tin đã thay đổi với 91 bổ sung21 xóa
  1. 52 11
      pyplanner/cli.py
  2. 18 0
      pyplanner/converters.py
  3. 21 10
      pyplanner/database.py

+ 52 - 11
pyplanner/cli.py

@@ -3,6 +3,8 @@ from .database import Database
 from .output import TerminalOutput
 from .models import *
 from .utils import Term
+from .converters import ConverterYaml
+import io
 
 
 @click.group()
@@ -13,14 +15,24 @@ def main(ctx, db_url):
     ctx.ensure_object(dict)
     ctx.obj['db'] = Database(db_url)
     ctx.obj['output'] = TerminalOutput()
+    converter = ConverterYaml()
+    converter.register(ctx.obj['db'].models.values())
+    ctx.obj['converter'] = converter
 
 
 @main.command()
 @click.option('-f', '--file', type=str, default='./data.yaml', show_default=True)
 @click.pass_context
 def dump(ctx, file):
+    """Dump data into a text file"""
     db = ctx.obj['db']
-    db.dump(file)
+    converter = ctx.obj['converter']
+    instances_dict = db.dump()
+
+    converter.register(instances_dict.keys())
+    with open(file, 'w') as f:
+        converter.dump({k.__name__: v for k, v in instances_dict.items()}, f)
+
     click.echo(Term.success(f"Data dumped to {file}"))
 
 
@@ -28,11 +40,45 @@ def dump(ctx, file):
 @click.option('-f', '--file', type=str, default='./data.yaml', show_default=True)
 @click.pass_context
 def load(ctx, file):
+    """Load data from a text file"""
     db = ctx.obj['db']
-    db.load(file)
+    converter = ctx.obj['converter']
+
+    with open(file) as f:
+        instances_dict = converter.load(f)
+
+    db.load(instances_dict)
     click.echo(Term.success(f"Data loaded from {file}"))
 
 
+@main.command()
+@click.argument('short_uuid')
+@click.pass_context
+def edit(ctx, short_uuid):
+    """Edit an object"""
+    db = ctx.obj['db']
+    converter = ctx.obj['converter']
+
+    obj = db.search(short_uuid)
+    if obj:
+        model = type(obj)
+
+        editable = io.StringIO('')
+        converter.dump(obj, editable)
+        before_edit = editable.getvalue()
+
+        edited = click.edit(before_edit)
+
+        if edited and edited != before_edit:
+            db.edit(model, converter.load(edited))
+            click.echo(Term.success(f"Successfully edited {model.__name__} {short_uuid}"))
+        else:
+            click.echo(Term.mildly_important(f"No changes were made"))
+
+    else:
+        click.echo(Term.failure(f"Could not find UUID {short_uuid}"))
+
+
 @main.group()
 @click.pass_context
 def milestone(ctx):
@@ -70,7 +116,7 @@ def new(ctx, **data):
     model = ctx.obj['model']
     db = ctx.obj['db']
     try:
-        instance = db.add(model, data)
+        instance = db.create(model, data)
         click.echo(Term.success(instance))
     except ValueError as e:
         click.echo(Term.failure(e.args[0]))
@@ -88,14 +134,9 @@ def list(ctx):
         click.echo(output.output_data(i))
 
 
-milestone.add_command(new)
-milestone.add_command(list)
-sprint.add_command(new)
-sprint.add_command(list)
-item.add_command(new)
-item.add_command(list)
-comment.add_command(new)
-comment.add_command(list)
+for group in [milestone, sprint, item, comment]:
+    for command in [new, list]:
+        group.add_command(command)
 
 
 if __name__ == '__main__':

+ 18 - 0
pyplanner/converters.py

@@ -0,0 +1,18 @@
+import ruamel
+
+
+class ConverterYaml:
+    def __init__(self):
+        self.yaml = ruamel.yaml.YAML()
+        self.yaml.default_flow_style = False
+        self.yaml.allow_unicode = True
+
+    def register(self, class_list):
+        for class_ in class_list:
+            self.yaml.register_class(class_)
+
+    def dump(self, obj, file):
+        return self.yaml.dump(obj, file)
+
+    def load(self, file):
+        return self.yaml.load(file)

+ 21 - 10
pyplanner/database.py

@@ -25,7 +25,7 @@ class Database:
             if inspect.isclass(v) and issubclass(v, SQLABase)
         }
 
-    def add(self, model, data):
+    def create(self, model, data):
         session = self.session_factory()
         data.update({'uuid': uuid.uuid4().hex})
 
@@ -38,11 +38,27 @@ class Database:
         else:
             raise ValueError(errors)
 
+    def edit(self, model, data):
+        session = self.session_factory()
+        m = model(**data)
+        session.merge(m)
+        session.commit()
+        return m
+
     def list(self, model):
         session = self.session_factory()
         return session.query(model).all()
 
-    def dump(self, file):
+    def search(self, uuid):
+        session = self.session_factory()
+        obj = None
+        for k, v in self.models.items():
+            obj = session.query(v).filter(v.uuid.startswith(uuid)).one_or_none()
+            if obj:
+                break
+        return obj
+
+    def dump(self):
         session = self.session_factory()
         instances_dict = {}
 
@@ -50,16 +66,11 @@ class Database:
             if v.__subclasses__():
                 continue
             instances = [i for i in session.query(v).all()]
-            instances_dict[k] = instances
-            yaml.register_class(v)
-
-        with open(file, 'w') as f:
-            yaml.dump(instances_dict, f)
+            instances_dict[v] = instances
 
-    def load(self, file):
-        with open(file) as f:
-            instances_dict = yaml.load(f)
+        return instances_dict
 
+    def load(self, instances_dict):
         session = self.session_factory()
         for k, v_list in instances_dict.items():
             model = self.models[k]