Forráskód Böngészése

Use an enum for column types

jherve 1 éve
szülő
commit
5ca6864b59
2 módosított fájl, 28 hozzáadás és 16 törlés
  1. 13 12
      src/media_observer/storage.py
  2. 15 4
      src/media_observer/storage_abstraction.py

+ 13 - 12
src/media_observer/storage.py

@@ -12,6 +12,7 @@ from media_observer.article import (
 from media_observer.storage_abstraction import (
     Table,
     Reference,
+    ColumnType,
     Column,
     UniqueIndex,
     View,
@@ -26,8 +27,8 @@ table_sites = Table(
     name="sites",
     columns=[
         Column(name="id", primary_key=True),
-        Column(name="name", type_="TEXT"),
-        Column(name="original_url", type_="TEXT"),
+        Column(name="name", type_=ColumnType.Text),
+        Column(name="original_url", type_=ColumnType.Url),
     ],
 )
 table_frontpages = Table(
@@ -38,31 +39,31 @@ table_frontpages = Table(
             name="site_id",
             references=Reference("sites", "id", on_delete="cascade"),
         ),
-        Column(name="timestamp", type_="timestamp with time zone"),
-        Column(name="timestamp_virtual", type_="timestamp with time zone"),
-        Column(name="url_original", type_="TEXT"),
-        Column(name="url_snapshot", type_="TEXT"),
+        Column(name="timestamp", type_=ColumnType.TimestampTz),
+        Column(name="timestamp_virtual", type_=ColumnType.TimestampTz),
+        Column(name="url_original", type_=ColumnType.Url),
+        Column(name="url_snapshot", type_=ColumnType.Url),
     ],
 )
 table_articles = Table(
     name="articles",
     columns=[
         Column(name="id", primary_key=True),
-        Column(name="url", type_="TEXT"),
+        Column(name="url", type_=ColumnType.Url),
     ],
 )
 table_titles = Table(
     name="titles",
     columns=[
         Column(name="id", primary_key=True),
-        Column(name="text", type_="TEXT"),
+        Column(name="text", type_=ColumnType.Text),
     ],
 )
 table_main_articles = Table(
     name="main_articles",
     columns=[
         Column(name="id", primary_key=True),
-        Column(name="url", type_="TEXT"),
+        Column(name="url", type_=ColumnType.Url),
         Column(
             name="frontpage_id",
             references=Reference("frontpages", "id", on_delete="cascade"),
@@ -81,8 +82,8 @@ table_top_articles = Table(
     name="top_articles",
     columns=[
         Column(name="id", primary_key=True),
-        Column(name="url", type_="TEXT"),
-        Column(name="rank", type_="INTEGER"),
+        Column(name="url", type_=ColumnType.Url),
+        Column(name="rank", type_=ColumnType.Integer),
         Column(
             name="frontpage_id",
             references=Reference("frontpages", "id", on_delete="cascade"),
@@ -104,7 +105,7 @@ table_embeddings = Table(
         Column(
             name="title_id", references=Reference("titles", "id", on_delete="cascade")
         ),
-        Column(name="vector", type_="bytea"),
+        Column(name="vector", type_=ColumnType.Vector),
     ],
 )
 view_frontpages = View(

+ 15 - 4
src/media_observer/storage_abstraction.py

@@ -1,4 +1,5 @@
 from abc import ABC
+from enum import Enum, auto
 from datetime import datetime
 from attrs import frozen
 
@@ -33,21 +34,31 @@ class Reference:
         return f"{self.table_name} ({self.column_name}) {on_delete}"
 
 
+class ColumnType(Enum):
+    PrimaryKey = "SERIAL PRIMARY KEY"
+    References = "REFERENCES"
+    Text = "TEXT"
+    Url = "TEXT"
+    TimestampTz = "timestamp with time zone"
+    Integer = "INTEGER"
+    Vector = "bytea"
+
+
 @frozen
 class Column:
     name: str
-    type_: str | None = None
+    type_: ColumnType | None = None
     primary_key: bool = False
     references: Reference | None = None
 
     @property
     def attrs(self):
         if self.primary_key:
-            return "SERIAL PRIMARY KEY"
+            return ColumnType.PrimaryKey.value
         elif self.references is not None:
-            return f"INTEGER REFERENCES {self.references.as_sql}"
+            return f"{ColumnType.Integer.value} {ColumnType.References.value} {self.references.as_sql}"
         elif self.type_ is not None:
-            return self.type_
+            return self.type_.value
         else:
             raise ValueError("Missing informations in column")