storage.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. from typing import Any
  2. from datetime import datetime
  3. import numpy as np
  4. from yarl import URL
  5. from config import settings
  6. from media_observer.article import (
  7. TopArticle,
  8. FeaturedArticleSnapshot,
  9. FeaturedArticle,
  10. )
  11. from media_observer.storage_abstraction import Table, Column, UniqueIndex, View
  12. from media_observer.db.sqlite import SqliteBackend
  13. from media_observer.db.postgres import PostgresBackend
  14. from media_observer.internet_archive import InternetArchiveSnapshotId
  15. class Storage:
  16. tables = [
  17. Table(
  18. name="sites",
  19. columns=[
  20. Column(name="id", primary_key=True),
  21. Column(name="name", type_="TEXT"),
  22. Column(name="original_url", type_="TEXT"),
  23. ],
  24. ),
  25. Table(
  26. name="snapshots",
  27. columns=[
  28. Column(name="id", primary_key=True),
  29. Column(
  30. name="site_id",
  31. references="sites (id) ON DELETE CASCADE",
  32. ),
  33. Column(name="timestamp", type_="timestamp with time zone"),
  34. Column(name="timestamp_virtual", type_="timestamp with time zone"),
  35. Column(name="url_original", type_="TEXT"),
  36. Column(name="url_snapshot", type_="TEXT"),
  37. ],
  38. ),
  39. Table(
  40. name="featured_articles",
  41. columns=[
  42. Column(name="id", primary_key=True),
  43. Column(name="url", type_="TEXT"),
  44. ],
  45. ),
  46. Table(
  47. name="featured_article_snapshots",
  48. columns=[
  49. Column(name="id", primary_key=True),
  50. Column(
  51. name="featured_article_id",
  52. references="featured_articles (id) ON DELETE CASCADE",
  53. ),
  54. Column(name="title", type_="TEXT"),
  55. Column(name="url", type_="TEXT"),
  56. ],
  57. ),
  58. Table(
  59. name="main_articles",
  60. columns=[
  61. Column(name="id", primary_key=True),
  62. Column(
  63. name="snapshot_id",
  64. references="snapshots (id) ON DELETE CASCADE",
  65. ),
  66. Column(
  67. name="featured_article_snapshot_id",
  68. references="featured_article_snapshots (id) ON DELETE CASCADE",
  69. ),
  70. ],
  71. ),
  72. Table(
  73. name="top_articles",
  74. columns=[
  75. Column(name="id", primary_key=True),
  76. Column(
  77. name="snapshot_id",
  78. references="snapshots (id) ON DELETE CASCADE",
  79. ),
  80. Column(
  81. name="featured_article_snapshot_id",
  82. references="featured_article_snapshots (id) ON DELETE CASCADE",
  83. ),
  84. Column(name="rank", type_="INTEGER"),
  85. ],
  86. ),
  87. Table(
  88. name="articles_embeddings",
  89. columns=[
  90. Column(name="id", primary_key=True),
  91. Column(
  92. name="featured_article_snapshot_id",
  93. references="featured_article_snapshots (id) ON DELETE CASCADE",
  94. ),
  95. Column(name="title_embedding", type_="bytea"),
  96. ],
  97. ),
  98. ]
  99. views = [
  100. View(
  101. name="snapshots_view",
  102. column_names=[
  103. "id",
  104. "site_id",
  105. "site_name",
  106. "site_original_url",
  107. "timestamp",
  108. "timestamp_virtual",
  109. ],
  110. create_stmt="""
  111. SELECT
  112. s.id,
  113. si.id AS site_id,
  114. si.name AS site_name,
  115. si.original_url AS site_original_url,
  116. s.timestamp,
  117. s.timestamp_virtual
  118. FROM
  119. snapshots AS s
  120. JOIN
  121. sites AS si ON si.id = s.site_id
  122. """,
  123. ),
  124. View(
  125. name="main_page_apparitions",
  126. column_names=[
  127. "id",
  128. "featured_article_id",
  129. "title",
  130. "url_archive",
  131. "url_article",
  132. "main_in_snapshot_id",
  133. "top_in_snapshot_id",
  134. "rank",
  135. ],
  136. create_stmt="""
  137. SELECT
  138. fas.id,
  139. fas.featured_article_id,
  140. fas.title,
  141. fas.url AS url_archive,
  142. fa.url AS url_article,
  143. m.snapshot_id AS main_in_snapshot_id,
  144. t.snapshot_id AS top_in_snapshot_id,
  145. t.rank
  146. FROM featured_article_snapshots fas
  147. JOIN featured_articles fa ON fa.id = fas.featured_article_id
  148. LEFT JOIN main_articles m ON m.featured_article_snapshot_id = fas.id
  149. LEFT JOIN top_articles t ON t.featured_article_snapshot_id = fas.id
  150. """,
  151. ),
  152. View(
  153. name="snapshot_apparitions",
  154. column_names=[
  155. "snapshot_id",
  156. "site_id",
  157. "site_name",
  158. "site_original_url",
  159. "timestamp",
  160. "timestamp_virtual",
  161. "featured_article_snapshot_id",
  162. "featured_article_id",
  163. "title",
  164. "url_archive",
  165. "url_article",
  166. "is_main",
  167. "rank",
  168. ],
  169. create_stmt="""
  170. SELECT
  171. sv.id as snapshot_id,
  172. sv.site_id,
  173. sv.site_name,
  174. sv.site_original_url,
  175. sv.timestamp,
  176. sv.timestamp_virtual,
  177. mpa.id AS featured_article_snapshot_id,
  178. mpa.featured_article_id,
  179. mpa.title,
  180. mpa.url_archive,
  181. mpa.url_article,
  182. mpa.main_in_snapshot_id IS NOT NULL AS is_main,
  183. mpa.rank
  184. FROM main_page_apparitions mpa
  185. JOIN snapshots_view sv ON sv.id = mpa.main_in_snapshot_id OR sv.id = mpa.top_in_snapshot_id
  186. """,
  187. ),
  188. ]
  189. indexes = [
  190. UniqueIndex(
  191. name="sites_unique_name",
  192. table="sites",
  193. columns=["name"],
  194. ),
  195. UniqueIndex(
  196. name="snapshots_unique_timestamp_virtual_site_id",
  197. table="snapshots",
  198. columns=["timestamp_virtual", "site_id"],
  199. ),
  200. UniqueIndex(
  201. name="main_articles_unique_idx_snapshot_id",
  202. table="main_articles",
  203. columns=["snapshot_id"],
  204. ),
  205. UniqueIndex(
  206. name="featured_articles_unique_url",
  207. table="featured_articles",
  208. columns=["url"],
  209. ),
  210. UniqueIndex(
  211. name="featured_article_snapshots_unique_idx_featured_article_id_url",
  212. table="featured_article_snapshots",
  213. columns=["featured_article_id", "url"],
  214. ),
  215. UniqueIndex(
  216. name="top_articles_unique_idx_snapshot_id_rank",
  217. table="top_articles",
  218. columns=["snapshot_id", "rank"],
  219. ),
  220. UniqueIndex(
  221. name="articles_embeddings_unique_idx_featured_article_snapshot_id",
  222. table="articles_embeddings",
  223. columns=["featured_article_snapshot_id"],
  224. ),
  225. ]
  226. def __init__(self, backend):
  227. self.backend = backend
  228. async def close(self):
  229. await self.backend.close()
  230. @staticmethod
  231. async def create():
  232. # We try to reproduce the scheme used by SQLAlchemy for Database-URLs
  233. # https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls
  234. conn_url = URL(settings.database_url)
  235. backend = None
  236. if conn_url.scheme == "sqlite":
  237. if conn_url.path.startswith("//"):
  238. raise ValueError("Absolute URLs not supported for sqlite")
  239. elif conn_url.path.startswith("/"):
  240. backend = await SqliteBackend.create(conn_url.path[1:])
  241. elif conn_url.scheme == "postgresql":
  242. backend = await PostgresBackend.create(settings.database_url)
  243. else:
  244. raise ValueError("Only the SQLite backend is supported")
  245. storage = Storage(backend)
  246. await storage._create_db()
  247. return storage
  248. async def _create_db(self):
  249. async with self.backend.get_connection() as conn:
  250. for t in self.tables:
  251. await t.create_if_not_exists(conn)
  252. for i in self.indexes:
  253. await i.create_if_not_exists(conn)
  254. for v in self.views:
  255. await v.create_if_not_exists(conn)
  256. async def exists_snapshot(self, name: str, dt: datetime):
  257. async with self.backend.get_connection() as conn:
  258. exists = await conn.execute_fetchall(
  259. """
  260. SELECT 1
  261. FROM snapshots snap
  262. JOIN sites s ON s.id = snap.site_id
  263. WHERE s.name = $1 AND timestamp_virtual = $2
  264. """,
  265. name,
  266. dt,
  267. )
  268. return exists != []
  269. async def list_all_featured_article_snapshots(self):
  270. async with self.backend.get_connection() as conn:
  271. rows = await conn.execute_fetchall(
  272. """
  273. SELECT *
  274. FROM featured_article_snapshots
  275. """,
  276. )
  277. return [
  278. self._from_row(r, self._table_by_name["featured_article_snapshots"])
  279. for r in rows
  280. ]
  281. async def list_snapshot_apparitions(self, featured_article_snapshot_ids: list[int]):
  282. if len(featured_article_snapshot_ids) == 0:
  283. return []
  284. async with self.backend.get_connection() as conn:
  285. rows = await conn.execute_fetchall(
  286. f"""
  287. SELECT *
  288. FROM snapshot_apparitions
  289. WHERE featured_article_snapshot_id IN ({self._placeholders(*featured_article_snapshot_ids)})
  290. """,
  291. *featured_article_snapshot_ids,
  292. )
  293. return [
  294. self._from_row(r, self._view_by_name["snapshot_apparitions"])
  295. for r in rows
  296. ]
  297. @classmethod
  298. def _from_row(cls, r, table_or_view: Table | View):
  299. columns = table_or_view.column_names
  300. return {col: r[idx] for idx, col in enumerate(columns)}
  301. async def list_all_embedded_featured_article_snapshot_ids(self) -> list[int]:
  302. async with self.backend.get_connection() as conn:
  303. rows = await conn.execute_fetchall(
  304. """
  305. SELECT featured_article_snapshot_id
  306. FROM articles_embeddings
  307. """,
  308. )
  309. return [r[0] for r in rows]
  310. async def list_all_articles_embeddings(self):
  311. async with self.backend.get_connection() as conn:
  312. rows = await conn.execute_fetchall(
  313. """
  314. SELECT *
  315. FROM articles_embeddings
  316. """,
  317. )
  318. return [self._from_articles_embeddings_row(r) for r in rows]
  319. @classmethod
  320. def _from_articles_embeddings_row(cls, r):
  321. [embeds_table] = [t for t in cls.tables if t.name == "articles_embeddings"]
  322. d = cls._from_row(r, embeds_table)
  323. d.update(title_embedding=np.frombuffer(d["title_embedding"], dtype="float32"))
  324. return d
  325. async def add_embedding(self, featured_article_snapshot_id: int, embedding):
  326. async with self.backend.get_connection() as conn:
  327. await conn.execute_insert(
  328. self._insert_stmt(
  329. "articles_embeddings",
  330. ["featured_article_snapshot_id", "title_embedding"],
  331. ),
  332. featured_article_snapshot_id,
  333. embedding,
  334. )
  335. async def list_sites(self):
  336. async with self.backend.get_connection() as conn:
  337. sites = await conn.execute_fetchall("SELECT * FROM sites")
  338. return [self._from_row(s, self._table_by_name["sites"]) for s in sites]
  339. async def list_neighbouring_main_articles(
  340. self,
  341. site_id: int,
  342. timestamp: datetime | None = None,
  343. ):
  344. async with self.backend.get_connection() as conn:
  345. if timestamp is None:
  346. [row] = await conn.execute_fetchall(
  347. """
  348. SELECT timestamp_virtual
  349. FROM snapshots_view
  350. WHERE site_id = $1
  351. ORDER BY timestamp_virtual DESC
  352. LIMIT 1
  353. """,
  354. site_id,
  355. )
  356. timestamp = row["timestamp_virtual"]
  357. # This query is the union of 3 queries that respectively fetch :
  358. # * articles published at the same time as the queried article (including the queried article)
  359. # * the article published just after, on the same site
  360. # *the article published just before, on the same site
  361. main_articles = await conn.execute_fetchall(
  362. """
  363. WITH sav_diff AS (
  364. SELECT sav.*, EXTRACT(EPOCH FROM sav.timestamp_virtual - $2) :: integer AS time_diff
  365. FROM snapshot_apparitions sav
  366. )
  367. SELECT * FROM (
  368. SELECT * FROM sav_diff
  369. WHERE is_main AND time_diff = 0
  370. )
  371. UNION ALL
  372. SELECT * FROM (
  373. SELECT * FROM sav_diff
  374. WHERE is_main AND site_id = $1 AND time_diff > 0
  375. ORDER BY time_diff
  376. LIMIT 1
  377. )
  378. UNION ALL
  379. SELECT * FROM (
  380. SELECT * FROM sav_diff
  381. WHERE is_main AND site_id = $1 AND time_diff < 0
  382. ORDER BY time_diff DESC
  383. LIMIT 1
  384. )
  385. """,
  386. site_id,
  387. timestamp,
  388. )
  389. return [
  390. self._from_row(a, self._view_by_name["snapshot_apparitions"])
  391. | {"time_diff": a[13]}
  392. for a in main_articles
  393. ]
  394. async def add_page(self, collection, page, dt):
  395. assert dt.tzinfo is not None
  396. async with self.backend.get_connection() as conn:
  397. async with conn.transaction():
  398. site_id = await self._add_site(conn, collection.name, collection.url)
  399. snapshot_id = await self._add_snapshot(
  400. conn, site_id, page.snapshot.id, dt
  401. )
  402. article_id = await self._add_featured_article(
  403. conn, page.main_article.article.original
  404. )
  405. main_article_snap_id = await self._add_featured_article_snapshot(
  406. conn, article_id, page.main_article.article
  407. )
  408. await self._add_main_article(conn, snapshot_id, main_article_snap_id)
  409. for t in page.top_articles:
  410. article_id = await self._add_featured_article(
  411. conn, t.article.original
  412. )
  413. top_article_snap_id = await self._add_featured_article_snapshot(
  414. conn, article_id, t.article
  415. )
  416. await self._add_top_article(
  417. conn, snapshot_id, top_article_snap_id, t
  418. )
  419. return site_id
  420. async def _add_site(self, conn, name: str, original_url: str) -> int:
  421. return await self._insert_or_get(
  422. conn,
  423. self._insert_stmt("sites", ["name", "original_url"]),
  424. [name, original_url],
  425. "SELECT id FROM sites WHERE name = $1",
  426. [name],
  427. )
  428. async def _add_snapshot(
  429. self, conn, site_id: int, snapshot: InternetArchiveSnapshotId, virtual: datetime
  430. ) -> int:
  431. return await self._insert_or_get(
  432. conn,
  433. self._insert_stmt(
  434. "snapshots",
  435. [
  436. "timestamp",
  437. "site_id",
  438. "timestamp_virtual",
  439. "url_original",
  440. "url_snapshot",
  441. ],
  442. ),
  443. [snapshot.timestamp, site_id, virtual, snapshot.original, snapshot.url],
  444. "SELECT id FROM snapshots WHERE timestamp_virtual = $1 AND site_id = $2",
  445. [virtual, site_id],
  446. )
  447. async def _add_featured_article(self, conn, article: FeaturedArticle):
  448. return await self._insert_or_get(
  449. conn,
  450. self._insert_stmt("featured_articles", ["url"]),
  451. [str(article.url)],
  452. "SELECT id FROM featured_articles WHERE url = $1",
  453. [str(article.url)],
  454. )
  455. async def _add_featured_article_snapshot(
  456. self, conn, featured_article_id: int, article: FeaturedArticleSnapshot
  457. ):
  458. return await self._insert_or_get(
  459. conn,
  460. self._insert_stmt(
  461. "featured_article_snapshots",
  462. ["title", "url", "featured_article_id"],
  463. ),
  464. [article.title, str(article.url), featured_article_id],
  465. "SELECT id FROM featured_article_snapshots WHERE featured_article_id = $1 AND url = $2",
  466. [featured_article_id, str(article.url)],
  467. )
  468. async def _add_main_article(self, conn, snapshot_id: int, article_id: int):
  469. await conn.execute_insert(
  470. self._insert_stmt(
  471. "main_articles", ["snapshot_id", "featured_article_snapshot_id"]
  472. ),
  473. snapshot_id,
  474. article_id,
  475. )
  476. async def _add_top_article(
  477. self, conn, snapshot_id: int, article_id: int, article: TopArticle
  478. ):
  479. await conn.execute_insert(
  480. self._insert_stmt(
  481. "top_articles",
  482. ["snapshot_id", "featured_article_snapshot_id", "rank"],
  483. ),
  484. snapshot_id,
  485. article_id,
  486. article.rank,
  487. )
  488. async def _insert_or_get(
  489. self,
  490. conn,
  491. insert_stmt: str,
  492. insert_args: list[Any],
  493. select_stmt: str,
  494. select_args: list[Any],
  495. ) -> int:
  496. await conn.execute_insert(insert_stmt, *insert_args)
  497. [(id_,)] = await conn.execute_fetchall(select_stmt, *select_args)
  498. return id_
  499. @staticmethod
  500. def _insert_stmt(table, cols):
  501. cols_str = ", ".join(cols)
  502. return f"""
  503. INSERT INTO {table} ({cols_str})
  504. VALUES ({Storage._placeholders(*cols)})
  505. ON CONFLICT DO NOTHING
  506. """
  507. @staticmethod
  508. def _placeholders(*args):
  509. return ", ".join([f"${idx + 1}" for idx, _ in enumerate(args)])
  510. @property
  511. def _table_by_name(self):
  512. return {t.name: t for t in self.tables}
  513. @property
  514. def _view_by_name(self):
  515. return {v.name: v for v in self.views}