db.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import logging
  2. from sqlalchemy import create_engine
  3. from sqlalchemy.orm import scoped_session, sessionmaker
  4. from sqlalchemy.ext.declarative import declarative_base
  5. from sqlalchemy.exc import IntegrityError
  6. from twhatter.output import OutputBase
  7. from twhatter.client import ClientTimeline, ClientProfile
  8. # Registry of SQLAlchemy's models
  9. class_registry = {}
  10. # Base class for SQLAlchemy models
  11. Base = declarative_base(class_registry=class_registry)
  12. # Session maker
  13. Session = scoped_session(sessionmaker(autoflush=False))
  14. logger = logging.getLogger(__name__)
  15. class Database(OutputBase):
  16. def __init__(self, db_url):
  17. engine = create_engine(db_url)
  18. self.session_maker = Session
  19. self.session_maker.configure(bind=engine)
  20. Base.metadata.create_all(engine)
  21. def start(self):
  22. return self.session_maker()
  23. def stop(self, session):
  24. session.close()
  25. def _add_no_fail(self, session, obj):
  26. # This is an extremely unefficient way to add objects to the database,
  27. # but the only way I've found so far to deal with duplications
  28. session.add(obj)
  29. try:
  30. session.commit()
  31. return 1
  32. except IntegrityError as e:
  33. logger.debug("Error on commit : {}".format(e))
  34. session.rollback()
  35. return 0
  36. def output_tweets(self, tweets):
  37. client_timeline = ClientTimeline(user, limit)
  38. Tweet = class_registry['Tweet']
  39. User = class_registry['User']
  40. session = self.start()
  41. tweets = [Tweet.from_raw(t) for t in client_timeline]
  42. logger.info("Adding {} tweets".format(len(tweets)))
  43. profiles = set()
  44. for t in client_timeline:
  45. p = ClientProfile(t.username)
  46. profiles.add(p)
  47. users = [User.from_raw(p.user) for p in profiles]
  48. unique_errors = 0
  49. for u in users:
  50. self._add_no_fail(session, u)
  51. for t in tweets:
  52. unique_errors += self._add_no_fail(session, t)
  53. if unique_errors:
  54. logger.info(
  55. "{} tweets were already in the database".format(unique_errors)
  56. )
  57. self.stop(session)
  58. def output_users(self, users):
  59. User = class_registry['User']
  60. p = ClientProfile(user)
  61. session = self.start()
  62. self._add_no_fail(session, User.from_raw(p.user))
  63. self.stop(session)