feat: ollama embedder mit exponential backoff retry
This commit is contained in:
56
app/ingest/embedder.py
Normal file
56
app/ingest/embedder.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from ollama import AsyncClient
|
||||||
|
|
||||||
|
from app.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_BACKOFF_SECONDS = [1, 2, 4]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _client() -> AsyncClient:
|
||||||
|
return AsyncClient(host=get_settings().ollama_url)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_texts(texts: list[str], model: str) -> list[list[float]]:
|
||||||
|
"""Embed each text via Ollama. Retries individual calls 3x with backoff."""
|
||||||
|
vectors: list[list[float]] = []
|
||||||
|
for text in texts:
|
||||||
|
vec = await _embed_one(text, model)
|
||||||
|
vectors.append(vec)
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
|
||||||
|
async def _embed_one(text: str, model: str) -> list[float]:
|
||||||
|
last_err: Exception | None = None
|
||||||
|
client = _client()
|
||||||
|
for attempt in range(len(_BACKOFF_SECONDS) + 1):
|
||||||
|
try:
|
||||||
|
response = await client.embeddings(model=model, prompt=text)
|
||||||
|
return list(response["embedding"])
|
||||||
|
except Exception as exc:
|
||||||
|
last_err = exc
|
||||||
|
if attempt < len(_BACKOFF_SECONDS):
|
||||||
|
wait = _BACKOFF_SECONDS[attempt]
|
||||||
|
logger.warning(
|
||||||
|
"ollama embed retry",
|
||||||
|
extra={"event": "embed_retry", "attempt": attempt + 1, "wait_s": wait, "error": str(exc)},
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
raise EmbeddingError(f"embed failed after retries: {last_err}")
|
||||||
|
|
||||||
|
|
||||||
|
async def embedding_dimension(model: str) -> int:
|
||||||
|
"""Probe a single embedding to discover the model's vector dimension."""
|
||||||
|
vec = await _embed_one("dimension probe", model)
|
||||||
|
return len(vec)
|
||||||
52
tests/test_embedder.py
Normal file
52
tests/test_embedder.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.ingest.embedder import embed_texts, EmbeddingError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_texts_returns_vectors():
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_client.embeddings = AsyncMock(side_effect=[
|
||||||
|
{"embedding": [0.1, 0.2, 0.3]},
|
||||||
|
{"embedding": [0.4, 0.5, 0.6]},
|
||||||
|
])
|
||||||
|
|
||||||
|
with patch("app.ingest.embedder._client", return_value=fake_client):
|
||||||
|
vectors = await embed_texts(["hello", "world"], model="qwen3-embedding:0.6b")
|
||||||
|
|
||||||
|
assert vectors == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
assert fake_client.embeddings.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_texts_retries_on_failure(monkeypatch):
|
||||||
|
monkeypatch.setattr("app.ingest.embedder._BACKOFF_SECONDS", [0, 0, 0])
|
||||||
|
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_client.embeddings = AsyncMock(side_effect=[
|
||||||
|
Exception("connection refused"),
|
||||||
|
Exception("timeout"),
|
||||||
|
{"embedding": [1.0, 2.0]},
|
||||||
|
])
|
||||||
|
|
||||||
|
with patch("app.ingest.embedder._client", return_value=fake_client):
|
||||||
|
vectors = await embed_texts(["hi"], model="m")
|
||||||
|
|
||||||
|
assert vectors == [[1.0, 2.0]]
|
||||||
|
assert fake_client.embeddings.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_texts_raises_after_max_retries(monkeypatch):
|
||||||
|
monkeypatch.setattr("app.ingest.embedder._BACKOFF_SECONDS", [0, 0, 0])
|
||||||
|
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_client.embeddings = AsyncMock(side_effect=Exception("nope"))
|
||||||
|
|
||||||
|
with patch("app.ingest.embedder._client", return_value=fake_client):
|
||||||
|
with pytest.raises(EmbeddingError):
|
||||||
|
await embed_texts(["hi"], model="m")
|
||||||
|
|
||||||
|
assert fake_client.embeddings.call_count == 4 # initial + 3 retries
|
||||||
Reference in New Issue
Block a user