Browse Source

Tweak database sessions

theenglishway (time) 7 years ago
parent
commit
e4e3180631
2 changed files with 26 additions and 12 deletions
  1. 24 10
      tests/output/test_db.py
  2. 2 2
      twhatter/output/sqlalchemy/db.py

+ 24 - 10
tests/output/test_db.py

@@ -13,30 +13,45 @@ def output(db_url):
     return Database(db_url)
 
 
+@pytest.fixture(scope="function")
+def session(output):
+    session = output.session_maker()
+    yield session
+    session.close()
+
 @pytest.mark.parametrize("fixtures_file", [
     'tests/fixtures/tweets/text_only_10.yaml',
     'tests/fixtures/tweets/retweet_10.yaml',
     'tests/fixtures/tweets/link_10.yaml',
     'tests/fixtures/tweets/reaction_9.yaml',
+], ids=[
+    "TextOnly",
+    "Retweet",
+    "Link",
+    "Reaction"
 ])
-def test_output_tweets(capsys, tweets_factory, output, fixtures_file):
+def test_output_tweets_presence(capsys, tweets_factory, output, fixtures_file, session):
+testdata = [
+    pytest.param('tests/fixtures/tweets/text_only_10.yaml', TweetTextOnly, id="text-only"),
+    pytest.param('tests/fixtures/tweets/retweet_10.yaml', TweetRetweet, id="retweets"),
+    pytest.param('tests/fixtures/tweets/link_10.yaml', TweetLink, id="link"),
+    pytest.param('tests/fixtures/tweets/reaction_9.yaml', TweetReaction, id="reaction"),
+]
+
+
+@pytest.mark.parametrize("fixtures_file, raw_class", testdata)
+def test_output_tweets_presence(tweets_factory, output, fixtures_file, session, raw_class):
     tweets = tweets_factory(fixtures_file)
     output.start()
     output.output_tweets(tweets)
     output.stop()
 
-    session = output.session_maker()
     for t in tweets:
         assert session.query(Tweet).filter(Tweet.id == t.id).one()
 
 
-@pytest.mark.parametrize("fixtures_file", [
-    'tests/fixtures/tweets/text_only_10.yaml',
-    'tests/fixtures/tweets/retweet_10.yaml',
-    'tests/fixtures/tweets/link_10.yaml',
-    'tests/fixtures/tweets/reaction_9.yaml',
-])
-def test_output_tweets_twice(capsys, tweets_factory, output, fixtures_file):
+@pytest.mark.parametrize("fixtures_file, raw_class", testdata)
+def test_output_tweets_twice(tweets_factory, output, fixtures_file, session, raw_class):
     tweets = tweets_factory(fixtures_file)
     output.start()
     output.output_tweets(tweets)
@@ -46,6 +61,5 @@ def test_output_tweets_twice(capsys, tweets_factory, output, fixtures_file):
     output.output_tweets(tweets)
     output.stop()
 
-    session = output.session_maker()
     for t in tweets:
         assert session.query(Tweet).filter(Tweet.id == t.id).one()

+ 2 - 2
twhatter/output/sqlalchemy/db.py

@@ -14,7 +14,7 @@ class_registry = {}
 Base = declarative_base(class_registry=class_registry)
 
 # Session maker
-Session = scoped_session(sessionmaker(autoflush=False))
+Session = sessionmaker(autoflush=False)
 
 logger = logging.getLogger(__name__)
 
@@ -52,7 +52,7 @@ class Database(OutputBase):
 
     def output_users(self, users):
         User = class_registry['User']
-        logger.info("Adding {} tweets".format(len(users)))
+        logger.info("Adding {} user".format(len(users)))
 
         self.all_objs += [User.from_raw(t) for t in users]
         #self.session.add_all([User.from_raw(u) for u in users])