浏览代码

Ensure that db tweets have all the fields from parsed tweets

... expect Media, for now
theenglishway (time) 7 年之前
父节点
当前提交
7c8ea927f4
共有 2 个文件被更改,包括 53 次插入26 次删除
  1. 34 12
      tests/output/test_db.py
  2. 19 14
      twhatter/output/sqlalchemy/models/tweets.py

+ 34 - 12
tests/output/test_db.py

@@ -1,6 +1,9 @@
 import pytest
+from dataclasses import fields
+
 from twhatter.output import Database
 from twhatter.output.sqlalchemy import Tweet
+from twhatter.parser.tweet import *
 
 
 @pytest.fixture
@@ -19,18 +22,19 @@ def session(output):
     yield session
     session.close()
 
-@pytest.mark.parametrize("fixtures_file", [
-    'tests/fixtures/tweets/text_only_10.yaml',
-    'tests/fixtures/tweets/retweet_10.yaml',
-    'tests/fixtures/tweets/link_10.yaml',
-    'tests/fixtures/tweets/reaction_9.yaml',
-], ids=[
-    "TextOnly",
-    "Retweet",
-    "Link",
-    "Reaction"
-])
-def test_output_tweets_presence(capsys, tweets_factory, output, fixtures_file, session):
+
+@pytest.fixture(scope="function")
+def tweets_output_factory(tweets_factory, output):
+    """Tweets that have been output"""
+    def _tweets_output_factory(fixtures_file):
+        tweets = tweets_factory(fixtures_file)
+        output.start()
+        output.output_tweets(tweets)
+        output.stop()
+        return tweets
+    return _tweets_output_factory
+
+
 testdata = [
     pytest.param('tests/fixtures/tweets/text_only_10.yaml', TweetTextOnly, id="text-only"),
     pytest.param('tests/fixtures/tweets/retweet_10.yaml', TweetRetweet, id="retweets"),
@@ -63,3 +67,21 @@ def test_output_tweets_twice(tweets_factory, output, fixtures_file, session, raw
 
     for t in tweets:
         assert session.query(Tweet).filter(Tweet.id == t.id).one()
+
+
+@pytest.mark.parametrize("field_name, fixtures_file, raw_tweet_cls", [
+    pytest.param(
+        field.name,
+        *td.values,
+        id="{}-{}".format(td.id, field.name)
+    )
+    for td in testdata
+    for field in fields(TweetTextOnly)
+    if field.name != 'media'
+])
+def test_output_tweets_attributes(tweets_output_factory, fixtures_file, session, raw_tweet_cls, field_name):
+    tweets = tweets_output_factory(fixtures_file)
+
+    for t in tweets:
+        db_tweet = session.query(Tweet).filter(Tweet.id == t.id).one()
+        assert getattr(db_tweet, field_name) == getattr(t, field_name)

+ 19 - 14
twhatter/output/sqlalchemy/models/tweets.py

@@ -21,18 +21,26 @@ class Tweet(Base):
     permalink = Column(String)
     text = Column(String)
     _hashtag_list = Column("hashtag_list", String)
-    _mention_list = Column(String)
+    _mention_list = Column("mention_list", String)
+
+    link_to = Column(String)
+
+    retweeter = Column(String)
+    retweet_id = Column(Integer)
+
+    reacted_id = Column(Integer)
+    reacted_user_id = Column(Integer)
 
     user = relationship('User', backref='tweets')
     #media = relationship('Media', backref='tweets')
 
     @hybrid_property
     def hashtag_list(self):
-        return self._hashtag_list.split(',')
+        return self._hashtag_list.split(',') if self._hashtag_list else []
 
     @hashtag_list.setter
-    def hashtag_list(self, value):
-        self._hashtag_list = ','.join(value)
+    def hashtag_list(self, values_list):
+        self._hashtag_list = ','.join(values_list)
 
     @hashtag_list.expression
     def hashtag_list(cls):
@@ -40,11 +48,14 @@ class Tweet(Base):
 
     @hybrid_property
     def mention_list(self):
-        return [int(user_id) for user_id in self._mention_list.split(',')]
+        if self._mention_list:
+            return [int(user_id) for user_id in self._mention_list.split(',')]
+        else:
+            return []
 
     @mention_list.setter
-    def mention_list(self, value):
-        self._mention_list = ','.join(value)
+    def mention_list(self, values_list):
+        self._mention_list = ','.join([str(v) for v in values_list])
 
     @mention_list.expression
     def mention_list(cls):
@@ -58,12 +69,6 @@ class Tweet(Base):
         kwargs = {
             k: v
             for k, v in asdict(raw_tweet).items()
-            if k not in [
-                'retweeter', 'retweet_id',
-                'reacted_id', 'reacted_user_id',
-                'link_to', 'soup',
-                'hashtag_list', 'mention_list',
-                'media'
-            ]
+            if k not in ['soup', 'media']
         }
         return cls(**kwargs)