A 5e storytelling engine with an LLM DM
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 501 lines 17 kB view raw
1"""Tests for vector search index.""" 2 3from pathlib import Path 4 5import pytest 6from conftest import _fake_embed 7 8from storied.search import SearchHit, VectorIndex, age_decay, chunk_document 9 10 11@pytest.fixture 12def index(tmp_path: Path): 13 """VectorIndex with fake embedder (no model download).""" 14 db_path = tmp_path / "test.search.db" 15 idx = VectorIndex(db_path) 16 idx._embed_fn = _fake_embed 17 yield idx 18 idx.close() 19 20 21@pytest.fixture 22def srd_tree(tmp_path: Path) -> Path: 23 """Small SRD-like directory tree for reindex tests.""" 24 spells = tmp_path / "spells" 25 spells.mkdir() 26 (spells / "fireball.md").write_text( 27 "# Fireball\n\n_Level 3 Evocation (Sorcerer, Wizard)_\n\n" 28 "Each creature in a 20-foot-radius sphere must make a Dexterity " 29 "saving throw. A target takes 8d6 fire damage on a failed save.\n" 30 ) 31 (spells / "magic-missile.md").write_text( 32 "# Magic Missile\n\n_Level 1 Evocation (Sorcerer, Wizard)_\n\n" 33 "You create three glowing darts of magical force. Each dart hits " 34 "automatically and deals 1d4+1 force damage.\n" 35 ) 36 (spells / "cure-wounds.md").write_text( 37 "# Cure Wounds\n\n_Level 1 Abjuration (Cleric, Druid)_\n\n" 38 "A creature you touch regains hit points equal to 1d8 + your " 39 "spellcasting ability modifier.\n" 40 ) 41 42 monsters = tmp_path / "monsters" 43 monsters.mkdir() 44 (monsters / "goblin.md").write_text( 45 "# Goblin\n\n_Small Humanoid, Neutral Evil_\n\n" 46 "**AC** 15 (leather armor, shield) **HP** 7 (2d6)\n" 47 ) 48 (monsters / "ancient-red-dragon.md").write_text( 49 "# Ancient Red Dragon\n\n_Gargantuan Dragon, Chaotic Evil_\n\n" 50 "**AC** 22 (natural armor) **HP** 546 (28d20 + 252)\n\n" 51 "## Fire Breath\nThe dragon exhales fire in a 90-foot cone.\n" 52 ) 53 54 return tmp_path 55 56 57@pytest.fixture 58def large_doc_content() -> str: 59 """A document large enough to trigger chunking.""" 60 sections = ["# Equipment\n\nRules for adventuring gear.\n"] 61 for i in range(10): 62 section = f"\n## Section {i}\n\n" 63 section += f"This is section {i} with enough content to matter. " * 50 64 sections.append(section) 65 return "\n".join(sections) 66 67 68# --- Chunking --- 69 70 71class TestChunkDocument: 72 def test_small_file_single_chunk(self): 73 content = "# Fireball\n\nA big explosion spell.\n" 74 chunks = chunk_document(Path("fireball.md"), content) 75 assert len(chunks) == 1 76 assert chunks[0][0] == 0 77 assert "Fireball" in chunks[0][1] 78 79 def test_large_file_splits_on_h2(self, large_doc_content: str): 80 chunks = chunk_document(Path("equipment.md"), large_doc_content) 81 assert len(chunks) > 1 82 # Each chunk should have the title prepended for context 83 for _, text in chunks: 84 assert "Equipment" in text 85 86 def test_chunk_index_sequential(self, large_doc_content: str): 87 chunks = chunk_document(Path("equipment.md"), large_doc_content) 88 indices = [idx for idx, _ in chunks] 89 assert indices == list(range(len(chunks))) 90 91 def test_no_empty_chunks(self, large_doc_content: str): 92 chunks = chunk_document(Path("equipment.md"), large_doc_content) 93 for _, text in chunks: 94 assert text.strip() 95 96 def test_file_with_no_h1_still_works(self): 97 content = "Just some content without any headers at all.\n" 98 chunks = chunk_document(Path("notes.md"), content) 99 assert len(chunks) == 1 100 assert "some content" in chunks[0][1] 101 102 def test_splits_on_h4_headings(self): 103 content = "# Rogue\n\n" + "\n".join( 104 f"#### **Feature {i}**\n\n{'x ' * 400}\n" for i in range(5) 105 ) 106 chunks = chunk_document(Path("rogue.md"), content) 107 assert len(chunks) >= 5 108 109 def test_splits_on_bold_definitions(self): 110 content = "# Rules Glossary\n\n## Definitions\n\n" + "\n".join( 111 f"**Term {i}**\nDefinition of term {i}. {'y ' * 400}\n" for i in range(5) 112 ) 113 chunks = chunk_document(Path("glossary.md"), content) 114 assert len(chunks) >= 5 115 116 def test_class_features_split_by_level(self): 117 content = "# Fighter\n\n#### **Fighter Class Features**\n\n" 118 content += "**Level 1: Fighting Style**\n" + "Choose a style. " * 100 + "\n\n" 119 content += "**Level 1: Second Wind**\n" + "Heal yourself. " * 100 + "\n\n" 120 content += "**Level 2: Action Surge**\n" + "Extra action. " * 100 + "\n" 121 chunks = chunk_document(Path("fighter.md"), content) 122 texts = [t for _, t in chunks] 123 assert any("Second Wind" in t for t in texts) 124 assert any("Action Surge" in t for t in texts) 125 126 def test_oversized_with_no_headings_splits_on_paragraphs(self): 127 content = "# Blob\n\n" + "\n\n".join("word " * 200 for _ in range(5)) 128 chunks = chunk_document(Path("blob.md"), content) 129 assert len(chunks) > 1 130 assert any("word" in text for _, text in chunks) 131 132 def test_chunks_get_title_context(self): 133 content = "# Rogue\n\n" + "\n".join( 134 f"#### **Section {i}**\n\n{'z ' * 400}\n" for i in range(5) 135 ) 136 chunks = chunk_document(Path("rogue.md"), content) 137 for _, text in chunks[1:]: 138 assert text.startswith("# Rogue") 139 140 141# --- Age Decay --- 142 143 144class TestAgeDecay: 145 def test_same_day_no_decay(self): 146 assert age_decay(current_day=5, doc_day=5) == 1.0 147 148 def test_half_life_halves_score(self): 149 result = age_decay(current_day=8, doc_day=5, half_life=3) 150 assert abs(result - 0.5) < 1e-9 151 152 def test_two_half_lives(self): 153 result = age_decay(current_day=11, doc_day=5, half_life=3) 154 assert abs(result - 0.25) < 1e-9 155 156 def test_future_doc_no_boost(self): 157 result = age_decay(current_day=3, doc_day=5) 158 assert result == 1.0 159 160 def test_custom_half_life(self): 161 result = age_decay(current_day=10, doc_day=0, half_life=5) 162 assert abs(result - 0.25) < 1e-9 163 164 165# --- VectorIndex CRUD --- 166 167 168class TestVectorIndexUpsert: 169 def test_upsert_and_search(self, index: VectorIndex): 170 index.upsert( 171 "srd:spells/fireball.md:0", 172 "Fireball: 8d6 fire damage in a 20-foot radius", 173 { 174 "source": "srd", 175 "content_type": "spells", 176 "path": "/tmp/fireball.md", 177 "title": "Fireball", 178 }, 179 ) 180 results = index.search("fireball") 181 assert len(results) >= 1 182 assert results[0].doc_id == "srd:spells/fireball.md:0" 183 184 def test_upsert_overwrites(self, index: VectorIndex): 185 doc_id = "world:npcs/vex.md:0" 186 index.upsert( 187 doc_id, 188 "Captain Vex the pirate", 189 { 190 "source": "world", 191 "content_type": "npcs", 192 "path": "/tmp/vex.md", 193 "title": "Captain Vex", 194 }, 195 ) 196 index.upsert( 197 doc_id, 198 "Captain Vex the reformed merchant", 199 { 200 "source": "world", 201 "content_type": "npcs", 202 "path": "/tmp/vex.md", 203 "title": "Captain Vex", 204 }, 205 ) 206 stats = index.stats() 207 assert stats["total_documents"] == 1 208 209 def test_delete(self, index: VectorIndex): 210 doc_id = "srd:spells/fireball.md:0" 211 index.upsert( 212 doc_id, 213 "Fireball spell", 214 { 215 "source": "srd", 216 "content_type": "spells", 217 "path": "/tmp/fireball.md", 218 "title": "Fireball", 219 }, 220 ) 221 index.delete(doc_id) 222 stats = index.stats() 223 assert stats["total_documents"] == 0 224 225 def test_delete_nonexistent_is_noop(self, index: VectorIndex): 226 index.delete("nonexistent:doc:0") 227 228 229# --- Search --- 230 231 232class TestVectorIndexSearch: 233 @pytest.fixture(autouse=True) 234 def _populate(self, index: VectorIndex): 235 docs = [ 236 ( 237 "srd:spells/fireball.md:0", 238 "Fireball: 8d6 fire damage in a 20-foot radius sphere", 239 { 240 "source": "srd", 241 "content_type": "spells", 242 "path": "/tmp/spells/fireball.md", 243 "title": "Fireball", 244 }, 245 ), 246 ( 247 "srd:spells/cure-wounds.md:0", 248 "Cure Wounds: restore hit points by touch", 249 { 250 "source": "srd", 251 "content_type": "spells", 252 "path": "/tmp/spells/cure-wounds.md", 253 "title": "Cure Wounds", 254 }, 255 ), 256 ( 257 "world:npcs/vex.md:0", 258 "Captain Vex is a notorious pirate who sails the Shattered Coast", 259 { 260 "source": "world", 261 "content_type": "npcs", 262 "path": "/tmp/npcs/vex.md", 263 "title": "Captain Vex", 264 }, 265 ), 266 ( 267 "transcript:transcripts/day+001.md:0", 268 "Player asked about the harbor. DM described ships at dock.", 269 { 270 "source": "transcript", 271 "content_type": "transcripts", 272 "path": "/tmp/transcripts/day+001.md", 273 "title": "Day 1", 274 "game_day": 1, 275 }, 276 ), 277 ] 278 for doc_id, text, meta in docs: 279 index.upsert(doc_id, text, meta) 280 281 def test_returns_search_hits(self, index: VectorIndex): 282 results = index.search("fire damage") 283 assert len(results) >= 1 284 assert all(isinstance(r, SearchHit) for r in results) 285 286 def test_limit(self, index: VectorIndex): 287 results = index.search("magic", limit=2) 288 assert len(results) <= 2 289 290 def test_source_filter(self, index: VectorIndex): 291 results = index.search("pirate", source_filter="world") 292 assert all(r.source == "world" for r in results) 293 294 def test_source_filter_excludes(self, index: VectorIndex): 295 results = index.search("fire", source_filter="world") 296 assert all(r.source == "world" for r in results) 297 298 def test_empty_results(self, index: VectorIndex): 299 # With fake embeddings, everything has some similarity, but 300 # we should still get results (the index isn't empty) 301 results = index.search("completely unrelated query xyz") 302 assert isinstance(results, list) 303 304 def test_search_hit_fields(self, index: VectorIndex): 305 results = index.search("fireball") 306 hit = results[0] 307 assert hit.path is not None 308 assert hit.source in ("srd", "world", "transcript", "player") 309 assert hit.content_type is not None 310 assert isinstance(hit.score, float) 311 312 def test_age_decay_applied(self, index: VectorIndex): 313 # Search with decay_ref far from day 1 transcript 314 results_no_decay = index.search("harbor ships") 315 results_decayed = index.search("harbor ships", decay_ref=100) 316 transcript_no_decay = [r for r in results_no_decay if r.source == "transcript"] 317 transcript_decayed = [r for r in results_decayed if r.source == "transcript"] 318 if transcript_no_decay and transcript_decayed: 319 assert transcript_decayed[0].score < transcript_no_decay[0].score 320 321 322# --- Reindex Directory --- 323 324 325class TestReindexDirectory: 326 def test_reindex_counts(self, index: VectorIndex, srd_tree: Path): 327 count = index.reindex_directory(srd_tree, source="srd") 328 assert count == 5 # 3 spells + 2 monsters 329 330 def test_reindex_searchable(self, index: VectorIndex, srd_tree: Path): 331 index.reindex_directory(srd_tree, source="srd") 332 results = index.search("goblin") 333 assert len(results) >= 1 334 335 def test_reindex_idempotent(self, index: VectorIndex, srd_tree: Path): 336 index.reindex_directory(srd_tree, source="srd") 337 index.reindex_directory(srd_tree, source="srd") 338 stats = index.stats() 339 assert stats["total_documents"] == 5 340 341 def test_reindex_updates_changed_files(self, index: VectorIndex, srd_tree: Path): 342 index.reindex_directory(srd_tree, source="srd") 343 # Modify a file 344 fireball = srd_tree / "spells" / "fireball.md" 345 fireball.write_text("# Fireball\n\nNow deals 10d6 fire damage!\n") 346 index.reindex_directory(srd_tree, source="srd") 347 # Should still have 5 docs total (updated, not duplicated) 348 stats = index.stats() 349 assert stats["total_documents"] == 5 350 351 352# --- Stats --- 353 354 355class TestStats: 356 def test_empty_index(self, index: VectorIndex): 357 stats = index.stats() 358 assert stats["total_documents"] == 0 359 360 def test_stats_by_source(self, index: VectorIndex): 361 index.upsert( 362 "srd:a:0", 363 "text a", 364 {"source": "srd", "content_type": "spells", "path": "/a", "title": "A"}, 365 ) 366 index.upsert( 367 "world:b:0", 368 "text b", 369 {"source": "world", "content_type": "npcs", "path": "/b", "title": "B"}, 370 ) 371 stats = index.stats() 372 assert stats["total_documents"] == 2 373 assert stats["by_source"]["srd"] == 1 374 assert stats["by_source"]["world"] == 1 375 376 377# --- Corrupt / Missing DB --- 378 379 380class TestReindexOnCorrupt: 381 def test_missing_db_creates_fresh(self, tmp_path: Path): 382 db_path = tmp_path / "nonexistent" / "test.db" 383 idx = VectorIndex(db_path) 384 idx._embed_fn = _fake_embed 385 idx.upsert( 386 "test:a:0", 387 "hello", 388 {"source": "test", "content_type": "misc", "path": "/a", "title": "A"}, 389 ) 390 assert idx.stats()["total_documents"] == 1 391 392 def test_corrupt_db_recreates(self, tmp_path: Path): 393 db_path = tmp_path / "corrupt.db" 394 db_path.write_bytes(b"this is not a sqlite database") 395 idx = VectorIndex(db_path) 396 idx._embed_fn = _fake_embed 397 idx.upsert( 398 "test:a:0", 399 "hello", 400 {"source": "test", "content_type": "misc", "path": "/a", "title": "A"}, 401 ) 402 assert idx.stats()["total_documents"] == 1 403 404 405# --- Seed From --- 406 407 408class TestSeedFrom: 409 def test_seed_copies_documents(self, index: VectorIndex, tmp_path: Path): 410 index.upsert( 411 "srd:spells/fireball.md:0", 412 "Fireball spell", 413 { 414 "source": "srd", 415 "content_type": "spells", 416 "path": "/tmp/fireball.md", 417 "title": "Fireball", 418 }, 419 ) 420 index.upsert( 421 "srd:monsters/goblin.md:0", 422 "Goblin monster", 423 { 424 "source": "srd", 425 "content_type": "monsters", 426 "path": "/tmp/goblin.md", 427 "title": "Goblin", 428 }, 429 ) 430 index.close() 431 432 dest = tmp_path / "world" / "search.db" 433 seeded = VectorIndex.seed_from(index._db_path, dest) 434 seeded._embed_fn = _fake_embed 435 assert seeded.stats()["total_documents"] == 2 436 seeded.close() 437 438 def test_seed_allows_additional_upserts(self, index: VectorIndex, tmp_path: Path): 439 index.upsert( 440 "srd:spells/fireball.md:0", 441 "Fireball spell", 442 { 443 "source": "srd", 444 "content_type": "spells", 445 "path": "/tmp/fireball.md", 446 "title": "Fireball", 447 }, 448 ) 449 index.close() 450 451 dest = tmp_path / "world" / "search.db" 452 seeded = VectorIndex.seed_from(index._db_path, dest) 453 seeded._embed_fn = _fake_embed 454 seeded.upsert( 455 "world:npcs/vex.md:0", 456 "Captain Vex", 457 { 458 "source": "world", 459 "content_type": "npcs", 460 "path": "/tmp/vex.md", 461 "title": "Vex", 462 }, 463 ) 464 assert seeded.stats()["total_documents"] == 2 465 assert seeded.stats()["by_source"]["srd"] == 1 466 assert seeded.stats()["by_source"]["world"] == 1 467 seeded.close() 468 469 470class TestThreadSafety: 471 """Tests for cross-thread access (MCP server runs on a background thread).""" 472 473 def test_search_from_different_thread(self, index: VectorIndex): 474 index.upsert( 475 "world:npcs/vex.md:0", 476 "Captain Vex, harbor master", 477 { 478 "source": "world", 479 "content_type": "npcs", 480 "path": "/tmp/vex.md", 481 "title": "Captain Vex", 482 }, 483 ) 484 485 import threading 486 487 results: list[list[SearchHit]] = [] 488 error: list[Exception] = [] 489 490 def search_on_thread(): 491 try: 492 results.append(index.search("harbor master")) 493 except Exception as e: 494 error.append(e) 495 496 t = threading.Thread(target=search_on_thread) 497 t.start() 498 t.join() 499 500 assert not error, f"Cross-thread search failed: {error[0]}" 501 assert len(results[0]) == 1