test_db.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import pytest
  2. from dataclasses import fields
  3. from twhatter.output import Database
  4. from twhatter.output.sqlalchemy import Tweet, User as DbUser
  5. from twhatter.parser import *
  6. @pytest.fixture
  7. def db_url():
  8. return "sqlite://"
  9. @pytest.fixture(scope="function")
  10. def output(db_url):
  11. return Database(db_url)
  12. @pytest.fixture(scope="function")
  13. def session(output):
  14. session = output.session_maker()
  15. yield session
  16. session.close()
  17. class TestTweetsOutput:
  18. testdata = [
  19. pytest.param('tests/fixtures/tweets/text_only_10.yaml', TweetTextOnly,
  20. id="text-only"),
  21. pytest.param('tests/fixtures/tweets/retweet_10.yaml', TweetRetweet,
  22. id="retweets"),
  23. pytest.param('tests/fixtures/tweets/link_10.yaml', TweetLink,
  24. id="link"),
  25. pytest.param('tests/fixtures/tweets/reaction_9.yaml', TweetReaction,
  26. id="reaction"),
  27. ]
  28. @pytest.fixture(scope="function")
  29. def tweets_output_factory(self, tweets_factory, output):
  30. """Tweets that have been output"""
  31. def _tweets_output_factory(fixtures_file):
  32. tweets = tweets_factory(fixtures_file)
  33. output.start()
  34. output.output_tweets(tweets)
  35. output.stop()
  36. return tweets
  37. return _tweets_output_factory
  38. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  39. def test_presence(self, tweets_factory, output, fixtures_file, session, raw_class):
  40. tweets = tweets_factory(fixtures_file)
  41. output.start()
  42. output.output_tweets(tweets)
  43. output.stop()
  44. for t in tweets:
  45. assert session.query(Tweet).filter(Tweet.id == t.id).one()
  46. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  47. def test_twice(self, tweets_factory, output, fixtures_file, session, raw_class):
  48. tweets = tweets_factory(fixtures_file)
  49. output.start()
  50. output.output_tweets(tweets)
  51. output.stop()
  52. output.start()
  53. output.output_tweets(tweets)
  54. output.stop()
  55. for t in tweets:
  56. assert session.query(Tweet).filter(Tweet.id == t.id).one()
  57. @pytest.mark.parametrize("field_name, fixtures_file, raw_tweet_cls", [
  58. pytest.param(
  59. field.name,
  60. *td.values,
  61. id="{}-{}".format(td.id, field.name)
  62. )
  63. for td in testdata
  64. for field in fields(TweetTextOnly)
  65. if field.name != 'media'
  66. ])
  67. def test_attributes(self, tweets_output_factory, fixtures_file, session, raw_tweet_cls, field_name):
  68. tweets = tweets_output_factory(fixtures_file)
  69. for t in tweets:
  70. db_tweet = session.query(Tweet).filter(Tweet.id == t.id).one()
  71. assert getattr(db_tweet, field_name) == getattr(t, field_name)
  72. class TestUsersOutput:
  73. testdata = [
  74. pytest.param('tests/fixtures/users/users_1.yaml', User,
  75. id="users"),
  76. ]
  77. @pytest.fixture(scope="function")
  78. def users_output_factory(self, users_factory, output):
  79. """Factory for users that have been output"""
  80. def _users_output_factory(fixtures_file):
  81. users = users_factory(fixtures_file)
  82. output.start()
  83. output.output_users(users)
  84. output.stop()
  85. return users
  86. return _users_output_factory
  87. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  88. def test_presence(self, users_factory, output, fixtures_file, session, raw_class):
  89. users = users_factory(fixtures_file)
  90. output.start()
  91. output.output_users(users)
  92. output.stop()
  93. for u in users:
  94. assert session.query(DbUser).filter(DbUser.id == u.id).one()
  95. @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
  96. def test_twice(self, users_factory, output, fixtures_file, session, raw_class):
  97. users = users_factory(fixtures_file)
  98. output.start()
  99. output.output_users(users)
  100. output.stop()
  101. output.start()
  102. output.output_users(users)
  103. output.stop()
  104. for u in users:
  105. assert session.query(DbUser).filter(DbUser.id == u.id).one()
  106. @pytest.mark.parametrize("field_name, fixtures_file, raw_tweet_cls", [
  107. pytest.param(
  108. field.name,
  109. *td.values,
  110. id="{}-{}".format(td.id, field.name)
  111. )
  112. for td in testdata
  113. for field in fields(User)
  114. ])
  115. def test_attributes(self, users_output_factory, fixtures_file, session, raw_tweet_cls, field_name):
  116. users = users_output_factory(fixtures_file)
  117. for u in users:
  118. db_tweet = session.query(DbUser).filter(DbUser.id == u.id).one()
  119. assert getattr(db_tweet, field_name) == getattr(u, field_name)