| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import pytest
- from dataclasses import fields
- from twhatter.output import Database
- from twhatter.output.sqlalchemy import Tweet
- from twhatter.parser import *
- @pytest.fixture
- def db_url():
- return "sqlite://"
- @pytest.fixture(scope="function")
- def output(db_url):
- return Database(db_url)
- @pytest.fixture(scope="function")
- def session(output):
- session = output.session_maker()
- yield session
- session.close()
- class TestTweetsOutput:
- testdata = [
- pytest.param('tests/fixtures/tweets/text_only_10.yaml', TweetTextOnly,
- id="text-only"),
- pytest.param('tests/fixtures/tweets/retweet_10.yaml', TweetRetweet,
- id="retweets"),
- pytest.param('tests/fixtures/tweets/link_10.yaml', TweetLink,
- id="link"),
- pytest.param('tests/fixtures/tweets/reaction_9.yaml', TweetReaction,
- id="reaction"),
- ]
- @pytest.fixture(scope="function")
- def tweets_output_factory(self, tweets_factory, output):
- """Tweets that have been output"""
- def _tweets_output_factory(fixtures_file):
- tweets = tweets_factory(fixtures_file)
- output.start()
- output.output_tweets(tweets)
- output.stop()
- return tweets
- return _tweets_output_factory
- @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
- def test_presence(self, tweets_factory, output, fixtures_file, session, raw_class):
- tweets = tweets_factory(fixtures_file)
- output.start()
- output.output_tweets(tweets)
- output.stop()
- for t in tweets:
- assert session.query(Tweet).filter(Tweet.id == t.id).one()
- @pytest.mark.parametrize("fixtures_file, raw_class", testdata)
- def test_twice(self, tweets_factory, output, fixtures_file, session, raw_class):
- tweets = tweets_factory(fixtures_file)
- output.start()
- output.output_tweets(tweets)
- output.stop()
- output.start()
- output.output_tweets(tweets)
- output.stop()
- for t in tweets:
- assert session.query(Tweet).filter(Tweet.id == t.id).one()
- @pytest.mark.parametrize("field_name, fixtures_file, raw_tweet_cls", [
- pytest.param(
- field.name,
- *td.values,
- id="{}-{}".format(td.id, field.name)
- )
- for td in testdata
- for field in fields(TweetTextOnly)
- if field.name != 'media'
- ])
- def test_attributes(self, tweets_output_factory, fixtures_file, session, raw_tweet_cls, field_name):
- tweets = tweets_output_factory(fixtures_file)
- for t in tweets:
- db_tweet = session.query(Tweet).filter(Tweet.id == t.id).one()
- assert getattr(db_tweet, field_name) == getattr(t, field_name)
|