"""Tests for vector search index.""" from pathlib import Path import pytest from conftest import _fake_embed from storied.search import SearchHit, VectorIndex, age_decay, chunk_document @pytest.fixture def index(tmp_path: Path): """VectorIndex with fake embedder (no model download).""" db_path = tmp_path / "test.search.db" idx = VectorIndex(db_path) idx._embed_fn = _fake_embed yield idx idx.close() @pytest.fixture def srd_tree(tmp_path: Path) -> Path: """Small SRD-like directory tree for reindex tests.""" spells = tmp_path / "spells" spells.mkdir() (spells / "fireball.md").write_text( "# Fireball\n\n_Level 3 Evocation (Sorcerer, Wizard)_\n\n" "Each creature in a 20-foot-radius sphere must make a Dexterity " "saving throw. A target takes 8d6 fire damage on a failed save.\n" ) (spells / "magic-missile.md").write_text( "# Magic Missile\n\n_Level 1 Evocation (Sorcerer, Wizard)_\n\n" "You create three glowing darts of magical force. Each dart hits " "automatically and deals 1d4+1 force damage.\n" ) (spells / "cure-wounds.md").write_text( "# Cure Wounds\n\n_Level 1 Abjuration (Cleric, Druid)_\n\n" "A creature you touch regains hit points equal to 1d8 + your " "spellcasting ability modifier.\n" ) monsters = tmp_path / "monsters" monsters.mkdir() (monsters / "goblin.md").write_text( "# Goblin\n\n_Small Humanoid, Neutral Evil_\n\n" "**AC** 15 (leather armor, shield) **HP** 7 (2d6)\n" ) (monsters / "ancient-red-dragon.md").write_text( "# Ancient Red Dragon\n\n_Gargantuan Dragon, Chaotic Evil_\n\n" "**AC** 22 (natural armor) **HP** 546 (28d20 + 252)\n\n" "## Fire Breath\nThe dragon exhales fire in a 90-foot cone.\n" ) return tmp_path @pytest.fixture def large_doc_content() -> str: """A document large enough to trigger chunking.""" sections = ["# Equipment\n\nRules for adventuring gear.\n"] for i in range(10): section = f"\n## Section {i}\n\n" section += f"This is section {i} with enough content to matter. " * 50 sections.append(section) return "\n".join(sections) # --- Chunking --- class TestChunkDocument: def test_small_file_single_chunk(self): content = "# Fireball\n\nA big explosion spell.\n" chunks = chunk_document(Path("fireball.md"), content) assert len(chunks) == 1 assert chunks[0][0] == 0 assert "Fireball" in chunks[0][1] def test_large_file_splits_on_h2(self, large_doc_content: str): chunks = chunk_document(Path("equipment.md"), large_doc_content) assert len(chunks) > 1 # Each chunk should have the title prepended for context for _, text in chunks: assert "Equipment" in text def test_chunk_index_sequential(self, large_doc_content: str): chunks = chunk_document(Path("equipment.md"), large_doc_content) indices = [idx for idx, _ in chunks] assert indices == list(range(len(chunks))) def test_no_empty_chunks(self, large_doc_content: str): chunks = chunk_document(Path("equipment.md"), large_doc_content) for _, text in chunks: assert text.strip() def test_file_with_no_h1_still_works(self): content = "Just some content without any headers at all.\n" chunks = chunk_document(Path("notes.md"), content) assert len(chunks) == 1 assert "some content" in chunks[0][1] def test_splits_on_h4_headings(self): content = "# Rogue\n\n" + "\n".join( f"#### **Feature {i}**\n\n{'x ' * 400}\n" for i in range(5) ) chunks = chunk_document(Path("rogue.md"), content) assert len(chunks) >= 5 def test_splits_on_bold_definitions(self): content = "# Rules Glossary\n\n## Definitions\n\n" + "\n".join( f"**Term {i}**\nDefinition of term {i}. {'y ' * 400}\n" for i in range(5) ) chunks = chunk_document(Path("glossary.md"), content) assert len(chunks) >= 5 def test_class_features_split_by_level(self): content = "# Fighter\n\n#### **Fighter Class Features**\n\n" content += "**Level 1: Fighting Style**\n" + "Choose a style. " * 100 + "\n\n" content += "**Level 1: Second Wind**\n" + "Heal yourself. " * 100 + "\n\n" content += "**Level 2: Action Surge**\n" + "Extra action. " * 100 + "\n" chunks = chunk_document(Path("fighter.md"), content) texts = [t for _, t in chunks] assert any("Second Wind" in t for t in texts) assert any("Action Surge" in t for t in texts) def test_oversized_with_no_headings_splits_on_paragraphs(self): content = "# Blob\n\n" + "\n\n".join("word " * 200 for _ in range(5)) chunks = chunk_document(Path("blob.md"), content) assert len(chunks) > 1 assert any("word" in text for _, text in chunks) def test_chunks_get_title_context(self): content = "# Rogue\n\n" + "\n".join( f"#### **Section {i}**\n\n{'z ' * 400}\n" for i in range(5) ) chunks = chunk_document(Path("rogue.md"), content) for _, text in chunks[1:]: assert text.startswith("# Rogue") # --- Age Decay --- class TestAgeDecay: def test_same_day_no_decay(self): assert age_decay(current_day=5, doc_day=5) == 1.0 def test_half_life_halves_score(self): result = age_decay(current_day=8, doc_day=5, half_life=3) assert abs(result - 0.5) < 1e-9 def test_two_half_lives(self): result = age_decay(current_day=11, doc_day=5, half_life=3) assert abs(result - 0.25) < 1e-9 def test_future_doc_no_boost(self): result = age_decay(current_day=3, doc_day=5) assert result == 1.0 def test_custom_half_life(self): result = age_decay(current_day=10, doc_day=0, half_life=5) assert abs(result - 0.25) < 1e-9 # --- VectorIndex CRUD --- class TestVectorIndexUpsert: def test_upsert_and_search(self, index: VectorIndex): index.upsert( "srd:spells/fireball.md:0", "Fireball: 8d6 fire damage in a 20-foot radius", { "source": "srd", "content_type": "spells", "path": "/tmp/fireball.md", "title": "Fireball", }, ) results = index.search("fireball") assert len(results) >= 1 assert results[0].doc_id == "srd:spells/fireball.md:0" def test_upsert_overwrites(self, index: VectorIndex): doc_id = "world:npcs/vex.md:0" index.upsert( doc_id, "Captain Vex the pirate", { "source": "world", "content_type": "npcs", "path": "/tmp/vex.md", "title": "Captain Vex", }, ) index.upsert( doc_id, "Captain Vex the reformed merchant", { "source": "world", "content_type": "npcs", "path": "/tmp/vex.md", "title": "Captain Vex", }, ) stats = index.stats() assert stats["total_documents"] == 1 def test_delete(self, index: VectorIndex): doc_id = "srd:spells/fireball.md:0" index.upsert( doc_id, "Fireball spell", { "source": "srd", "content_type": "spells", "path": "/tmp/fireball.md", "title": "Fireball", }, ) index.delete(doc_id) stats = index.stats() assert stats["total_documents"] == 0 def test_delete_nonexistent_is_noop(self, index: VectorIndex): index.delete("nonexistent:doc:0") # --- Search --- class TestVectorIndexSearch: @pytest.fixture(autouse=True) def _populate(self, index: VectorIndex): docs = [ ( "srd:spells/fireball.md:0", "Fireball: 8d6 fire damage in a 20-foot radius sphere", { "source": "srd", "content_type": "spells", "path": "/tmp/spells/fireball.md", "title": "Fireball", }, ), ( "srd:spells/cure-wounds.md:0", "Cure Wounds: restore hit points by touch", { "source": "srd", "content_type": "spells", "path": "/tmp/spells/cure-wounds.md", "title": "Cure Wounds", }, ), ( "world:npcs/vex.md:0", "Captain Vex is a notorious pirate who sails the Shattered Coast", { "source": "world", "content_type": "npcs", "path": "/tmp/npcs/vex.md", "title": "Captain Vex", }, ), ( "transcript:transcripts/day+001.md:0", "Player asked about the harbor. DM described ships at dock.", { "source": "transcript", "content_type": "transcripts", "path": "/tmp/transcripts/day+001.md", "title": "Day 1", "game_day": 1, }, ), ] for doc_id, text, meta in docs: index.upsert(doc_id, text, meta) def test_returns_search_hits(self, index: VectorIndex): results = index.search("fire damage") assert len(results) >= 1 assert all(isinstance(r, SearchHit) for r in results) def test_limit(self, index: VectorIndex): results = index.search("magic", limit=2) assert len(results) <= 2 def test_source_filter(self, index: VectorIndex): results = index.search("pirate", source_filter="world") assert all(r.source == "world" for r in results) def test_source_filter_excludes(self, index: VectorIndex): results = index.search("fire", source_filter="world") assert all(r.source == "world" for r in results) def test_empty_results(self, index: VectorIndex): # With fake embeddings, everything has some similarity, but # we should still get results (the index isn't empty) results = index.search("completely unrelated query xyz") assert isinstance(results, list) def test_search_hit_fields(self, index: VectorIndex): results = index.search("fireball") hit = results[0] assert hit.path is not None assert hit.source in ("srd", "world", "transcript", "player") assert hit.content_type is not None assert isinstance(hit.score, float) def test_age_decay_applied(self, index: VectorIndex): # Search with decay_ref far from day 1 transcript results_no_decay = index.search("harbor ships") results_decayed = index.search("harbor ships", decay_ref=100) transcript_no_decay = [r for r in results_no_decay if r.source == "transcript"] transcript_decayed = [r for r in results_decayed if r.source == "transcript"] if transcript_no_decay and transcript_decayed: assert transcript_decayed[0].score < transcript_no_decay[0].score # --- Reindex Directory --- class TestReindexDirectory: def test_reindex_counts(self, index: VectorIndex, srd_tree: Path): count = index.reindex_directory(srd_tree, source="srd") assert count == 5 # 3 spells + 2 monsters def test_reindex_searchable(self, index: VectorIndex, srd_tree: Path): index.reindex_directory(srd_tree, source="srd") results = index.search("goblin") assert len(results) >= 1 def test_reindex_idempotent(self, index: VectorIndex, srd_tree: Path): index.reindex_directory(srd_tree, source="srd") index.reindex_directory(srd_tree, source="srd") stats = index.stats() assert stats["total_documents"] == 5 def test_reindex_updates_changed_files(self, index: VectorIndex, srd_tree: Path): index.reindex_directory(srd_tree, source="srd") # Modify a file fireball = srd_tree / "spells" / "fireball.md" fireball.write_text("# Fireball\n\nNow deals 10d6 fire damage!\n") index.reindex_directory(srd_tree, source="srd") # Should still have 5 docs total (updated, not duplicated) stats = index.stats() assert stats["total_documents"] == 5 # --- Stats --- class TestStats: def test_empty_index(self, index: VectorIndex): stats = index.stats() assert stats["total_documents"] == 0 def test_stats_by_source(self, index: VectorIndex): index.upsert( "srd:a:0", "text a", {"source": "srd", "content_type": "spells", "path": "/a", "title": "A"}, ) index.upsert( "world:b:0", "text b", {"source": "world", "content_type": "npcs", "path": "/b", "title": "B"}, ) stats = index.stats() assert stats["total_documents"] == 2 assert stats["by_source"]["srd"] == 1 assert stats["by_source"]["world"] == 1 # --- Corrupt / Missing DB --- class TestReindexOnCorrupt: def test_missing_db_creates_fresh(self, tmp_path: Path): db_path = tmp_path / "nonexistent" / "test.db" idx = VectorIndex(db_path) idx._embed_fn = _fake_embed idx.upsert( "test:a:0", "hello", {"source": "test", "content_type": "misc", "path": "/a", "title": "A"}, ) assert idx.stats()["total_documents"] == 1 def test_corrupt_db_recreates(self, tmp_path: Path): db_path = tmp_path / "corrupt.db" db_path.write_bytes(b"this is not a sqlite database") idx = VectorIndex(db_path) idx._embed_fn = _fake_embed idx.upsert( "test:a:0", "hello", {"source": "test", "content_type": "misc", "path": "/a", "title": "A"}, ) assert idx.stats()["total_documents"] == 1 # --- Seed From --- class TestSeedFrom: def test_seed_copies_documents(self, index: VectorIndex, tmp_path: Path): index.upsert( "srd:spells/fireball.md:0", "Fireball spell", { "source": "srd", "content_type": "spells", "path": "/tmp/fireball.md", "title": "Fireball", }, ) index.upsert( "srd:monsters/goblin.md:0", "Goblin monster", { "source": "srd", "content_type": "monsters", "path": "/tmp/goblin.md", "title": "Goblin", }, ) index.close() dest = tmp_path / "world" / "search.db" seeded = VectorIndex.seed_from(index._db_path, dest) seeded._embed_fn = _fake_embed assert seeded.stats()["total_documents"] == 2 seeded.close() def test_seed_allows_additional_upserts(self, index: VectorIndex, tmp_path: Path): index.upsert( "srd:spells/fireball.md:0", "Fireball spell", { "source": "srd", "content_type": "spells", "path": "/tmp/fireball.md", "title": "Fireball", }, ) index.close() dest = tmp_path / "world" / "search.db" seeded = VectorIndex.seed_from(index._db_path, dest) seeded._embed_fn = _fake_embed seeded.upsert( "world:npcs/vex.md:0", "Captain Vex", { "source": "world", "content_type": "npcs", "path": "/tmp/vex.md", "title": "Vex", }, ) assert seeded.stats()["total_documents"] == 2 assert seeded.stats()["by_source"]["srd"] == 1 assert seeded.stats()["by_source"]["world"] == 1 seeded.close() class TestThreadSafety: """Tests for cross-thread access (MCP server runs on a background thread).""" def test_search_from_different_thread(self, index: VectorIndex): index.upsert( "world:npcs/vex.md:0", "Captain Vex, harbor master", { "source": "world", "content_type": "npcs", "path": "/tmp/vex.md", "title": "Captain Vex", }, ) import threading results: list[list[SearchHit]] = [] error: list[Exception] = [] def search_on_thread(): try: results.append(index.search("harbor master")) except Exception as e: error.append(e) t = threading.Thread(target=search_on_thread) t.start() t.join() assert not error, f"Cross-thread search failed: {error[0]}" assert len(results[0]) == 1