Explorar o código

Add OutputBase to put output logic where it belongs

theenglishway (time) %!s(int64=7) %!d(string=hai) anos
pai
achega
a57d28fd61

+ 10 - 22
twhatter/cli.py

@@ -6,6 +6,7 @@ import click
 import IPython
 
 from twhatter.client import ClientTimeline, ClientProfile
+from twhatter.output import Print
 from twhatter.output.sqlalchemy import Database, Tweet, User
 from twhatter.log import log_setup
 
@@ -18,25 +19,24 @@ from twhatter.log import log_setup
 def main(ctx, verbosity):
     log_setup(verbosity)
     ctx.ensure_object(dict)
+    ctx.obj['stdout'] = Print()
 
 
 @main.command()
 @click.option('-l', '--limit', type=int, default=100, show_default=True)
 @click.argument('user')
-def timeline(limit, user):
+@click.pass_context
+def timeline(ctx, limit, user):
     """Get some user's Tweets"""
-    timeline = ClientTimeline(user, limit)
-
-    for t in timeline:
-        click.echo(t)
+    ctx.obj['stdout'].output_tweets(user, limit)
 
 
 @main.command()
 @click.argument('user')
-def profile(user):
+@click.pass_context
+def profile(ctx, user):
     """Get basic info about some user"""
-    p = ClientProfile(user)
-    click.echo(p.user)
+    ctx.obj['stdout'].output_user(user)
 
 
 @main.group()
@@ -52,17 +52,7 @@ def db(ctx, db_url):
 @click.pass_context
 def timeline(ctx, limit, user):
     """Push user's Tweets into a database"""
-    timeline = ClientTimeline(user, limit)
-
-    tweets = [
-        Tweet.from_raw(t) for n, t in enumerate(timeline) if n < limit
-    ]
-    profiles = set()
-    for t in timeline:
-        p = ClientProfile(t.screen_name)
-        profiles.add(p)
-    users = [User.from_raw(p.user) for p in profiles]
-    ctx.obj['db'].add_all(*users, *tweets)
+    ctx.obj['db'].output_tweets(user, limit)
 
 
 @db.command()
@@ -70,9 +60,7 @@ def timeline(ctx, limit, user):
 @click.pass_context
 def profile(ctx, user):
     """Push some user into a database"""
-    p = ClientProfile(user)
-
-    ctx.obj['db'].add_all(User.from_raw(p.user))
+    ctx.obj['db'].output_user(user)
 
 
 @db.command()

+ 1 - 0
twhatter/log.py

@@ -27,6 +27,7 @@ LOGGING = {
 
 def log_setup(verbosity):
     logging.getLogger('urllib3').setLevel(logging.WARNING)
+    logging.getLogger('parso').setLevel(logging.WARNING)
 
     if verbosity == 'verbose':
         logging.getLogger('twhatter.client').setLevel(logging.DEBUG)

+ 1 - 0
twhatter/output/__init__.py

@@ -1 +1,2 @@
 from .print import Print
+from .base import OutputBase

+ 10 - 0
twhatter/output/base.py

@@ -0,0 +1,10 @@
+class OutputBase:
+    """Base class for scraper's data output"""
+    def output_tweets(self, user, limit) -> None:
+        raise NotImplementedError()
+
+    def output_user(self, user) -> None:
+        raise NotImplementedError()
+
+    def output_medias(self, user) -> None:
+        raise NotImplementedError()

+ 13 - 5
twhatter/output/print.py

@@ -1,6 +1,14 @@
-class Print:
-    def __init__(self, tweet):
-        self.tweet = tweet
+from .base import OutputBase
+from twhatter.client import ClientTimeline, ClientProfile
 
-    def __call__(self, *args, **kwargs):
-        print(self.tweet)
+
+class Print(OutputBase):
+    def output_tweets(self, user, limit):
+        client_timeline = ClientTimeline(user, limit)
+
+        for t in client_timeline:
+            print(t)
+
+    def output_user(self, user):
+        p = ClientProfile(user)
+        print(p.user)

+ 47 - 17
twhatter/output/sqlalchemy/db.py

@@ -5,16 +5,22 @@ from sqlalchemy.orm import scoped_session, sessionmaker
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.exc import IntegrityError
 
+from twhatter.output import OutputBase
+from twhatter.client import ClientTimeline, ClientProfile
 
+
+# Registry of SQLAlchemy's models
+class_registry = {}
 # Base class for SQLAlchemy models
-Base = declarative_base()
+Base = declarative_base(class_registry=class_registry)
 
 # Session maker
 Session = scoped_session(sessionmaker(autoflush=False))
 
 logger = logging.getLogger(__name__)
 
-class Database:
+
+class Database(OutputBase):
     def __init__(self, db_url):
         engine = create_engine(db_url)
         self.session_maker = Session
@@ -27,25 +33,49 @@ class Database:
     def stop(self, session):
         session.close()
 
-    def add_all(self, *objs):
-        logger.info("Adding {} objects".format(len(objs)))
-        session = self.session_maker()
-
-        unique_errors = 0
+    def _add_no_fail(self, session, obj):
         # This is an extremely unefficient way to add objects to the database,
         # but the only way I've found so far to deal with duplications
-        for o in objs:
-            session.add(o)
-            try:
-                session.commit()
-            except IntegrityError as e:
-                logger.debug("Error on commit : {}".format(e))
-                unique_errors += 1
-                session.rollback()
+        session.add(obj)
+        try:
+            session.commit()
+            return 1
+        except IntegrityError as e:
+            logger.debug("Error on commit : {}".format(e))
+            session.rollback()
+            return 0
+
+    def output_tweets(self, user, limit):
+        client_timeline = ClientTimeline(user, limit)
+        Tweet = class_registry['Tweet']
+        User = class_registry['User']
+        session = self.start()
+        tweets = [Tweet.from_raw(t) for t in client_timeline]
+        logger.info("Adding {} tweets".format(len(tweets)))
+
+        profiles = set()
+        for t in client_timeline:
+            p = ClientProfile(t.screen_name)
+            profiles.add(p)
+        users = [User.from_raw(p.user) for p in profiles]
+
+        unique_errors = 0
+        for u in users:
+            self._add_no_fail(session, u)
+        for t in tweets:
+            unique_errors += self._add_no_fail(session, t)
 
         if unique_errors:
             logger.info(
-                "{} objects were already in the database".format(unique_errors)
+                "{} tweets were already in the database".format(unique_errors)
             )
 
-        session.close()
+        self.stop(session)
+
+    def output_user(self, user):
+        User = class_registry['User']
+        p = ClientProfile(user)
+        session = self.start()
+
+        self._add_no_fail(session, User.from_raw(p.user))
+        self.stop(session)

+ 4 - 2
twhatter/parser/__init__.py

@@ -1,8 +1,8 @@
 from .tweet import (TweetList, TweetBase,
                     tweet_factory,
                     TweetTextOnly, TweetLink, TweetReaction, TweetRetweet)
-from .user import user_factory
-from .media import MediaImage, media_factory
+from .user import User, user_factory
+from .media import MediaBase, MediaImage, media_factory
 
 __all__= [
     "TweetList",
@@ -13,8 +13,10 @@ __all__= [
     "TweetReaction",
     "TweetRetweet",
 
+    "User",
     "user_factory",
 
+    "MediaBase",
     "MediaImage",
     "media_factory"
 ]