test_db.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import pytest
  2. from dataclasses import fields
  3. from twhatter.output import Database
  4. from twhatter.output.sqlalchemy import Tweet
  5. from twhatter.parser import *
  6. @pytest.fixture
  7. def db_url():
  8. return "sqlite://"
  9. @pytest.fixture(scope="function")
  10. def output(db_url):
  11. return Database(db_url)
  12. @pytest.fixture(scope="function")
  13. def session(output):
  14. session = output.session_maker()
  15. yield session
  16. session.close()
  17. class TestTweetsOutput:
  18. testdata = [
  19. pytest.param('tests/fixtures/tweets/text_only_10.yaml', TweetTextOnly,
  20. id="text-only"),
  21. pytest.param('tests/fixtures/tweets/retweet_10.yaml', TweetRetweet,
  22. id="retweets"),
  23. pytest.param('tests/fixtures/tweets/link_10.yaml', TweetLink,
  24. id="link"),
  25. pytest.param('tests/fixtures/tweets/reaction_9.yaml', TweetReaction,
  26. id="reaction"),
  27. ]
  28. @pytest.fixture(scope="function")
  29. def tweets_output_factory(self, tweets_factory, output):
  30. """Tweets that have been output"""
  31. def _tweets_output_factory(fixtures_file):
  32. tweets = tweets_factory(fixtures_file)
  33. output.start()
  34. output.output_tweets(tweets)
  35. output.stop()
  36. return tweets
  37. return _tweets_output_factory
  38. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  39. def test_presence(self, tweets_factory, output, fixtures_file, session, raw_class):
  40. tweets = tweets_factory(fixtures_file)
  41. output.start()
  42. output.output_tweets(tweets)
  43. output.stop()
  44. for t in tweets:
  45. assert session.query(Tweet).filter(Tweet.id == t.id).one()
  46. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  47. def test_twice(self, tweets_factory, output, fixtures_file, session, raw_class):
  48. tweets = tweets_factory(fixtures_file)
  49. output.start()
  50. output.output_tweets(tweets)
  51. output.stop()
  52. output.start()
  53. output.output_tweets(tweets)
  54. output.stop()
  55. for t in tweets:
  56. assert session.query(Tweet).filter(Tweet.id == t.id).one()
  57. @pytest.mark.parametrize("field_name, fixtures_file, raw_tweet_cls", [
  58. pytest.param(
  59. field.name,
  60. *td.values,
  61. id="{}-{}".format(td.id, field.name)
  62. )
  63. for td in testdata
  64. for field in fields(TweetTextOnly)
  65. if field.name != 'media'
  66. ])
  67. def test_attributes(self, tweets_output_factory, fixtures_file, session, raw_tweet_cls, field_name):
  68. tweets = tweets_output_factory(fixtures_file)
  69. for t in tweets:
  70. db_tweet = session.query(Tweet).filter(Tweet.id == t.id).one()
  71. assert getattr(db_tweet, field_name) == getattr(t, field_name)