feat: qdrant store mit ensure/upsert/delete-by-path
This commit is contained in:
54
app/qdrant_store.py
Normal file
54
app/qdrant_store.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.http import models as qm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ChunkPoint:
|
||||||
|
vector: list[float]
|
||||||
|
payload: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_collection(client: QdrantClient, name: str, vector_size: int) -> None:
|
||||||
|
"""Create the collection if missing. Crash if it exists with wrong dim."""
|
||||||
|
if not client.collection_exists(name):
|
||||||
|
client.create_collection(
|
||||||
|
collection_name=name,
|
||||||
|
vectors_config=qm.VectorParams(size=vector_size, distance=qm.Distance.COSINE),
|
||||||
|
)
|
||||||
|
for field in ("file_path", "semester", "fach"):
|
||||||
|
client.create_payload_index(
|
||||||
|
collection_name=name,
|
||||||
|
field_name=field,
|
||||||
|
field_schema=qm.PayloadSchemaType.KEYWORD,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
info = client.get_collection(name)
|
||||||
|
existing = info.config.params.vectors.size
|
||||||
|
if existing != vector_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"qdrant collection '{name}' dimension mismatch: "
|
||||||
|
f"existing={existing}, model={vector_size}. "
|
||||||
|
"Drop the collection manually and run a bulk import."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_chunks(client: QdrantClient, name: str, chunks: list[ChunkPoint]) -> None:
|
||||||
|
points = [
|
||||||
|
qm.PointStruct(id=str(uuid.uuid4()), vector=c.vector, payload=c.payload)
|
||||||
|
for c in chunks
|
||||||
|
]
|
||||||
|
client.upsert(collection_name=name, points=points)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_by_path(client: QdrantClient, name: str, file_path: str) -> None:
|
||||||
|
selector = qm.FilterSelector(
|
||||||
|
filter=qm.Filter(
|
||||||
|
must=[qm.FieldCondition(key="file_path", match=qm.MatchValue(value=file_path))]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
client.delete(collection_name=name, points_selector=selector)
|
||||||
73
tests/test_qdrant_store.py
Normal file
73
tests/test_qdrant_store.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.qdrant_store import (
|
||||||
|
ensure_collection,
|
||||||
|
upsert_chunks,
|
||||||
|
delete_by_path,
|
||||||
|
ChunkPoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_collection_creates_when_missing():
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_client.collection_exists.return_value = False
|
||||||
|
|
||||||
|
ensure_collection(fake_client, "rag_test", vector_size=1024)
|
||||||
|
|
||||||
|
fake_client.create_collection.assert_called_once()
|
||||||
|
args, kwargs = fake_client.create_collection.call_args
|
||||||
|
assert kwargs["collection_name"] == "rag_test"
|
||||||
|
# Payload indexes get created
|
||||||
|
assert fake_client.create_payload_index.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_collection_skips_when_exists_with_matching_dim():
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_client.collection_exists.return_value = True
|
||||||
|
info = MagicMock()
|
||||||
|
info.config.params.vectors.size = 1024
|
||||||
|
fake_client.get_collection.return_value = info
|
||||||
|
|
||||||
|
ensure_collection(fake_client, "rag_test", vector_size=1024)
|
||||||
|
|
||||||
|
fake_client.create_collection.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_collection_raises_on_dim_mismatch():
|
||||||
|
fake_client = MagicMock()
|
||||||
|
fake_client.collection_exists.return_value = True
|
||||||
|
info = MagicMock()
|
||||||
|
info.config.params.vectors.size = 768
|
||||||
|
fake_client.get_collection.return_value = info
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="dimension mismatch"):
|
||||||
|
ensure_collection(fake_client, "rag_test", vector_size=1024)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert_chunks_calls_client_upsert():
|
||||||
|
fake_client = MagicMock()
|
||||||
|
points = [
|
||||||
|
ChunkPoint(vector=[0.1] * 4, payload={"file_path": "a", "chunk_index": 0}),
|
||||||
|
ChunkPoint(vector=[0.2] * 4, payload={"file_path": "a", "chunk_index": 1}),
|
||||||
|
]
|
||||||
|
|
||||||
|
upsert_chunks(fake_client, "rag_test", points)
|
||||||
|
|
||||||
|
fake_client.upsert.assert_called_once()
|
||||||
|
kwargs = fake_client.upsert.call_args.kwargs
|
||||||
|
assert kwargs["collection_name"] == "rag_test"
|
||||||
|
assert len(kwargs["points"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_by_path_uses_filter():
|
||||||
|
fake_client = MagicMock()
|
||||||
|
delete_by_path(fake_client, "rag_test", "Documents/x.pdf")
|
||||||
|
|
||||||
|
fake_client.delete.assert_called_once()
|
||||||
|
kwargs = fake_client.delete.call_args.kwargs
|
||||||
|
assert kwargs["collection_name"] == "rag_test"
|
||||||
|
# The filter should target file_path
|
||||||
|
selector = kwargs["points_selector"]
|
||||||
|
# Inspect the FilterSelector → Filter → must → FieldCondition
|
||||||
|
assert selector.filter.must[0].key == "file_path"
|
||||||
Reference in New Issue
Block a user