|
|
@@ -1,5 +1,6 @@
|
|
|
from datetime import date, datetime, time, timedelta
|
|
|
import asyncio
|
|
|
+from typing import Any
|
|
|
from attrs import frozen
|
|
|
import traceback
|
|
|
from loguru import logger
|
|
|
@@ -162,9 +163,16 @@ class EmbeddingsWorker:
|
|
|
async def compute_embeddings(storage: Storage):
|
|
|
worker = EmbeddingsWorker.create(storage, "dangvantuan/sentence-camembert-large")
|
|
|
all_snapshots = await storage.list_all_featured_article_snapshots()
|
|
|
+ all_embeds_ids = set(
|
|
|
+ await storage.list_all_embedded_featured_article_snapshot_ids()
|
|
|
+ )
|
|
|
+
|
|
|
+ all_snapshots_not_stored = (
|
|
|
+ s for s in all_snapshots if s["id"] not in all_embeds_ids
|
|
|
+ )
|
|
|
|
|
|
batch_size = 64
|
|
|
- for batch in batched(all_snapshots, batch_size):
|
|
|
+ for batch in batched(all_snapshots_not_stored, batch_size):
|
|
|
embeddings_by_id = worker.compute_embeddings_for(
|
|
|
{s["id"]: s["title"] for s in batch}
|
|
|
)
|