Parcourir la source

Handle User model in DB

theenglishway (time) il y a 7 ans
Parent
commit
e0db25666d

+ 6 - 1
twhatter/cli.py

@@ -57,7 +57,12 @@ def timeline(ctx, limit, user):
     tweets = [
         Tweet.from_raw(t) for n, t in enumerate(timeline) if n < limit
     ]
-    ctx.obj['db'].add_all(*tweets)
+    profiles = set()
+    for t in timeline:
+        p = ClientProfile(t.screen_name)
+        profiles.add(p)
+    users = [User.from_raw(p.user) for p in profiles]
+    ctx.obj['db'].add_all(*users, *tweets)
 
 
 @db.command()

+ 4 - 0
twhatter/log.py

@@ -31,14 +31,18 @@ def log_setup(verbosity):
     if verbosity == 'verbose':
         logging.getLogger('twhatter.client').setLevel(logging.DEBUG)
         logging.getLogger('twhatter.parser').setLevel(logging.DEBUG)
+        logging.getLogger('twhatter.output').setLevel(logging.DEBUG)
     elif verbosity == 'debug':
         logging.getLogger('twhatter.client').setLevel(logging.DEBUG)
         logging.getLogger('twhatter.parser').setLevel(logging.INFO)
+        logging.getLogger('twhatter.output').setLevel(logging.INFO)
     elif verbosity == 'info':
         logging.getLogger('twhatter.client').setLevel(logging.INFO)
         logging.getLogger('twhatter.parser').setLevel(logging.INFO)
+        logging.getLogger('twhatter.output').setLevel(logging.INFO)
     elif verbosity == 'none':
         logging.getLogger('twhatter.client').setLevel(logging.WARNING)
         logging.getLogger('twhatter.parser').setLevel(logging.WARNING)
+        logging.getLogger('twhatter.output').setLevel(logging.WARNING)
 
     logging.config.dictConfig(LOGGING)

+ 22 - 5
twhatter/output/sqlalchemy/db.py

@@ -1,3 +1,5 @@
+import logging
+
 from sqlalchemy import create_engine
 from sqlalchemy.orm import scoped_session, sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
@@ -10,6 +12,7 @@ Base = declarative_base()
 # Session maker
 Session = scoped_session(sessionmaker(autoflush=False))
 
+logger = logging.getLogger(__name__)
 
 class Database:
     def __init__(self, db_url):
@@ -25,10 +28,24 @@ class Database:
         session.close()
 
     def add_all(self, *objs):
+        logger.info("Adding {} objects".format(len(objs)))
         session = self.session_maker()
-        session.add_all(objs)
-        try:
-            session.commit()
-        except IntegrityError:
-            print("Some objects could not be inserted")
+
+        unique_errors = 0
+        # This is an extremely unefficient way to add objects to the database,
+        # but the only way I've found so far to deal with duplications
+        for o in objs:
+            session.add(o)
+            try:
+                session.commit()
+            except IntegrityError as e:
+                logger.debug("Error on commit : {}".format(e))
+                unique_errors += 1
+                session.rollback()
+
+        if unique_errors:
+            logger.info(
+                "{} objects were already in the database".format(unique_errors)
+            )
+
         session.close()

+ 1 - 0
twhatter/output/sqlalchemy/models/__init__.py

@@ -1,2 +1,3 @@
 from .tweets import Tweet
 from .user import User
+from .media import Media

+ 26 - 0
twhatter/output/sqlalchemy/models/media.py

@@ -0,0 +1,26 @@
+from dataclasses import asdict
+
+from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.orm import relationship
+
+from twhatter.output.sqlalchemy.db import Base
+
+
+class Media(Base):
+    __tablename__ = 'medias'
+
+    id = Column(Integer, primary_key=True)
+    _images_list = Column("images_list", String)
+
+    @hybrid_property
+    def images_list(self):
+        return self._images_list.split(',')
+
+    @images_list.setter
+    def images_list(self, value):
+        self._images_list = ','.join(value)
+
+    @images_list.expression
+    def images_list(cls):
+        return cls._images_list

+ 3 - 1
twhatter/output/sqlalchemy/models/tweets.py

@@ -23,6 +23,7 @@ class Tweet(Base):
     _mention_list = Column(String)
 
     user = relationship('User', backref='tweets')
+    #media = relationship('Media', backref='tweets')
 
     @hybrid_property
     def hashtag_list(self):
@@ -60,7 +61,8 @@ class Tweet(Base):
                 'retweeter', 'retweet_id',
                 'reacted_id', 'reacted_user_id',
                 'link_to', 'soup',
-                'hashtag_list', 'mention_list'
+                'hashtag_list', 'mention_list',
+                'media'
             ]
         }
         return cls(**kwargs)