test_db.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import pytest
  2. from dataclasses import fields
  3. from twhatter.output import Database
  4. from twhatter.output.sqlalchemy import Tweet
  5. from twhatter.parser.tweet import *
  6. @pytest.fixture
  7. def db_url():
  8. return "sqlite://"
  9. @pytest.fixture(scope="function")
  10. def output(db_url):
  11. return Database(db_url)
  12. @pytest.fixture(scope="function")
  13. def session(output):
  14. session = output.session_maker()
  15. yield session
  16. session.close()
  17. @pytest.fixture(scope="function")
  18. def tweets_output_factory(tweets_factory, output):
  19. """Tweets that have been output"""
  20. def _tweets_output_factory(fixtures_file):
  21. tweets = tweets_factory(fixtures_file)
  22. output.start()
  23. output.output_tweets(tweets)
  24. output.stop()
  25. return tweets
  26. return _tweets_output_factory
  27. testdata = [
  28. pytest.param('tests/fixtures/tweets/text_only_10.yaml', TweetTextOnly, id="text-only"),
  29. pytest.param('tests/fixtures/tweets/retweet_10.yaml', TweetRetweet, id="retweets"),
  30. pytest.param('tests/fixtures/tweets/link_10.yaml', TweetLink, id="link"),
  31. pytest.param('tests/fixtures/tweets/reaction_9.yaml', TweetReaction, id="reaction"),
  32. ]
  33. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  34. def test_output_tweets_presence(tweets_factory, output, fixtures_file, session, raw_class):
  35. tweets = tweets_factory(fixtures_file)
  36. output.start()
  37. output.output_tweets(tweets)
  38. output.stop()
  39. for t in tweets:
  40. assert session.query(Tweet).filter(Tweet.id == t.id).one()
  41. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  42. def test_output_tweets_twice(tweets_factory, output, fixtures_file, session, raw_class):
  43. tweets = tweets_factory(fixtures_file)
  44. output.start()
  45. output.output_tweets(tweets)
  46. output.stop()
  47. output.start()
  48. output.output_tweets(tweets)
  49. output.stop()
  50. for t in tweets:
  51. assert session.query(Tweet).filter(Tweet.id == t.id).one()
  52. @pytest.mark.parametrize("field_name, fixtures_file, raw_tweet_cls", [
  53. pytest.param(
  54. field.name,
  55. *td.values,
  56. id="{}-{}".format(td.id, field.name)
  57. )
  58. for td in testdata
  59. for field in fields(TweetTextOnly)
  60. if field.name != 'media'
  61. ])
  62. def test_output_tweets_attributes(tweets_output_factory, fixtures_file, session, raw_tweet_cls, field_name):
  63. tweets = tweets_output_factory(fixtures_file)
  64. for t in tweets:
  65. db_tweet = session.query(Tweet).filter(Tweet.id == t.id).one()
  66. assert getattr(db_tweet, field_name) == getattr(t, field_name)