A 5e storytelling engine with an LLM DM
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 388 lines 13 kB view raw
1"""Tests for the character advancement evaluation system.""" 2 3import json 4from pathlib import Path 5from unittest.mock import MagicMock, patch 6 7import pytest 8 9from storied.advancement import ( 10 AdvancementResult, 11 BackgroundAdvancement, 12 build_advancement_context, 13 evaluate_advancement, 14) 15from storied.character import save_character 16from storied.log import CampaignLog 17from storied.session import save_session 18from storied.testing import call_tool 19from storied.tools import ToolContext 20from storied.tools.scene import notify_dm as _notify_dm 21 22 23def notify_dm(message: str, ctx: ToolContext) -> str: 24 """Test shim: drop legacy `ctx` arg and resolve Dependency params.""" 25 return call_tool(_notify_dm, message=message) 26 27 28# --- Fixtures --- 29 30 31@pytest.fixture 32def character(ctx: ToolContext) -> dict: 33 """Create a level 3 rogue character.""" 34 data = { 35 "identity": { 36 "name": "Kira", 37 "race": "Human", 38 "classes": [{"class": "Rogue", "subclass": "Thief", "level": 3}], 39 "background": "Criminal", 40 }, 41 "abilities": { 42 "strength": 10, 43 "dexterity": 16, 44 "constitution": 14, 45 "intelligence": 12, 46 "wisdom": 10, 47 "charisma": 14, 48 }, 49 "state": { 50 "hp": {"current": 24, "max": 24, "temp": 0}, 51 "ac": 14, 52 "speed": 30, 53 }, 54 "features": [ 55 {"source": "Rogue Lv1", "name": "Sneak Attack", "text": "2d6"}, 56 {"source": "Rogue Lv2", "name": "Cunning Action", "text": ""}, 57 ], 58 } 59 save_character("default", data) 60 return data 61 62 63@pytest.fixture 64def campaign_with_events(ctx: ToolContext) -> CampaignLog: 65 """Campaign log with some adventure events.""" 66 log = ctx.campaign_log 67 log.append_entry("Arrived in Millford", "30 min") 68 log.append_entry("Investigated the warehouse", "1 hour") 69 log.append_entry("Fought three thugs", "3 rounds", tags=["combat"]) 70 log.append_entry("Discovered smuggling ring", "30 min") 71 log.append_entry("Traveled to Thornwall", "4 hours") 72 log.append_entry("Infiltrated the manor", "2 hours") 73 log.append_entry("Stole the merchant's ledger", "30 min") 74 return log 75 76 77# --- notify_dm tool --- 78 79 80class TestNotifyDM: 81 def test_appends_notification(self, ctx: ToolContext, tmp_path: Path): 82 result = notify_dm("Test message", ctx) 83 84 assert "queued" in result.lower() 85 path = tmp_path / "worlds" / ctx.world_id / "dm_notifications.md" 86 assert path.exists() 87 assert "Test message" in path.read_text() 88 89 def test_multiple_notifications(self, ctx: ToolContext, tmp_path: Path): 90 notify_dm("First", ctx) 91 notify_dm("Second", ctx) 92 93 path = tmp_path / "worlds" / ctx.world_id / "dm_notifications.md" 94 content = path.read_text() 95 assert "First" in content 96 assert "Second" in content 97 98 99# --- build_advancement_context --- 100 101 102class TestBuildAdvancementContext: 103 def test_returns_none_without_character(self, ctx: ToolContext, tmp_path: Path): 104 result = build_advancement_context(ctx.world_id, ctx.player_id) 105 assert result is None 106 107 def test_includes_character_info(self, ctx: ToolContext, character: dict): 108 context = build_advancement_context(ctx.world_id, ctx.player_id) 109 assert context is not None 110 assert "Kira" in context 111 assert "Rogue" in context 112 assert "Level 3" in context 113 114 def test_includes_log_entries( 115 self, 116 ctx: ToolContext, 117 character: dict, 118 campaign_with_events: CampaignLog, 119 ): 120 context = build_advancement_context(ctx.world_id, ctx.player_id) 121 assert context is not None 122 assert "warehouse" in context 123 assert "smuggling ring" in context 124 125 def test_includes_entries_since_level_tag(self, ctx: ToolContext, character: dict): 126 log = ctx.campaign_log 127 log.append_entry("Old event before level-up", "1 hour") 128 log.append_entry("Leveled up to 3", "5 min", tags=["level"]) 129 log.append_entry("New adventure begins", "30 min") 130 log.append_entry("Fought a dragon", "5 rounds", tags=["combat"]) 131 132 context = build_advancement_context(ctx.world_id, ctx.player_id) 133 assert context is not None 134 assert "Old event before level-up" not in context 135 assert "New adventure begins" in context 136 assert "Fought a dragon" in context 137 138 def test_includes_advancement_history(self, ctx: ToolContext, character: dict): 139 log = ctx.campaign_log 140 log.append_entry("Reached level 2", "5 min", tags=["level"]) 141 log.append_entry("Adventured more", "2 hours") 142 log.append_entry("Reached level 3", "5 min", tags=["level"]) 143 log.append_entry("Recent events", "1 hour") 144 145 context = build_advancement_context(ctx.world_id, ctx.player_id) 146 assert context is not None 147 assert "Advancement History" in context 148 assert "Reached level 2" in context 149 assert "Reached level 3" in context 150 151 def test_includes_session_state(self, ctx: ToolContext, character: dict): 152 save_session( 153 "default", 154 { 155 "location": "Town Square", 156 "body": "## Open Threads\n- Find the missing merchant", 157 }, 158 ) 159 160 context = build_advancement_context(ctx.world_id, ctx.player_id) 161 assert context is not None 162 assert "missing merchant" in context 163 164 165# --- CampaignLog tag methods --- 166 167 168class TestLogTagMethods: 169 def test_get_entries_since_tag_returns_all_when_no_tag(self, ctx: ToolContext): 170 log = ctx.campaign_log 171 log.append_entry("Event one", "10 min") 172 log.append_entry("Event two", "10 min") 173 174 entries = log.get_entries_since_tag("level") 175 assert len(entries) == 2 176 177 def test_get_entries_since_tag_returns_after_last_tag(self, ctx: ToolContext): 178 log = ctx.campaign_log 179 log.append_entry("Before level", "10 min") 180 log.append_entry("Level up!", "5 min", tags=["level"]) 181 log.append_entry("After level", "10 min") 182 183 entries = log.get_entries_since_tag("level") 184 assert len(entries) == 1 185 assert entries[0].event == "After level" 186 187 def test_get_entries_since_tag_uses_last_occurrence(self, ctx: ToolContext): 188 log = ctx.campaign_log 189 log.append_entry("First level", "5 min", tags=["level"]) 190 log.append_entry("Between levels", "1 hour") 191 log.append_entry("Second level", "5 min", tags=["level"]) 192 log.append_entry("After second", "30 min") 193 194 entries = log.get_entries_since_tag("level") 195 assert len(entries) == 1 196 assert entries[0].event == "After second" 197 198 def test_find_tag_entries(self, ctx: ToolContext): 199 log = ctx.campaign_log 200 log.append_entry("Normal event", "10 min") 201 log.append_entry("Level 2", "5 min", tags=["level"]) 202 log.append_entry("More stuff", "1 hour") 203 log.append_entry("Level 3", "5 min", tags=["level"]) 204 205 entries = log.find_tag_entries("level") 206 assert len(entries) == 2 207 assert entries[0].event == "Level 2" 208 assert entries[1].event == "Level 3" 209 210 def test_find_tag_entries_empty(self, ctx: ToolContext): 211 log = ctx.campaign_log 212 log.append_entry("Normal event", "10 min") 213 214 entries = log.find_tag_entries("level") 215 assert entries == [] 216 217 def test_get_all_entries(self, ctx: ToolContext): 218 log = ctx.campaign_log 219 log.append_entry("Day 1 morning", "10 min") 220 log.append_entry("Day 1 travel", "18 hours") 221 log.append_entry("Day 2 event", "10 min") 222 223 entries = log.get_all_entries() 224 assert len(entries) == 3 225 assert entries[0].event == "Day 1 morning" 226 assert entries[2].event == "Day 2 event" 227 228 229# --- evaluate_advancement --- 230 231 232class TestEvaluateAdvancement: 233 def test_skips_without_character(self, ctx: ToolContext): 234 result = evaluate_advancement( 235 world_id=ctx.world_id, 236 player_id=ctx.player_id, 237 ) 238 assert result.evaluated is False 239 240 def test_posts_reminder_when_advancement_pending( 241 self, 242 ctx: ToolContext, 243 character: dict, 244 tmp_path: Path, 245 ): 246 character["advancement_ready"] = 4 247 save_character("default", character) 248 249 result = evaluate_advancement( 250 world_id=ctx.world_id, 251 player_id=ctx.player_id, 252 ) 253 254 assert result.evaluated is False 255 path = tmp_path / "worlds" / ctx.world_id / "dm_notifications.md" 256 assert path.exists() 257 contents = path.read_text() 258 assert "Kira" in contents 259 assert "level 4" in contents 260 261 @patch("storied.claude.subprocess.Popen") 262 def test_calls_claude_when_character_exists( 263 self, 264 mock_popen: MagicMock, 265 ctx: ToolContext, 266 character: dict, 267 campaign_with_events: CampaignLog, 268 ): 269 result_line = json.dumps( 270 { 271 "type": "result", 272 "session_id": "sess-adv", 273 "usage": {"input_tokens": 500, "output_tokens": 100}, 274 "duration_ms": 2000, 275 } 276 ) 277 mock_proc = MagicMock() 278 mock_proc.stdin = MagicMock() 279 mock_proc.stdout = iter([result_line.encode() + b"\n"]) 280 mock_proc.stderr = iter([]) 281 mock_proc.wait.return_value = 0 282 mock_proc.returncode = 0 283 mock_popen.return_value = mock_proc 284 285 result = evaluate_advancement( 286 world_id=ctx.world_id, 287 player_id=ctx.player_id, 288 ) 289 290 assert result.evaluated is True 291 assert result.input_tokens == 500 292 mock_popen.assert_called_once() 293 294 295# --- BackgroundAdvancement --- 296 297 298class TestBackgroundAdvancement: 299 def test_does_not_trigger_before_interval(self): 300 adv = BackgroundAdvancement( 301 world_id="test", 302 player_id="default", 303 interval=5, 304 ) 305 # 4 turns should not trigger 306 for _ in range(4): 307 adv.on_turn() 308 assert adv._thread is None 309 310 def test_on_combat_end_triggers_immediately( 311 self, ctx: ToolContext, character: dict 312 ): 313 adv = BackgroundAdvancement( 314 world_id=ctx.world_id, 315 player_id=ctx.player_id, 316 interval=100, 317 ) 318 319 with patch("storied.advancement.evaluate_advancement") as mock_eval: 320 mock_eval.return_value = AdvancementResult() 321 adv.on_combat_end() 322 # Give the thread a moment 323 if adv._thread: 324 adv._thread.join(timeout=2) 325 mock_eval.assert_called_once() 326 327 def test_turn_counter_resets_after_evaluation(self): 328 adv = BackgroundAdvancement( 329 world_id="test", 330 player_id="default", 331 interval=5, 332 ) 333 with patch("storied.advancement.evaluate_advancement") as mock_eval: 334 mock_eval.return_value = AdvancementResult() 335 for _ in range(5): 336 adv.on_turn() 337 if adv._thread: 338 adv._thread.join(timeout=2) 339 assert adv._turn_count == 0 340 341 def test_pop_result_returns_none_when_no_thread(self): 342 adv = BackgroundAdvancement( 343 world_id="test", 344 player_id="default", 345 ) 346 assert adv.pop_result() is None 347 348 def test_pop_result_returns_and_clears_after_completion(self): 349 adv = BackgroundAdvancement( 350 world_id="test", 351 player_id="default", 352 interval=1, 353 ) 354 with patch("storied.advancement.evaluate_advancement") as mock_eval: 355 mock_eval.return_value = AdvancementResult(evaluated=True) 356 adv.on_turn() 357 if adv._thread: 358 adv._thread.join(timeout=2) 359 360 first = adv.pop_result() 361 assert first is not None 362 assert first.evaluated is True 363 # Second pop returns None — result was consumed 364 assert adv.pop_result() is None 365 366 def test_maybe_evaluate_skips_when_already_running(self): 367 adv = BackgroundAdvancement( 368 world_id="test", 369 player_id="default", 370 ) 371 # Stub a fake "still running" thread on the instance 372 from unittest.mock import MagicMock 373 374 fake_thread = MagicMock() 375 fake_thread.is_alive.return_value = True 376 adv._thread = fake_thread 377 # Should early-return without spawning a new thread 378 adv._maybe_evaluate() 379 # The fake thread is still the only one 380 assert adv._thread is fake_thread 381 382 def test_evaluate_advancement_default_base_path(self, tmp_path: Path, monkeypatch): 383 """The fall-through `base_path = Path.cwd()` branch.""" 384 from storied.advancement import evaluate_advancement 385 386 monkeypatch.chdir(tmp_path) # cwd has no character → returns early 387 result = evaluate_advancement(world_id="test", player_id="default") 388 assert result.evaluated is False