A 5e storytelling engine with an LLM DM
1"""Tests for vector search index."""
2
3from pathlib import Path
4
5import pytest
6from conftest import _fake_embed
7
8from storied.search import SearchHit, VectorIndex, age_decay, chunk_document
9
10
11@pytest.fixture
12def index(tmp_path: Path):
13 """VectorIndex with fake embedder (no model download)."""
14 db_path = tmp_path / "test.search.db"
15 idx = VectorIndex(db_path)
16 idx._embed_fn = _fake_embed
17 yield idx
18 idx.close()
19
20
21@pytest.fixture
22def srd_tree(tmp_path: Path) -> Path:
23 """Small SRD-like directory tree for reindex tests."""
24 spells = tmp_path / "spells"
25 spells.mkdir()
26 (spells / "fireball.md").write_text(
27 "# Fireball\n\n_Level 3 Evocation (Sorcerer, Wizard)_\n\n"
28 "Each creature in a 20-foot-radius sphere must make a Dexterity "
29 "saving throw. A target takes 8d6 fire damage on a failed save.\n"
30 )
31 (spells / "magic-missile.md").write_text(
32 "# Magic Missile\n\n_Level 1 Evocation (Sorcerer, Wizard)_\n\n"
33 "You create three glowing darts of magical force. Each dart hits "
34 "automatically and deals 1d4+1 force damage.\n"
35 )
36 (spells / "cure-wounds.md").write_text(
37 "# Cure Wounds\n\n_Level 1 Abjuration (Cleric, Druid)_\n\n"
38 "A creature you touch regains hit points equal to 1d8 + your "
39 "spellcasting ability modifier.\n"
40 )
41
42 monsters = tmp_path / "monsters"
43 monsters.mkdir()
44 (monsters / "goblin.md").write_text(
45 "# Goblin\n\n_Small Humanoid, Neutral Evil_\n\n"
46 "**AC** 15 (leather armor, shield) **HP** 7 (2d6)\n"
47 )
48 (monsters / "ancient-red-dragon.md").write_text(
49 "# Ancient Red Dragon\n\n_Gargantuan Dragon, Chaotic Evil_\n\n"
50 "**AC** 22 (natural armor) **HP** 546 (28d20 + 252)\n\n"
51 "## Fire Breath\nThe dragon exhales fire in a 90-foot cone.\n"
52 )
53
54 return tmp_path
55
56
57@pytest.fixture
58def large_doc_content() -> str:
59 """A document large enough to trigger chunking."""
60 sections = ["# Equipment\n\nRules for adventuring gear.\n"]
61 for i in range(10):
62 section = f"\n## Section {i}\n\n"
63 section += f"This is section {i} with enough content to matter. " * 50
64 sections.append(section)
65 return "\n".join(sections)
66
67
68# --- Chunking ---
69
70
71class TestChunkDocument:
72 def test_small_file_single_chunk(self):
73 content = "# Fireball\n\nA big explosion spell.\n"
74 chunks = chunk_document(Path("fireball.md"), content)
75 assert len(chunks) == 1
76 assert chunks[0][0] == 0
77 assert "Fireball" in chunks[0][1]
78
79 def test_large_file_splits_on_h2(self, large_doc_content: str):
80 chunks = chunk_document(Path("equipment.md"), large_doc_content)
81 assert len(chunks) > 1
82 # Each chunk should have the title prepended for context
83 for _, text in chunks:
84 assert "Equipment" in text
85
86 def test_chunk_index_sequential(self, large_doc_content: str):
87 chunks = chunk_document(Path("equipment.md"), large_doc_content)
88 indices = [idx for idx, _ in chunks]
89 assert indices == list(range(len(chunks)))
90
91 def test_no_empty_chunks(self, large_doc_content: str):
92 chunks = chunk_document(Path("equipment.md"), large_doc_content)
93 for _, text in chunks:
94 assert text.strip()
95
96 def test_file_with_no_h1_still_works(self):
97 content = "Just some content without any headers at all.\n"
98 chunks = chunk_document(Path("notes.md"), content)
99 assert len(chunks) == 1
100 assert "some content" in chunks[0][1]
101
102 def test_splits_on_h4_headings(self):
103 content = "# Rogue\n\n" + "\n".join(
104 f"#### **Feature {i}**\n\n{'x ' * 400}\n" for i in range(5)
105 )
106 chunks = chunk_document(Path("rogue.md"), content)
107 assert len(chunks) >= 5
108
109 def test_splits_on_bold_definitions(self):
110 content = "# Rules Glossary\n\n## Definitions\n\n" + "\n".join(
111 f"**Term {i}**\nDefinition of term {i}. {'y ' * 400}\n" for i in range(5)
112 )
113 chunks = chunk_document(Path("glossary.md"), content)
114 assert len(chunks) >= 5
115
116 def test_class_features_split_by_level(self):
117 content = "# Fighter\n\n#### **Fighter Class Features**\n\n"
118 content += "**Level 1: Fighting Style**\n" + "Choose a style. " * 100 + "\n\n"
119 content += "**Level 1: Second Wind**\n" + "Heal yourself. " * 100 + "\n\n"
120 content += "**Level 2: Action Surge**\n" + "Extra action. " * 100 + "\n"
121 chunks = chunk_document(Path("fighter.md"), content)
122 texts = [t for _, t in chunks]
123 assert any("Second Wind" in t for t in texts)
124 assert any("Action Surge" in t for t in texts)
125
126 def test_oversized_with_no_headings_splits_on_paragraphs(self):
127 content = "# Blob\n\n" + "\n\n".join("word " * 200 for _ in range(5))
128 chunks = chunk_document(Path("blob.md"), content)
129 assert len(chunks) > 1
130 assert any("word" in text for _, text in chunks)
131
132 def test_chunks_get_title_context(self):
133 content = "# Rogue\n\n" + "\n".join(
134 f"#### **Section {i}**\n\n{'z ' * 400}\n" for i in range(5)
135 )
136 chunks = chunk_document(Path("rogue.md"), content)
137 for _, text in chunks[1:]:
138 assert text.startswith("# Rogue")
139
140
141# --- Age Decay ---
142
143
144class TestAgeDecay:
145 def test_same_day_no_decay(self):
146 assert age_decay(current_day=5, doc_day=5) == 1.0
147
148 def test_half_life_halves_score(self):
149 result = age_decay(current_day=8, doc_day=5, half_life=3)
150 assert abs(result - 0.5) < 1e-9
151
152 def test_two_half_lives(self):
153 result = age_decay(current_day=11, doc_day=5, half_life=3)
154 assert abs(result - 0.25) < 1e-9
155
156 def test_future_doc_no_boost(self):
157 result = age_decay(current_day=3, doc_day=5)
158 assert result == 1.0
159
160 def test_custom_half_life(self):
161 result = age_decay(current_day=10, doc_day=0, half_life=5)
162 assert abs(result - 0.25) < 1e-9
163
164
165# --- VectorIndex CRUD ---
166
167
168class TestVectorIndexUpsert:
169 def test_upsert_and_search(self, index: VectorIndex):
170 index.upsert(
171 "srd:spells/fireball.md:0",
172 "Fireball: 8d6 fire damage in a 20-foot radius",
173 {
174 "source": "srd",
175 "content_type": "spells",
176 "path": "/tmp/fireball.md",
177 "title": "Fireball",
178 },
179 )
180 results = index.search("fireball")
181 assert len(results) >= 1
182 assert results[0].doc_id == "srd:spells/fireball.md:0"
183
184 def test_upsert_overwrites(self, index: VectorIndex):
185 doc_id = "world:npcs/vex.md:0"
186 index.upsert(
187 doc_id,
188 "Captain Vex the pirate",
189 {
190 "source": "world",
191 "content_type": "npcs",
192 "path": "/tmp/vex.md",
193 "title": "Captain Vex",
194 },
195 )
196 index.upsert(
197 doc_id,
198 "Captain Vex the reformed merchant",
199 {
200 "source": "world",
201 "content_type": "npcs",
202 "path": "/tmp/vex.md",
203 "title": "Captain Vex",
204 },
205 )
206 stats = index.stats()
207 assert stats["total_documents"] == 1
208
209 def test_delete(self, index: VectorIndex):
210 doc_id = "srd:spells/fireball.md:0"
211 index.upsert(
212 doc_id,
213 "Fireball spell",
214 {
215 "source": "srd",
216 "content_type": "spells",
217 "path": "/tmp/fireball.md",
218 "title": "Fireball",
219 },
220 )
221 index.delete(doc_id)
222 stats = index.stats()
223 assert stats["total_documents"] == 0
224
225 def test_delete_nonexistent_is_noop(self, index: VectorIndex):
226 index.delete("nonexistent:doc:0")
227
228
229# --- Search ---
230
231
232class TestVectorIndexSearch:
233 @pytest.fixture(autouse=True)
234 def _populate(self, index: VectorIndex):
235 docs = [
236 (
237 "srd:spells/fireball.md:0",
238 "Fireball: 8d6 fire damage in a 20-foot radius sphere",
239 {
240 "source": "srd",
241 "content_type": "spells",
242 "path": "/tmp/spells/fireball.md",
243 "title": "Fireball",
244 },
245 ),
246 (
247 "srd:spells/cure-wounds.md:0",
248 "Cure Wounds: restore hit points by touch",
249 {
250 "source": "srd",
251 "content_type": "spells",
252 "path": "/tmp/spells/cure-wounds.md",
253 "title": "Cure Wounds",
254 },
255 ),
256 (
257 "world:npcs/vex.md:0",
258 "Captain Vex is a notorious pirate who sails the Shattered Coast",
259 {
260 "source": "world",
261 "content_type": "npcs",
262 "path": "/tmp/npcs/vex.md",
263 "title": "Captain Vex",
264 },
265 ),
266 (
267 "transcript:transcripts/day+001.md:0",
268 "Player asked about the harbor. DM described ships at dock.",
269 {
270 "source": "transcript",
271 "content_type": "transcripts",
272 "path": "/tmp/transcripts/day+001.md",
273 "title": "Day 1",
274 "game_day": 1,
275 },
276 ),
277 ]
278 for doc_id, text, meta in docs:
279 index.upsert(doc_id, text, meta)
280
281 def test_returns_search_hits(self, index: VectorIndex):
282 results = index.search("fire damage")
283 assert len(results) >= 1
284 assert all(isinstance(r, SearchHit) for r in results)
285
286 def test_limit(self, index: VectorIndex):
287 results = index.search("magic", limit=2)
288 assert len(results) <= 2
289
290 def test_source_filter(self, index: VectorIndex):
291 results = index.search("pirate", source_filter="world")
292 assert all(r.source == "world" for r in results)
293
294 def test_source_filter_excludes(self, index: VectorIndex):
295 results = index.search("fire", source_filter="world")
296 assert all(r.source == "world" for r in results)
297
298 def test_empty_results(self, index: VectorIndex):
299 # With fake embeddings, everything has some similarity, but
300 # we should still get results (the index isn't empty)
301 results = index.search("completely unrelated query xyz")
302 assert isinstance(results, list)
303
304 def test_search_hit_fields(self, index: VectorIndex):
305 results = index.search("fireball")
306 hit = results[0]
307 assert hit.path is not None
308 assert hit.source in ("srd", "world", "transcript", "player")
309 assert hit.content_type is not None
310 assert isinstance(hit.score, float)
311
312 def test_age_decay_applied(self, index: VectorIndex):
313 # Search with decay_ref far from day 1 transcript
314 results_no_decay = index.search("harbor ships")
315 results_decayed = index.search("harbor ships", decay_ref=100)
316 transcript_no_decay = [r for r in results_no_decay if r.source == "transcript"]
317 transcript_decayed = [r for r in results_decayed if r.source == "transcript"]
318 if transcript_no_decay and transcript_decayed:
319 assert transcript_decayed[0].score < transcript_no_decay[0].score
320
321
322# --- Reindex Directory ---
323
324
325class TestReindexDirectory:
326 def test_reindex_counts(self, index: VectorIndex, srd_tree: Path):
327 count = index.reindex_directory(srd_tree, source="srd")
328 assert count == 5 # 3 spells + 2 monsters
329
330 def test_reindex_searchable(self, index: VectorIndex, srd_tree: Path):
331 index.reindex_directory(srd_tree, source="srd")
332 results = index.search("goblin")
333 assert len(results) >= 1
334
335 def test_reindex_idempotent(self, index: VectorIndex, srd_tree: Path):
336 index.reindex_directory(srd_tree, source="srd")
337 index.reindex_directory(srd_tree, source="srd")
338 stats = index.stats()
339 assert stats["total_documents"] == 5
340
341 def test_reindex_updates_changed_files(self, index: VectorIndex, srd_tree: Path):
342 index.reindex_directory(srd_tree, source="srd")
343 # Modify a file
344 fireball = srd_tree / "spells" / "fireball.md"
345 fireball.write_text("# Fireball\n\nNow deals 10d6 fire damage!\n")
346 index.reindex_directory(srd_tree, source="srd")
347 # Should still have 5 docs total (updated, not duplicated)
348 stats = index.stats()
349 assert stats["total_documents"] == 5
350
351
352# --- Stats ---
353
354
355class TestStats:
356 def test_empty_index(self, index: VectorIndex):
357 stats = index.stats()
358 assert stats["total_documents"] == 0
359
360 def test_stats_by_source(self, index: VectorIndex):
361 index.upsert(
362 "srd:a:0",
363 "text a",
364 {"source": "srd", "content_type": "spells", "path": "/a", "title": "A"},
365 )
366 index.upsert(
367 "world:b:0",
368 "text b",
369 {"source": "world", "content_type": "npcs", "path": "/b", "title": "B"},
370 )
371 stats = index.stats()
372 assert stats["total_documents"] == 2
373 assert stats["by_source"]["srd"] == 1
374 assert stats["by_source"]["world"] == 1
375
376
377# --- Corrupt / Missing DB ---
378
379
380class TestReindexOnCorrupt:
381 def test_missing_db_creates_fresh(self, tmp_path: Path):
382 db_path = tmp_path / "nonexistent" / "test.db"
383 idx = VectorIndex(db_path)
384 idx._embed_fn = _fake_embed
385 idx.upsert(
386 "test:a:0",
387 "hello",
388 {"source": "test", "content_type": "misc", "path": "/a", "title": "A"},
389 )
390 assert idx.stats()["total_documents"] == 1
391
392 def test_corrupt_db_recreates(self, tmp_path: Path):
393 db_path = tmp_path / "corrupt.db"
394 db_path.write_bytes(b"this is not a sqlite database")
395 idx = VectorIndex(db_path)
396 idx._embed_fn = _fake_embed
397 idx.upsert(
398 "test:a:0",
399 "hello",
400 {"source": "test", "content_type": "misc", "path": "/a", "title": "A"},
401 )
402 assert idx.stats()["total_documents"] == 1
403
404
405# --- Seed From ---
406
407
408class TestSeedFrom:
409 def test_seed_copies_documents(self, index: VectorIndex, tmp_path: Path):
410 index.upsert(
411 "srd:spells/fireball.md:0",
412 "Fireball spell",
413 {
414 "source": "srd",
415 "content_type": "spells",
416 "path": "/tmp/fireball.md",
417 "title": "Fireball",
418 },
419 )
420 index.upsert(
421 "srd:monsters/goblin.md:0",
422 "Goblin monster",
423 {
424 "source": "srd",
425 "content_type": "monsters",
426 "path": "/tmp/goblin.md",
427 "title": "Goblin",
428 },
429 )
430 index.close()
431
432 dest = tmp_path / "world" / "search.db"
433 seeded = VectorIndex.seed_from(index._db_path, dest)
434 seeded._embed_fn = _fake_embed
435 assert seeded.stats()["total_documents"] == 2
436 seeded.close()
437
438 def test_seed_allows_additional_upserts(self, index: VectorIndex, tmp_path: Path):
439 index.upsert(
440 "srd:spells/fireball.md:0",
441 "Fireball spell",
442 {
443 "source": "srd",
444 "content_type": "spells",
445 "path": "/tmp/fireball.md",
446 "title": "Fireball",
447 },
448 )
449 index.close()
450
451 dest = tmp_path / "world" / "search.db"
452 seeded = VectorIndex.seed_from(index._db_path, dest)
453 seeded._embed_fn = _fake_embed
454 seeded.upsert(
455 "world:npcs/vex.md:0",
456 "Captain Vex",
457 {
458 "source": "world",
459 "content_type": "npcs",
460 "path": "/tmp/vex.md",
461 "title": "Vex",
462 },
463 )
464 assert seeded.stats()["total_documents"] == 2
465 assert seeded.stats()["by_source"]["srd"] == 1
466 assert seeded.stats()["by_source"]["world"] == 1
467 seeded.close()
468
469
470class TestThreadSafety:
471 """Tests for cross-thread access (MCP server runs on a background thread)."""
472
473 def test_search_from_different_thread(self, index: VectorIndex):
474 index.upsert(
475 "world:npcs/vex.md:0",
476 "Captain Vex, harbor master",
477 {
478 "source": "world",
479 "content_type": "npcs",
480 "path": "/tmp/vex.md",
481 "title": "Captain Vex",
482 },
483 )
484
485 import threading
486
487 results: list[list[SearchHit]] = []
488 error: list[Exception] = []
489
490 def search_on_thread():
491 try:
492 results.append(index.search("harbor master"))
493 except Exception as e:
494 error.append(e)
495
496 t = threading.Thread(target=search_on_thread)
497 t.start()
498 t.join()
499
500 assert not error, f"Cross-thread search failed: {error[0]}"
501 assert len(results[0]) == 1