test.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import asyncio
  2. from itertools import islice
  3. from datetime import date, datetime, time, timedelta
  4. import os
  5. from pathlib import Path
  6. import pickle
  7. import tempfile
  8. import traceback
  9. from typing import ClassVar
  10. import urllib.parse
  11. from zoneinfo import ZoneInfo
  12. from loguru import logger
  13. from attrs import frozen
  14. from abc import ABC, abstractmethod
  15. from uuid import UUID, uuid1
  16. from hypercorn.asyncio import serve
  17. from hypercorn.config import Config
  18. from media_observer.article import ArchiveCollection, FrontPage
  19. from media_observer.internet_archive import (
  20. InternetArchiveClient,
  21. InternetArchiveSnapshot,
  22. InternetArchiveSnapshotId,
  23. SnapshotNotYetAvailable,
  24. )
  25. from media_observer.storage import Storage
  26. from media_observer.medias import media_collection
  27. from media_observer.web import app
  28. from config import settings
  29. tmpdir = Path(tempfile.mkdtemp(prefix="media_observer"))
  30. @frozen
  31. class Job(ABC):
  32. id_: UUID
  33. queue: ClassVar[asyncio.Queue]
  34. @abstractmethod
  35. async def execute(self, **kwargs): ...
  36. def _log(self, level: str, msg: str):
  37. logger.log(level, f"[{self.id_}] {msg}")
  38. class StupidJob(Job):
  39. async def execute(self, *args, **kwargs):
  40. logger.info(f"Executing job {self.id_}..")
  41. def unique_id():
  42. return uuid1()
  43. @frozen
  44. class SnapshotSearchJob(Job):
  45. queue = asyncio.Queue()
  46. collection: ArchiveCollection
  47. dt: datetime
  48. @classmethod
  49. def create(cls, n_days: int, hours: list[int]):
  50. return [
  51. cls(unique_id(), c, d)
  52. for c in media_collection.values()
  53. for d in cls.last_n_days_at_hours(n_days, hours, c.tz)
  54. ]
  55. @staticmethod
  56. def last_n_days_at_hours(n: int, hours: list[int], tz: ZoneInfo) -> list[datetime]:
  57. now = datetime.now(tz)
  58. return [
  59. dt
  60. for i in range(0, n)
  61. for h in hours
  62. if (
  63. dt := datetime.combine(
  64. date.today() - timedelta(days=i), time(hour=h), tzinfo=tz
  65. )
  66. )
  67. < now
  68. ]
  69. async def execute(self, *, storage: Storage, ia_client: InternetArchiveClient):
  70. collection = self.collection
  71. dt = self.dt
  72. if await storage.exists_frontpage(collection.name, dt):
  73. return None, []
  74. self._log(
  75. "DEBUG",
  76. f"Start handling snap for collection {collection.name} @ {dt}",
  77. )
  78. try:
  79. id_closest = await ia_client.get_snapshot_id_closest_to(
  80. self.collection.url, self.dt
  81. )
  82. delta = self.dt - id_closest.timestamp
  83. abs_delta = abs(delta)
  84. if abs_delta.total_seconds() > 3600:
  85. time = "after" if delta > timedelta(0) else "before"
  86. self._log(
  87. "WARNING",
  88. f"Snapshot is {abs(delta)} {time} the required timestamp ({id_closest.timestamp} instead of {self.dt})",
  89. )
  90. self._log("INFO", f"Got snapshot {id_closest}")
  91. return id_closest, [
  92. SnapshotFetchJob(self.id_, id_closest, self.collection, self.dt)
  93. ]
  94. except SnapshotNotYetAvailable as e:
  95. self._log(
  96. "WARNING",
  97. f"Snapshot for {collection.name} @ {dt} not yet available",
  98. )
  99. raise e
  100. except Exception as e:
  101. self._log(
  102. "ERROR",
  103. f"Error while trying to find snapshot for {collection.name} @ {dt}",
  104. )
  105. traceback.print_exception(e)
  106. raise e
  107. @frozen
  108. class SnapshotFetchJob(Job):
  109. queue = asyncio.Queue()
  110. snap_id: InternetArchiveSnapshotId
  111. collection: ArchiveCollection
  112. dt: datetime
  113. async def execute(self, ia_client: InternetArchiveClient):
  114. try:
  115. closest = await ia_client.fetch(self.snap_id)
  116. return closest, [
  117. SnapshotParseJob(self.id_, self.collection, closest, self.dt)
  118. ]
  119. except Exception as e:
  120. self._log("ERROR", f"Error while fetching {self.snap_id}")
  121. traceback.print_exception(e)
  122. raise e
  123. @frozen
  124. class SnapshotParseJob(Job):
  125. queue = asyncio.Queue()
  126. collection: ArchiveCollection
  127. snapshot: InternetArchiveSnapshot
  128. dt: datetime
  129. async def execute(self):
  130. try:
  131. main_page = await self.collection.FrontPageClass.from_snapshot(
  132. self.snapshot
  133. )
  134. return main_page, [
  135. SnapshotStoreJob(self.id_, main_page, self.collection, self.dt)
  136. ]
  137. except Exception as e:
  138. snapshot = self.snapshot
  139. sub_dir = (
  140. tmpdir
  141. / urllib.parse.quote_plus(snapshot.id.original)
  142. / urllib.parse.quote_plus(str(snapshot.id.timestamp))
  143. )
  144. os.makedirs(sub_dir)
  145. with open(sub_dir / "self.pickle", "wb") as f:
  146. pickle.dump(self, f)
  147. with open(sub_dir / "snapshot.html", "w") as f:
  148. f.write(snapshot.text)
  149. with open(sub_dir / "exception.txt", "w") as f:
  150. f.writelines(traceback.format_exception(e))
  151. with open(sub_dir / "url.txt", "w") as f:
  152. f.write(snapshot.id.url)
  153. self._log(
  154. "ERROR",
  155. f"Error while parsing snapshot from {snapshot.id.url}, details were written in directory {sub_dir}",
  156. )
  157. raise e
  158. @frozen
  159. class SnapshotStoreJob(Job):
  160. queue = asyncio.Queue()
  161. page: FrontPage
  162. collection: ArchiveCollection
  163. dt: datetime
  164. async def execute(self, storage: Storage):
  165. try:
  166. return await storage.add_page(self.collection, self.page, self.dt), []
  167. except Exception as e:
  168. self._log(
  169. "ERROR",
  170. f"Error while attempting to store {self.page} from {self.collection.name} @ {self.dt}",
  171. )
  172. traceback.print_exception(e)
  173. raise e
  174. @frozen
  175. class Worker(ABC):
  176. async def loop(self):
  177. logger.info(f"Task {self.__class__.__name__} {id(self)} booting..")
  178. while True:
  179. try:
  180. await self.run()
  181. except asyncio.CancelledError:
  182. logger.warning(f"Task {self.__class__.__name__} {id(self)} cancelled")
  183. return
  184. except Exception as e:
  185. traceback.print_exception(e)
  186. logger.error(
  187. f"Task {self.__class__.__name__} {id(self)} failed with {e}"
  188. )
  189. @abstractmethod
  190. async def run(self): ...
  191. def get_execution_context(self) -> dict: ...
  192. @frozen
  193. class QueueWorker(Worker):
  194. inbound_queue: asyncio.Queue
  195. outbound_queue: asyncio.Queue | None
  196. async def run(self):
  197. logger.info(f"Task {self.__class__.__name__} {id(self)} waiting for job..")
  198. job: Job = await self.inbound_queue.get()
  199. assert isinstance(job, Job)
  200. ret, further_jobs = await job.execute(**self.get_execution_context())
  201. try:
  202. for j in further_jobs:
  203. await self.outbound_queue.put(j)
  204. except AttributeError as e:
  205. logger.error(
  206. f"Could not push jobs {further_jobs} because there is no outbound queue"
  207. )
  208. raise (e)
  209. self.inbound_queue.task_done()
  210. def get_execution_context(self):
  211. return {}
  212. @frozen
  213. class SnapshotWorker(QueueWorker):
  214. storage: Storage
  215. ia_client: InternetArchiveClient
  216. def get_execution_context(self):
  217. return {"storage": self.storage, "ia_client": self.ia_client}
  218. @frozen
  219. class FetchWorker(QueueWorker):
  220. ia_client: InternetArchiveClient
  221. def get_execution_context(self):
  222. return {"ia_client": self.ia_client}
  223. @frozen
  224. class ParseWorker(QueueWorker): ...
  225. @frozen
  226. class StoreWorker(QueueWorker):
  227. storage: Storage
  228. def get_execution_context(self):
  229. return {"storage": self.storage}
  230. def batched(iterable, n):
  231. """
  232. Batch data into tuples of length n. The last batch may be shorter.
  233. `batched('ABCDEFG', 3) --> ABC DEF G`
  234. Straight from : https://docs.python.org/3.11/library/itertools.html#itertools-recipes
  235. """
  236. if n < 1:
  237. raise ValueError("n must be at least one")
  238. it = iter(iterable)
  239. while batch := tuple(islice(it, n)):
  240. yield batch
  241. @frozen
  242. class EmbeddingsWorker(Worker):
  243. storage: Storage
  244. model_name: str
  245. batch_size: int
  246. async def run(self):
  247. def load_model():
  248. logger.info("Starting stuff")
  249. from sentence_transformers import SentenceTransformer
  250. model = SentenceTransformer(self.model_name)
  251. return model
  252. def compute_embeddings_for(model, sentences: tuple[tuple[int, str]]):
  253. logger.debug(f"Computing embeddings for {len(sentences)} sentences")
  254. all_texts = [t[1] for t in sentences]
  255. all_embeddings = model.encode(all_texts)
  256. return {sentences[idx][0]: e for idx, e in enumerate(all_embeddings)}
  257. logger.info("Embeddings watchdog")
  258. loop = asyncio.get_running_loop()
  259. model = await loop.run_in_executor(None, load_model)
  260. logger.info("Model loaded")
  261. all_titles = [
  262. (t["id"], t["text"])
  263. for t in await self.storage.list_all_titles_without_embedding()
  264. ]
  265. for batch in batched(all_titles, self.batch_size):
  266. embeddings = compute_embeddings_for(model, batch)
  267. logger.info(f"embeddings: {embeddings}")
  268. for i, embed in embeddings.items():
  269. await self.storage.add_embedding(i, embed)
  270. logger.debug(f"Stored {len(embeddings)} embeddings")
  271. @frozen
  272. class WebServer(Worker):
  273. async def run(self):
  274. shutdown_event = asyncio.Event()
  275. try:
  276. logger.info("Web server stuff")
  277. # Just setting the shutdown_trigger even though it is not connected
  278. # to anything allows the app to gracefully shutdown
  279. await serve(app, Config(), shutdown_trigger=shutdown_event.wait)
  280. except asyncio.CancelledError:
  281. logger.warning("Web server exiting")
  282. return
  283. async def main():
  284. tasks = []
  285. jobs = SnapshotSearchJob.create(
  286. settings.snapshots.days_in_past, settings.snapshots.hours
  287. )
  288. storage = await Storage.create()
  289. try:
  290. async with InternetArchiveClient.create() as ia:
  291. workers = {
  292. "snapshot": SnapshotWorker(
  293. SnapshotSearchJob.queue, SnapshotFetchJob.queue, storage, ia
  294. ),
  295. "fetch": FetchWorker(
  296. SnapshotFetchJob.queue, SnapshotParseJob.queue, ia
  297. ),
  298. "parse": ParseWorker(SnapshotParseJob.queue, SnapshotStoreJob.queue),
  299. "store": StoreWorker(SnapshotStoreJob.queue, None, storage),
  300. }
  301. web_server = WebServer()
  302. embeds = EmbeddingsWorker(
  303. storage,
  304. "dangvantuan/sentence-camembert-large",
  305. 64,
  306. )
  307. async with asyncio.TaskGroup() as tg:
  308. for w in workers.values():
  309. tasks.append(tg.create_task(w.loop()))
  310. for j in jobs[:3]:
  311. await SnapshotSearchJob.queue.put(j)
  312. tasks.append(tg.create_task(web_server.run()))
  313. tasks.append(tg.create_task(embeds.run()))
  314. finally:
  315. await storage.close()
  316. if __name__ == "__main__":
  317. try:
  318. asyncio.run(main())
  319. except KeyboardInterrupt:
  320. logger.warning("Main kbinterrupt")