Przeglądaj źródła

Add tweet subclasses

theenglishway (time) 7 lat temu
rodzic
commit
448e8a740a
3 zmienionych plików z 73 dodań i 10 usunięć
  1. 20 6
      tests/test_parser.py
  2. 11 1
      twhatter/parser/__init__.py
  3. 42 3
      twhatter/parser/tweet.py

+ 20 - 6
tests/test_parser.py

@@ -1,5 +1,5 @@
 import pytest
-from twhatter.parser import TweetList, Tweet
+from twhatter.parser import *
 
 
 class TestTweetList:
@@ -10,11 +10,11 @@ class TestTweetList:
     def test_iter(self, raw_html_user_initial_page):
         t_list = TweetList(raw_html_user_initial_page)
         for t in t_list:
-            assert isinstance(t, Tweet)
+            assert isinstance(t, TweetBase)
 
 
 class TestTweet:
-    @pytest.mark.parametrize("tweet_type", [
+    all_types = [
         "plain",
         "reaction_tweet",
         "with_link",
@@ -22,11 +22,13 @@ class TestTweet:
         "hashtags",
         "mentions",
         "stats",
-    ])
-    def test_plain_tweet(self, raw_tweet_factory, tweet_collection, tweet_type):
+    ]
+
+    @pytest.mark.parametrize("tweet_type", all_types)
+    def test_tweet(self, raw_tweet_factory, tweet_collection, tweet_type):
         tweet_info = tweet_collection[tweet_type]
         raw = raw_tweet_factory(tweet_info)
-        t = Tweet.extract(raw)
+        t = TweetBase.extract(raw)
         assert t
 
         for field, value in tweet_info._asdict().items():
@@ -36,3 +38,15 @@ class TestTweet:
             # not tested
             if value is not None:
                 assert getattr(t, field) == value
+
+    @pytest.mark.parametrize("tweet_type,expected_class", [
+        ('plain', TweetTextOnly),
+        ('reaction_tweet', TweetReaction),
+        ('with_link', TweetLink),
+        ('retweet', TweetRetweet)
+    ])
+    def test_tweet_type(self, raw_tweet_factory, tweet_collection, tweet_type, expected_class):
+        tweet_info = tweet_collection[tweet_type]
+        raw = raw_tweet_factory(tweet_info)
+        t = TweetBase.extract(raw)
+        assert isinstance(t, expected_class)

+ 11 - 1
twhatter/parser/__init__.py

@@ -1 +1,11 @@
-from .tweet import TweetList, Tweet
+from .tweet import (TweetList, TweetBase,
+                    TweetTextOnly, TweetLink, TweetReaction, TweetRetweet)
+
+__all__= [
+    "TweetList",
+    "TweetBase",
+    "TweetTextOnly",
+    "TweetLink",
+    "TweetReaction",
+    "TweetRetweet"
+]

+ 42 - 3
twhatter/parser/tweet.py

@@ -6,7 +6,7 @@ from typing import List
 
 
 @dataclass
-class Tweet:
+class TweetBase:
     #: Tweet ID
     id: int
     #: Handle of the tweet's original author
@@ -49,6 +49,10 @@ class Tweet:
     def __post_init__(self, soup):
         self.soup = soup
 
+    @staticmethod
+    def condition(kwargs):
+        raise NotImplementedError()
+
     @staticmethod
     def _extract_from_span(soup, distinct_span, data_kw):
         return (
@@ -193,7 +197,42 @@ class Tweet:
             return fn(soup)
 
         kwargs = {f.name: _extract_value(f) for f in fields(cls)}
-        return cls(soup=soup, **kwargs)
+
+        for kls in cls.__subclasses__():
+            try:
+                print(kls)
+                if kls.condition(kwargs):
+                    return kls(soup=soup, **kwargs)
+            except NotImplementedError:
+                continue
+        else:
+            return TweetTextOnly(soup=soup, **kwargs)
+
+
+class TweetTextOnly(TweetBase):
+    """An original tweet with only plain text"""
+
+
+class TweetLink(TweetBase):
+    """An original tweet with a link"""
+    @staticmethod
+    def condition(kwargs):
+        print(kwargs)
+        return kwargs['link_to']
+
+
+class TweetRetweet(TweetBase):
+    """A plain retweet"""
+    @staticmethod
+    def condition(kwargs):
+        return kwargs['retweet_id']
+
+
+class TweetReaction(TweetBase):
+    """A reaction to another tweet"""
+    @staticmethod
+    def condition(kwargs):
+        return kwargs['reacted_id']
 
 
 class TweetList:
@@ -205,7 +244,7 @@ class TweetList:
             # Don't know what this u-dir stuff is about but if it's in there,
             # it's not a tweet !
             if not tweet.find_all('p', class_="u-dir"):
-                yield Tweet.extract(tweet)
+                yield TweetBase.extract(tweet)
 
     def __len__(self):
         return len(self.raw_tweets)