Преглед изворни кода

Do not use tricks to modify EmbeddingsWorker instance

jherve пре 1 година
родитељ
комит
baafa29b0f
1 измењених фајлова са 3 додато и 7 уклоњено
  1. 3 7
      src/media_observer/test.py

+ 3 - 7
src/media_observer/test.py

@@ -10,7 +10,7 @@ from typing import Any, ClassVar
 import urllib.parse
 import urllib.parse
 from zoneinfo import ZoneInfo
 from zoneinfo import ZoneInfo
 from loguru import logger
 from loguru import logger
-from attrs import field, frozen
+from attrs import define, field, frozen
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 from uuid import UUID, uuid1
 from uuid import UUID, uuid1
 from hypercorn.asyncio import serve
 from hypercorn.asyncio import serve
@@ -209,7 +209,6 @@ class SnapshotStoreJob(Job):
             raise e
             raise e
 
 
 
 
-@frozen
 class Worker(ABC):
 class Worker(ABC):
     @abstractmethod
     @abstractmethod
     async def run(self): ...
     async def run(self): ...
@@ -298,7 +297,7 @@ def batched(iterable, n):
         yield batch
         yield batch
 
 
 
 
-@frozen
+@define
 class EmbeddingsWorker(Worker):
 class EmbeddingsWorker(Worker):
     storage: Storage
     storage: Storage
     model_name: str
     model_name: str
@@ -310,10 +309,7 @@ class EmbeddingsWorker(Worker):
         def load_model():
         def load_model():
             from sentence_transformers import SentenceTransformer
             from sentence_transformers import SentenceTransformer
 
 
-            # Quite a dirty trick since the instance is supposed to be "frozen"
-            # but I did not find a better solution to load the model in the
-            # background
-            object.__setattr__(self, "model", SentenceTransformer(self.model_name))
+            self.model = SentenceTransformer(self.model_name)
 
 
         def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
         def compute_embeddings_for(sentences: tuple[tuple[int, str]]):
             logger.debug(f"Computing embeddings for {len(sentences)} sentences")
             logger.debug(f"Computing embeddings for {len(sentences)} sentences")