Browse Source

Add basic database recording

theenglishway (time) 7 years ago
parent
commit
18ddad72c5

+ 1 - 0
Pipfile

@@ -24,3 +24,4 @@ bs4 = "*"
 lxml = "*"
 requests = "*"
 twhatter = {path = ".",editable = true}
+sqlalchemy = "*"

+ 9 - 2
tests/test_cli.py

@@ -14,7 +14,7 @@ def test_command_line_interface(cli_runner):
 
     help_result = cli_runner.invoke(cli.main, ['--help'])
     assert help_result.exit_code == 0
-    assert '--help  Show this message and exit.' in help_result.output
+    assert 'Show this message and exit.' in help_result.output
 
 
 class TestOwn:
@@ -31,8 +31,15 @@ class TestOwn:
 
     @pytest.mark.send_request
     def test_limit(self, cli_runner, user, tweet_limit):
-        result = cli_runner.invoke(cli.main, ['own', user, '--limit', tweet_limit])
+        result = cli_runner.invoke(cli.main, ['--limit', tweet_limit, 'own', user])
         assert result.exit_code == 0
 
         lines = result.output.split('\n')[:-1]
         assert len(lines) == tweet_limit
+
+
+class TestDb:
+    @pytest.mark.send_request
+    def test_no_limit(self, cli_runner, user):
+        result = cli_runner.invoke(cli.main, ['db', 'own', user])
+        assert result.exit_code == 0

+ 34 - 2
twhatter/cli.py

@@ -3,8 +3,10 @@
 
 """Console script for twhatter."""
 import click
+import IPython
 
 from twhatter.api import ApiUser
+from twhatter.output import Database, Tweet
 
 
 @click.group()
@@ -28,9 +30,39 @@ def own(ctx, user):
             break
 
         click.echo(t)
-            break
 
-        click.echo(t)
+
+@main.group()
+@click.option('-d', '--db_url', type=str, default="sqlite:////tmp/db.sqlite3", show_default=True)
+@click.pass_context
+def db(ctx, db_url):
+    ctx.obj['db'] = Database(db_url)
+
+
+@db.command()
+@click.argument('user')
+@click.pass_context
+def own(ctx, user):
+    """Push user's Tweets into a database"""
+    a = ApiUser(user)
+
+    tweets = [
+        Tweet.from_raw(t) for n, t in enumerate(a.iter_tweets()) if n < ctx.obj['limit']
+    ]
+    ctx.obj['db'].add_all(*tweets)
+
+
+@db.command()
+@click.pass_context
+def shell(ctx):
+    session = ctx.obj['db'].start()
+    user_ns = {
+        'db': ctx.obj['db'],
+        'session': session,
+        'Tweet': Tweet
+    }
+    IPython.start_ipython(argv=[], user_ns=user_ns)
+    ctx.obj['db'].stop(session)
 
 
 if __name__ == "__main__":

+ 1 - 0
twhatter/output/__init__.py

@@ -1 +1,2 @@
+from .sqlalchemy import *
 from .print import Print

+ 8 - 0
twhatter/output/sqlalchemy/__init__.py

@@ -0,0 +1,8 @@
+from .db import Database
+from .models import Tweet
+
+
+__all__ = [
+    'Database',
+    'Tweet'
+]

+ 34 - 0
twhatter/output/sqlalchemy/db.py

@@ -0,0 +1,34 @@
+from sqlalchemy import create_engine
+from sqlalchemy.orm import scoped_session, sessionmaker
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.exc import IntegrityError
+
+
+# Base class for SQLAlchemy models
+Base = declarative_base()
+
+# Session maker
+Session = scoped_session(sessionmaker(autoflush=False))
+
+
+class Database:
+    def __init__(self, db_url):
+        engine = create_engine(db_url)
+        self.session_maker = Session
+        self.session_maker.configure(bind=engine)
+        Base.metadata.create_all(engine)
+
+    def start(self):
+        return self.session_maker()
+
+    def stop(self, session):
+        session.close()
+
+    def add_all(self, *objs):
+        session = self.session_maker()
+        session.add_all(objs)
+        try:
+            session.commit()
+        except IntegrityError:
+            print("Some objects could not be inserted")
+        session.close()

+ 63 - 0
twhatter/output/sqlalchemy/models.py

@@ -0,0 +1,63 @@
+from dataclasses import asdict
+
+from sqlalchemy import Column, Integer, String, DateTime
+from sqlalchemy.ext.hybrid import hybrid_property
+
+from twhatter.output.sqlalchemy.db import Base
+
+
+class Tweet(Base):
+    __tablename__ = 'tweets'
+
+    id = Column(Integer, primary_key=True)
+    screen_name = Column(String)
+    user_id = Column(Integer)
+    comments_nb = Column(Integer)
+    retweets_nb = Column(Integer)
+    likes_nb = Column(Integer)
+    timestamp = Column(DateTime)
+    permalink = Column(String)
+    text = Column(String)
+    _hashtag_list = Column("hashtag_list", String)
+    _mention_list = Column(String)
+
+    @hybrid_property
+    def hashtag_list(self):
+        return self._hashtag_list.split(',')
+
+    @hashtag_list.setter
+    def hashtag_list(self, value):
+        self._hashtag_list = ','.join(value)
+
+    @hashtag_list.expression
+    def hashtag_list(cls):
+        return cls._hashtag_list
+
+    @hybrid_property
+    def mention_list(self):
+        return [int(user_id) for user_id in self._mention_list.split(',')]
+
+    @mention_list.setter
+    def mention_list(self, value):
+        self._mention_list = ','.join(value)
+
+    @mention_list.expression
+    def mention_list(cls):
+        return cls._mention_list
+
+    def __repr__(self):
+        return "<{0} (id={1.id})".format(self.__class__.__qualname__, self)
+
+    @classmethod
+    def from_raw(cls, raw_tweet):
+        kwargs = {
+            k: v
+            for k, v in asdict(raw_tweet).items()
+            if k not in [
+                'retweeter', 'retweet_id',
+                'reacted_id', 'reacted_user_id',
+                'link_to', 'soup',
+                'hashtag_list', 'mention_list'
+            ]
+        }
+        return cls(**kwargs)