cli + tui to publish to leaflet (wip) & manage tasks, notes & watch/read lists ๐Ÿƒ
charm leaflet readability golang
29
fork

Configure Feed

Select the types of activity you want to include in your feed.

build: add additional shims for tests

+730 -594
+2 -7
internal/handlers/handlers.go
··· 25 25 logger.Info("Using config directory", "path", configDir) 26 26 fmt.Printf("Config directory: %s\n", configDir) 27 27 28 - // Load or create config to determine the actual database path 29 28 config, err := store.LoadConfig() 30 29 if err != nil { 31 30 return fmt.Errorf("failed to load configuration: %w", err) 32 31 } 33 32 34 - // Determine database path using the same logic as NewDatabase 35 33 var dbPath string 36 34 if config.DatabasePath != "" { 37 35 dbPath = config.DatabasePath ··· 68 66 fmt.Printf("Color scheme: %s\n", config.ColorScheme) 69 67 fmt.Printf("Default view: %s\n", config.DefaultView) 70 68 71 - runner := db.NewMigrationRunner() 69 + runner := store.NewMigrationRunner(db) 72 70 migrations, err := runner.GetAppliedMigrations() 73 71 if err != nil { 74 72 return fmt.Errorf("failed to get migration status: %w", err) ··· 92 90 func Reset(ctx context.Context, args []string) error { 93 91 fmt.Println("Resetting noteleaf...") 94 92 95 - // Load config to determine the actual database path 96 93 config, err := store.LoadConfig() 97 94 if err != nil { 98 - // If config doesn't exist, try to determine paths anyway 99 95 config = store.DefaultConfig() 100 96 } 101 97 102 - // Determine database path using the same logic as NewDatabase 103 98 var dbPath string 104 99 if config.DatabasePath != "" { 105 100 dbPath = config.DatabasePath ··· 182 177 } 183 178 defer db.Close() 184 179 185 - runner := db.NewMigrationRunner() 180 + runner := store.NewMigrationRunner(db) 186 181 applied, err := runner.GetAppliedMigrations() 187 182 if err != nil { 188 183 return fmt.Errorf("failed to get applied migrations: %w", err)
+1 -1
internal/handlers/handlers_test.go
··· 105 105 } 106 106 defer db.Close() 107 107 108 - runner := db.NewMigrationRunner() 108 + runner := store.NewMigrationRunner(db) 109 109 migrations, err := runner.GetAppliedMigrations() 110 110 if err != nil { 111 111 t.Fatalf("Failed to get migrations: %v", err)
+210 -140
internal/handlers/tasks_test.go
··· 4 4 "bytes" 5 5 "context" 6 6 "os" 7 - "path/filepath" 8 7 "runtime" 9 8 "slices" 10 9 "strconv" ··· 14 13 15 14 "github.com/google/uuid" 16 15 "github.com/stormlightlabs/noteleaf/internal/models" 16 + "github.com/stormlightlabs/noteleaf/internal/repo" 17 17 "github.com/stormlightlabs/noteleaf/internal/ui" 18 18 ) 19 19 20 - func setupTaskTest(t *testing.T) (string, func()) { 21 - tempDir, err := os.MkdirTemp("", "noteleaf-task-test-*") 22 - if err != nil { 23 - t.Fatalf("Failed to create temp dir: %v", err) 24 - } 25 - 26 - oldNoteleafConfig := os.Getenv("NOTELEAF_CONFIG") 27 - oldNoteleafDataDir := os.Getenv("NOTELEAF_DATA_DIR") 28 - os.Setenv("NOTELEAF_CONFIG", filepath.Join(tempDir, ".noteleaf.conf.toml")) 29 - os.Setenv("NOTELEAF_DATA_DIR", tempDir) 30 - 31 - cleanup := func() { 32 - os.Setenv("NOTELEAF_CONFIG", oldNoteleafConfig) 33 - os.Setenv("NOTELEAF_DATA_DIR", oldNoteleafDataDir) 34 - os.RemoveAll(tempDir) 35 - } 36 - 37 - ctx := context.Background() 38 - err = Setup(ctx, []string{}) 39 - if err != nil { 40 - cleanup() 41 - t.Fatalf("Failed to setup database: %v", err) 42 - } 43 - 44 - return tempDir, cleanup 45 - } 46 - 47 20 func TestTaskHandler(t *testing.T) { 21 + ctx := context.Background() 48 22 t.Run("New", func(t *testing.T) { 49 23 t.Run("creates handler successfully", func(t *testing.T) { 50 - _, cleanup := setupTaskTest(t) 24 + _, cleanup := SetupHandlerTest(t) 51 25 defer cleanup() 52 26 53 27 handler, err := NewTaskHandler() ··· 96 70 }) 97 71 98 72 t.Run("Create", func(t *testing.T) { 99 - _, cleanup := setupTaskTest(t) 73 + _, cleanup := SetupHandlerTest(t) 100 74 defer cleanup() 101 75 102 - handler, err := NewTaskHandler() 103 - if err != nil { 104 - t.Fatalf("Failed to create handler: %v", err) 105 - } 106 - defer handler.Close() 76 + handler := CreateTaskHandler(t) 107 77 108 78 t.Run("creates task successfully", func(t *testing.T) { 109 - ctx := context.Background() 110 79 desc := "Buy groceries and cook dinner" 111 80 err := handler.Create(ctx, desc, "", "", "", "", "", "", "", "", []string{}) 112 - if err != nil { 113 - t.Errorf("CreateTask failed: %v", err) 114 - } 81 + repo.AssertNoError(t, err, "CreateTask should succeed") 115 82 116 83 tasks, err := handler.repos.Tasks.GetPending(ctx) 117 - if err != nil { 118 - t.Fatalf("Failed to get pending tasks: %v", err) 119 - } 84 + repo.AssertNoError(t, err, "Failed to get pending tasks") 120 85 121 86 if len(tasks) != 1 { 122 87 t.Errorf("Expected 1 task, got %d", len(tasks)) ··· 138 103 }) 139 104 140 105 t.Run("fails with empty description", func(t *testing.T) { 141 - ctx := context.Background() 142 106 desc := "" 143 107 err := handler.Create(ctx, desc, "", "", "", "", "", "", "", "", []string{}) 144 - if err == nil { 145 - t.Error("Expected error for empty description") 146 - } 147 - 148 - if !strings.Contains(err.Error(), "task description required") { 149 - t.Errorf("Expected error about required description, got: %v", err) 150 - } 108 + repo.AssertError(t, err, "Expected error for empty description") 109 + repo.AssertContains(t, err.Error(), "task description required", "Error message should mention required description") 151 110 }) 152 111 153 112 t.Run("creates task with flags", func(t *testing.T) { 154 - ctx := context.Background() 155 113 description := "Task with flags" 156 114 priority := "A" 157 115 project := "test-project" ··· 210 168 }) 211 169 212 170 t.Run("fails with invalid due date format", func(t *testing.T) { 213 - ctx := context.Background() 214 171 desc := "Task with invalid date" 215 172 invalidDue := "invalid-date" 216 173 ··· 223 180 t.Errorf("Expected error about invalid date format, got: %v", err) 224 181 } 225 182 }) 183 + 184 + t.Run("fails when repository Create returns error", func(t *testing.T) { 185 + ctx, cancel := context.WithCancel(ctx) 186 + cancel() 187 + 188 + err := handler.Create(ctx, "Test task", "", "", "", "", "", "", "", "", []string{}) 189 + if err == nil { 190 + t.Error("Expected error when repository Create fails") 191 + } 192 + 193 + if !strings.Contains(err.Error(), "failed to create task") { 194 + t.Errorf("Expected 'failed to create task' error, got: %v", err) 195 + } 196 + }) 226 197 }) 227 198 228 199 t.Run("List", func(t *testing.T) { 229 - _, cleanup := setupTaskTest(t) 200 + _, cleanup := SetupHandlerTest(t) 230 201 defer cleanup() 231 202 232 - ctx := context.Background() 233 203 handler, err := NewTaskHandler() 234 204 if err != nil { 235 205 t.Fatalf("Failed to create handler: %v", err) ··· 313 283 }) 314 284 315 285 t.Run("Update", func(t *testing.T) { 316 - _, cleanup := setupTaskTest(t) 286 + _, cleanup := SetupHandlerTest(t) 317 287 defer cleanup() 318 - 319 - ctx := context.Background() 320 288 321 289 handler, err := NewTaskHandler() 322 290 if err != nil { ··· 455 423 t.Errorf("Expected error about task not found, got: %v", err) 456 424 } 457 425 }) 426 + 427 + t.Run("fails when repository Get fails", func(t *testing.T) { 428 + cancelCtx, cancel := context.WithCancel(context.Background()) 429 + cancel() 430 + 431 + err := handler.Update(cancelCtx, "1", "test", "", "", "", "", "", "", "", "", []string{}, []string{}, "", "") 432 + if err == nil { 433 + t.Error("Expected error when repository Get fails") 434 + } 435 + 436 + if !strings.Contains(err.Error(), "failed to find task") { 437 + t.Errorf("Expected 'failed to find task' error, got: %v", err) 438 + } 439 + }) 440 + 441 + t.Run("fails when repository operations fail with canceled context", func(t *testing.T) { 442 + task := &models.Task{ 443 + UUID: uuid.New().String(), 444 + Description: "Test task", 445 + Status: "pending", 446 + } 447 + id, err := handler.repos.Tasks.Create(ctx, task) 448 + if err != nil { 449 + t.Fatalf("Failed to create task: %v", err) 450 + } 451 + 452 + cancelCtx, cancel := context.WithCancel(context.Background()) 453 + cancel() 454 + 455 + taskID := strconv.FormatInt(id, 10) 456 + err = handler.Update(cancelCtx, taskID, "Updated", "", "", "", "", "", "", "", "", []string{}, []string{}, "", "") 457 + if err == nil { 458 + t.Error("Expected error with canceled context") 459 + } 460 + }) 458 461 }) 459 462 460 463 t.Run("Delete", func(t *testing.T) { 461 - _, cleanup := setupTaskTest(t) 464 + _, cleanup := SetupHandlerTest(t) 462 465 defer cleanup() 463 466 464 - ctx := context.Background() 465 - 466 - handler, err := NewTaskHandler() 467 - if err != nil { 468 - t.Fatalf("Failed to create handler: %v", err) 469 - } 470 - defer handler.Close() 467 + handler := CreateTaskHandler(t) 471 468 472 469 task := &models.Task{ 473 470 UUID: uuid.New().String(), ··· 545 542 }) 546 543 547 544 t.Run("View", func(t *testing.T) { 548 - _, cleanup := setupTaskTest(t) 545 + _, cleanup := SetupHandlerTest(t) 549 546 defer cleanup() 550 - 551 - ctx := context.Background() 552 547 553 548 handler, err := NewTaskHandler() 554 549 if err != nil { ··· 642 637 }) 643 638 644 639 t.Run("Done", func(t *testing.T) { 645 - _, cleanup := setupTaskTest(t) 640 + _, cleanup := SetupHandlerTest(t) 646 641 defer cleanup() 647 - 648 - ctx := context.Background() 649 642 650 643 handler, err := NewTaskHandler() 651 644 if err != nil { ··· 783 776 }) 784 777 785 778 t.Run("Print", func(t *testing.T) { 786 - _, cleanup := setupTaskTest(t) 779 + _, cleanup := SetupHandlerTest(t) 787 780 defer cleanup() 788 781 789 782 handler, err := NewTaskHandler() ··· 830 823 }) 831 824 832 825 t.Run("ListProjects", func(t *testing.T) { 833 - _, cleanup := setupTaskTest(t) 826 + _, cleanup := SetupHandlerTest(t) 834 827 defer cleanup() 835 - 836 - ctx := context.Background() 837 828 838 829 handler, err := NewTaskHandler() 839 830 if err != nil { ··· 863 854 }) 864 855 865 856 t.Run("returns no projects when none exist", func(t *testing.T) { 866 - _, cleanup2 := setupTaskTest(t) 857 + _, cleanup2 := SetupHandlerTest(t) 867 858 defer cleanup2() 868 859 869 860 err := handler.ListProjects(ctx, true) ··· 871 862 t.Errorf("ListProjects with no projects failed: %v", err) 872 863 } 873 864 }) 865 + 866 + t.Run("fails when repository List fails", func(t *testing.T) { 867 + cancelCtx, cancel := context.WithCancel(context.Background()) 868 + cancel() 869 + 870 + err := handler.ListProjects(cancelCtx, true) 871 + if err == nil { 872 + t.Error("Expected error when repository List fails") 873 + } 874 + 875 + if !strings.Contains(err.Error(), "failed to list tasks for projects") { 876 + t.Errorf("Expected 'failed to list tasks for projects' error, got: %v", err) 877 + } 878 + }) 874 879 }) 875 880 876 881 t.Run("ListTags", func(t *testing.T) { 877 - _, cleanup := setupTaskTest(t) 882 + _, cleanup := SetupHandlerTest(t) 878 883 defer cleanup() 879 - 880 - ctx := context.Background() 881 884 882 885 handler, err := NewTaskHandler() 883 886 if err != nil { ··· 907 910 }) 908 911 909 912 t.Run("returns no tags when none exist", func(t *testing.T) { 910 - _, cleanup2 := setupTaskTest(t) 913 + _, cleanup2 := SetupHandlerTest(t) 911 914 defer cleanup2() 912 915 913 916 err := handler.ListTags(ctx, true) ··· 915 918 t.Errorf("ListTags with no tags failed: %v", err) 916 919 } 917 920 }) 921 + 922 + t.Run("fails when repository List fails", func(t *testing.T) { 923 + cancelCtx, cancel := context.WithCancel(context.Background()) 924 + cancel() 925 + 926 + err := handler.ListTags(cancelCtx, true) 927 + if err == nil { 928 + t.Error("Expected error when repository List fails") 929 + } 930 + 931 + if !strings.Contains(err.Error(), "failed to list tasks for tags") { 932 + t.Errorf("Expected 'failed to list tasks for tags' error, got: %v", err) 933 + } 934 + }) 918 935 }) 919 936 920 937 t.Run("Pluralize", func(t *testing.T) { ··· 944 961 }) 945 962 946 963 t.Run("InteractiveComponentsStatic", func(t *testing.T) { 947 - _, cleanup := setupTaskTest(t) 964 + _, cleanup := SetupHandlerTest(t) 948 965 defer cleanup() 949 966 950 967 handler, err := NewTaskHandler() ··· 952 969 t.Fatalf("Failed to create task handler: %v", err) 953 970 } 954 971 defer handler.Close() 955 - 956 - ctx := context.Background() 957 972 958 973 err = handler.Create(ctx, "Test Task 1", "high", "test-project", "test-context", "", "", "", "", "", []string{"tag1"}) 959 974 if err != nil { ··· 1092 1107 }) 1093 1108 1094 1109 t.Run("handles no contexts", func(t *testing.T) { 1095 - _, cleanup2 := setupTaskTest(t) 1110 + _, cleanup2 := SetupHandlerTest(t) 1096 1111 defer cleanup2() 1097 1112 1098 1113 handler2, err := NewTaskHandler() ··· 1159 1174 }) 1160 1175 1161 1176 t.Run("ListContexts", func(t *testing.T) { 1162 - _, cleanup := setupTaskTest(t) 1177 + _, cleanup := SetupHandlerTest(t) 1163 1178 defer cleanup() 1164 1179 1165 - ctx := context.Background() 1166 - 1167 1180 handler, err := NewTaskHandler() 1168 1181 if err != nil { 1169 1182 t.Fatalf("Failed to create handler: %v", err) ··· 1238 1251 }) 1239 1252 1240 1253 t.Run("returns no contexts when none exist", func(t *testing.T) { 1241 - _, cleanup_ := setupTaskTest(t) 1254 + _, cleanup_ := SetupHandlerTest(t) 1242 1255 defer cleanup_() 1243 1256 1244 1257 handler_, err := NewTaskHandler() ··· 1254 1267 }) 1255 1268 }) 1256 1269 1257 - t.Run("RecurSet", func(t *testing.T) { 1258 - _, cleanup := setupTaskTest(t) 1270 + t.Run("SetRecur", func(t *testing.T) { 1271 + _, cleanup := SetupHandlerTest(t) 1259 1272 defer cleanup() 1260 - 1261 - ctx := context.Background() 1262 1273 1263 1274 handler, err := NewTaskHandler() 1264 1275 if err != nil { ··· 1276 1287 t.Run("sets recurrence rule", func(t *testing.T) { 1277 1288 err := handler.SetRecur(ctx, strconv.FormatInt(id, 10), "FREQ=DAILY", "2025-12-31") 1278 1289 if err != nil { 1279 - t.Errorf("RecurSet failed: %v", err) 1290 + t.Errorf("SetRecur failed: %v", err) 1280 1291 } 1281 1292 1282 1293 task, err := handler.repos.Tasks.Get(ctx, id) ··· 1299 1310 t.Error("Expected error for invalid until date") 1300 1311 } 1301 1312 }) 1313 + 1314 + t.Run("fails when repository Get fails", func(t *testing.T) { 1315 + cancelCtx, cancel := context.WithCancel(context.Background()) 1316 + cancel() 1317 + 1318 + err := handler.SetRecur(cancelCtx, "1", "FREQ=DAILY", "") 1319 + if err == nil { 1320 + t.Error("Expected error when repository Get fails") 1321 + } 1322 + 1323 + if !strings.Contains(err.Error(), "failed to find task") { 1324 + t.Errorf("Expected 'failed to find task' error, got: %v", err) 1325 + } 1326 + }) 1327 + 1328 + t.Run("fails with canceled context", func(t *testing.T) { 1329 + task := &models.Task{ 1330 + UUID: uuid.New().String(), 1331 + Description: "Test task", 1332 + Status: "pending", 1333 + } 1334 + id, err := handler.repos.Tasks.Create(ctx, task) 1335 + if err != nil { 1336 + t.Fatalf("Failed to create task: %v", err) 1337 + } 1338 + 1339 + cancelCtx, cancel := context.WithCancel(context.Background()) 1340 + cancel() 1341 + 1342 + err = handler.SetRecur(cancelCtx, strconv.FormatInt(id, 10), "FREQ=DAILY", "") 1343 + if err == nil { 1344 + t.Error("Expected error with canceled context") 1345 + } 1346 + }) 1302 1347 }) 1303 1348 1304 - t.Run("RecurClear", func(t *testing.T) { 1305 - _, cleanup := setupTaskTest(t) 1349 + t.Run("ClearRecur", func(t *testing.T) { 1350 + _, cleanup := SetupHandlerTest(t) 1306 1351 defer cleanup() 1307 - 1308 - ctx := context.Background() 1309 1352 1310 1353 handler, err := NewTaskHandler() 1311 1354 if err != nil { ··· 1325 1368 t.Fatalf("Failed to create task: %v", err) 1326 1369 } 1327 1370 1328 - err = handler.ClearRecur(ctx, strconv.FormatInt(id, 10)) 1329 - if err != nil { 1330 - t.Errorf("RecurClear failed: %v", err) 1371 + if err = handler.ClearRecur(ctx, strconv.FormatInt(id, 10)); err != nil { 1372 + t.Errorf("ClearRecur failed: %v", err) 1331 1373 } 1332 1374 1333 1375 task, err := handler.repos.Tasks.Get(ctx, id) ··· 1342 1384 if task.Until != nil { 1343 1385 t.Error("Expected until to be cleared") 1344 1386 } 1387 + 1388 + t.Run("fails when repository Get fails", func(t *testing.T) { 1389 + cancelCtx, cancel := context.WithCancel(context.Background()) 1390 + cancel() 1391 + 1392 + err := handler.ClearRecur(cancelCtx, "1") 1393 + if err == nil { 1394 + t.Error("Expected error when repository Get fails") 1395 + } 1396 + 1397 + if !strings.Contains(err.Error(), "failed to find task") { 1398 + t.Errorf("Expected 'failed to find task' error, got: %v", err) 1399 + } 1400 + }) 1401 + 1402 + t.Run("fails with canceled context", func(t *testing.T) { 1403 + task := &models.Task{ 1404 + UUID: uuid.New().String(), 1405 + Description: "Test task", 1406 + Status: "pending", 1407 + Recur: "FREQ=DAILY", 1408 + } 1409 + id, err := handler.repos.Tasks.Create(ctx, task) 1410 + if err != nil { 1411 + t.Fatalf("Failed to create task: %v", err) 1412 + } 1413 + 1414 + cancelCtx, cancel := context.WithCancel(context.Background()) 1415 + cancel() 1416 + 1417 + if err = handler.ClearRecur(cancelCtx, strconv.FormatInt(id, 10)); err == nil { 1418 + t.Error("Expected error with canceled context") 1419 + } 1420 + }) 1345 1421 }) 1346 1422 1347 - t.Run("RecurShow", func(t *testing.T) { 1348 - _, cleanup := setupTaskTest(t) 1423 + t.Run("ShowRecur", func(t *testing.T) { 1424 + _, cleanup := SetupHandlerTest(t) 1349 1425 defer cleanup() 1350 - 1351 - ctx := context.Background() 1352 1426 1353 1427 handler, err := NewTaskHandler() 1354 1428 if err != nil { ··· 1370 1444 1371 1445 err = handler.ShowRecur(ctx, strconv.FormatInt(id, 10)) 1372 1446 if err != nil { 1373 - t.Errorf("RecurShow failed: %v", err) 1447 + t.Errorf("ShowRecur failed: %v", err) 1374 1448 } 1449 + 1450 + t.Run("fails when repository Get fails", func(t *testing.T) { 1451 + cancelCtx, cancel := context.WithCancel(context.Background()) 1452 + cancel() 1453 + 1454 + err := handler.ShowRecur(cancelCtx, "1") 1455 + if err == nil { 1456 + t.Error("Expected error when repository Get fails") 1457 + } 1458 + 1459 + if !strings.Contains(err.Error(), "failed to find task") { 1460 + t.Errorf("Expected 'failed to find task' error, got: %v", err) 1461 + } 1462 + }) 1375 1463 }) 1376 1464 1377 - t.Run("DependAdd", func(t *testing.T) { 1378 - _, cleanup := setupTaskTest(t) 1465 + t.Run("AddDep", func(t *testing.T) { 1466 + _, cleanup := SetupHandlerTest(t) 1379 1467 defer cleanup() 1380 - 1381 - ctx := context.Background() 1382 1468 1383 1469 handler, err := NewTaskHandler() 1384 1470 if err != nil { ··· 1396 1482 t.Fatalf("Failed to create task 1: %v", err) 1397 1483 } 1398 1484 1399 - _, err = handler.repos.Tasks.Create(ctx, &models.Task{ 1485 + if _, err = handler.repos.Tasks.Create(ctx, &models.Task{ 1400 1486 UUID: task2UUID, Description: "Task 2", Status: "pending", 1401 - }) 1402 - if err != nil { 1487 + }); err != nil { 1403 1488 t.Fatalf("Failed to create task 2: %v", err) 1404 1489 } 1405 1490 1406 1491 err = handler.AddDep(ctx, strconv.FormatInt(id1, 10), task2UUID) 1407 1492 if err != nil { 1408 - t.Errorf("DependAdd failed: %v", err) 1493 + t.Errorf("AddDep failed: %v", err) 1409 1494 } 1410 1495 1411 1496 task, err := handler.repos.Tasks.Get(ctx, id1) ··· 1422 1507 } 1423 1508 }) 1424 1509 1425 - t.Run("DependRemove", func(t *testing.T) { 1426 - _, cleanup := setupTaskTest(t) 1510 + t.Run("RemoveDep", func(t *testing.T) { 1511 + _, cleanup := SetupHandlerTest(t) 1427 1512 defer cleanup() 1428 - 1429 - ctx := context.Background() 1430 1513 1431 1514 handler, err := NewTaskHandler() 1432 1515 if err != nil { ··· 1456 1539 1457 1540 err = handler.RemoveDep(ctx, strconv.FormatInt(id1, 10), task2UUID) 1458 1541 if err != nil { 1459 - t.Errorf("DependRemove failed: %v", err) 1542 + t.Errorf("RemoveDep failed: %v", err) 1460 1543 } 1461 1544 1462 1545 task, err := handler.repos.Tasks.Get(ctx, id1) ··· 1469 1552 } 1470 1553 }) 1471 1554 1472 - t.Run("DependList", func(t *testing.T) { 1473 - _, cleanup := setupTaskTest(t) 1555 + t.Run("ListDeps", func(t *testing.T) { 1556 + _, cleanup := SetupHandlerTest(t) 1474 1557 defer cleanup() 1475 - 1476 - ctx := context.Background() 1477 1558 1478 1559 handler, err := NewTaskHandler() 1479 1560 if err != nil { ··· 1503 1584 1504 1585 err = handler.ListDeps(ctx, strconv.FormatInt(id1, 10)) 1505 1586 if err != nil { 1506 - t.Errorf("DependList failed: %v", err) 1587 + t.Errorf("ListDeps failed: %v", err) 1507 1588 } 1508 1589 }) 1509 1590 1510 - t.Run("DependBlockedBy", func(t *testing.T) { 1511 - _, cleanup := setupTaskTest(t) 1591 + t.Run("BlockedByDep", func(t *testing.T) { 1592 + _, cleanup := SetupHandlerTest(t) 1512 1593 defer cleanup() 1513 - 1514 - ctx := context.Background() 1515 1594 1516 1595 handler, err := NewTaskHandler() 1517 1596 if err != nil { ··· 1522 1601 task1UUID := uuid.New().String() 1523 1602 task2UUID := uuid.New().String() 1524 1603 1525 - id2, err := handler.repos.Tasks.Create(ctx, &models.Task{ 1526 - UUID: task2UUID, Description: "Task 2", Status: "pending", 1527 - }) 1604 + id2, err := handler.repos.Tasks.Create(ctx, &models.Task{UUID: task2UUID, Description: "Task 2", Status: "pending"}) 1528 1605 if err != nil { 1529 1606 t.Fatalf("Failed to create task 2: %v", err) 1530 1607 } 1531 1608 1532 - _, err = handler.repos.Tasks.Create(ctx, &models.Task{ 1533 - UUID: task1UUID, 1534 - Description: "Task 1", 1535 - Status: "pending", 1536 - DependsOn: []string{task2UUID}, 1537 - }) 1538 - if err != nil { 1609 + if _, err = handler.repos.Tasks.Create(ctx, &models.Task{UUID: task1UUID, Description: "Task 1", Status: "pending", DependsOn: []string{task2UUID}}); err != nil { 1539 1610 t.Fatalf("Failed to create task 1: %v", err) 1540 1611 } 1541 1612 1542 - err = handler.BlockedByDep(ctx, strconv.FormatInt(id2, 10)) 1543 - if err != nil { 1544 - t.Errorf("DependBlockedBy failed: %v", err) 1613 + if err = handler.BlockedByDep(ctx, strconv.FormatInt(id2, 10)); err != nil { 1614 + t.Errorf("BlockedByDep failed: %v", err) 1545 1615 } 1546 1616 }) 1547 1617 }
+39 -4
internal/handlers/test_utilities.go
··· 830 830 831 831 // SetupHandlerWithInput creates a handler and sets up input simulation in one call 832 832 func SetupBookHandlerWithInput(t *testing.T, inputs ...string) (*BookHandler, func()) { 833 - _, cleanup := setupTest(t) 833 + _, cleanup := SetupHandlerTest(t) 834 834 835 835 handler, err := NewBookHandler() 836 836 if err != nil { ··· 852 852 853 853 // SetupMovieHandlerWithInput creates a movie handler and sets up input simulation 854 854 func SetupMovieHandlerWithInput(t *testing.T, inputs ...string) (*MovieHandler, func()) { 855 - _, cleanup := setupTest(t) 855 + _, cleanup := SetupHandlerTest(t) 856 856 857 857 handler, err := NewMovieHandler() 858 858 if err != nil { ··· 874 874 875 875 // SetupTVHandlerWithInput creates a TV handler and sets up input simulation 876 876 func SetupTVHandlerWithInput(t *testing.T, inputs ...string) (*TVHandler, func()) { 877 - _, cleanup := setupTest(t) 877 + _, cleanup := SetupHandlerTest(t) 878 878 879 879 handler, err := NewTVHandler() 880 880 if err != nil { ··· 894 894 return handler, fullCleanup 895 895 } 896 896 897 - func setupTest(t *testing.T) (string, func()) { 897 + func SetupHandlerTest(t *testing.T) (string, func()) { 898 898 tempDir, err := os.MkdirTemp("", "noteleaf-interactive-test-*") 899 899 if err != nil { 900 900 t.Fatalf("Failed to create temp dir: %v", err) ··· 1191 1191 }, 1192 1192 } 1193 1193 } 1194 + 1195 + // CreateTaskHandler creates a TaskHandler for testing with automatic cleanup 1196 + func CreateTaskHandler(t *testing.T) *TaskHandler { 1197 + t.Helper() 1198 + 1199 + handler, err := NewTaskHandler() 1200 + if err != nil { 1201 + t.Fatalf("Failed to create task handler: %v", err) 1202 + } 1203 + 1204 + t.Cleanup(func() { 1205 + handler.Close() 1206 + }) 1207 + 1208 + return handler 1209 + } 1210 + 1211 + // AssertTaskHasUUID verifies that a task has a non-empty UUID 1212 + func AssertTaskHasUUID(t *testing.T, task *models.Task) { 1213 + t.Helper() 1214 + if task.UUID == "" { 1215 + t.Fatal("Task should have a UUID") 1216 + } 1217 + } 1218 + 1219 + // AssertTaskDatesSet verifies that Entry and Modified timestamps are set 1220 + func AssertTaskDatesSet(t *testing.T, task *models.Task) { 1221 + t.Helper() 1222 + if task.Entry.IsZero() { 1223 + t.Error("Task Entry timestamp should be set") 1224 + } 1225 + if task.Modified.IsZero() { 1226 + t.Error("Task Modified timestamp should be set") 1227 + } 1228 + }
+2 -109
internal/repo/repositories_test.go
··· 2 2 3 3 import ( 4 4 "context" 5 - "database/sql" 6 5 "testing" 7 6 8 7 "github.com/google/uuid" ··· 10 9 "github.com/stormlightlabs/noteleaf/internal/models" 11 10 ) 12 11 13 - func createFullTestDB(t *testing.T) *sql.DB { 14 - db, err := sql.Open("sqlite3", ":memory:") 15 - if err != nil { 16 - t.Fatalf("Failed to create in-memory database: %v", err) 17 - } 18 - 19 - if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { 20 - t.Fatalf("Failed to enable foreign keys: %v", err) 21 - } 22 - 23 - schema := ` 24 - -- Tasks table 25 - CREATE TABLE IF NOT EXISTS tasks ( 26 - id INTEGER PRIMARY KEY AUTOINCREMENT, 27 - uuid TEXT UNIQUE NOT NULL, 28 - description TEXT NOT NULL, 29 - status TEXT DEFAULT 'pending', 30 - priority TEXT, 31 - project TEXT, 32 - context TEXT, 33 - tags TEXT, 34 - due DATETIME, 35 - entry DATETIME DEFAULT CURRENT_TIMESTAMP, 36 - modified DATETIME DEFAULT CURRENT_TIMESTAMP, 37 - end DATETIME, 38 - start DATETIME, 39 - annotations TEXT, 40 - recur TEXT, 41 - until DATETIME, 42 - parent_uuid TEXT 43 - ); 44 - 45 - -- Task dependencies table 46 - CREATE TABLE IF NOT EXISTS task_dependencies ( 47 - id INTEGER PRIMARY KEY AUTOINCREMENT, 48 - task_uuid TEXT NOT NULL, 49 - depends_on_uuid TEXT NOT NULL, 50 - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, 51 - FOREIGN KEY(task_uuid) REFERENCES tasks(uuid) ON DELETE CASCADE, 52 - FOREIGN KEY(depends_on_uuid) REFERENCES tasks(uuid) ON DELETE CASCADE 53 - ); 54 - 55 - -- Movies table 56 - CREATE TABLE IF NOT EXISTS movies ( 57 - id INTEGER PRIMARY KEY AUTOINCREMENT, 58 - title TEXT NOT NULL, 59 - year INTEGER, 60 - status TEXT DEFAULT 'queued', 61 - rating REAL, 62 - notes TEXT, 63 - added DATETIME DEFAULT CURRENT_TIMESTAMP, 64 - watched DATETIME 65 - ); 66 - 67 - -- TV Shows table 68 - CREATE TABLE IF NOT EXISTS tv_shows ( 69 - id INTEGER PRIMARY KEY AUTOINCREMENT, 70 - title TEXT NOT NULL, 71 - season INTEGER, 72 - episode INTEGER, 73 - status TEXT DEFAULT 'queued', 74 - rating REAL, 75 - notes TEXT, 76 - added DATETIME DEFAULT CURRENT_TIMESTAMP, 77 - last_watched DATETIME 78 - ); 79 - 80 - -- Books table 81 - CREATE TABLE IF NOT EXISTS books ( 82 - id INTEGER PRIMARY KEY AUTOINCREMENT, 83 - title TEXT NOT NULL, 84 - author TEXT, 85 - status TEXT DEFAULT 'queued', 86 - progress INTEGER DEFAULT 0, 87 - pages INTEGER, 88 - rating REAL, 89 - notes TEXT, 90 - added DATETIME DEFAULT CURRENT_TIMESTAMP, 91 - started DATETIME, 92 - finished DATETIME 93 - ); 94 - 95 - -- Notes table 96 - CREATE TABLE IF NOT EXISTS notes ( 97 - id INTEGER PRIMARY KEY AUTOINCREMENT, 98 - title TEXT NOT NULL, 99 - content TEXT NOT NULL, 100 - tags TEXT, 101 - archived BOOLEAN DEFAULT FALSE, 102 - created DATETIME DEFAULT CURRENT_TIMESTAMP, 103 - modified DATETIME DEFAULT CURRENT_TIMESTAMP, 104 - file_path TEXT 105 - ); 106 - ` 107 - 108 - if _, err := db.Exec(schema); err != nil { 109 - t.Fatalf("Failed to create schema: %v", err) 110 - } 111 - 112 - t.Cleanup(func() { 113 - db.Close() 114 - }) 115 - 116 - return db 117 - } 118 - 119 12 func TestRepositories(t *testing.T) { 120 13 t.Run("Integration", func(t *testing.T) { 121 - db := createFullTestDB(t) 14 + db := CreateTestDB(t) 122 15 repos := NewRepositories(db) 123 16 ctx := context.Background() 124 17 ··· 300 193 }) 301 194 302 195 t.Run("New", func(t *testing.T) { 303 - db := createFullTestDB(t) 196 + db := CreateTestDB(t) 304 197 repos := NewRepositories(db) 305 198 306 199 t.Run("All repositories are initialized", func(t *testing.T) {
+57 -58
internal/repo/task_repository.go
··· 10 10 "github.com/stormlightlabs/noteleaf/internal/models" 11 11 ) 12 12 13 + var ( 14 + marshalTaskTags = (*models.Task).MarshalTags 15 + marshalTaskAnnotations = (*models.Task).MarshalAnnotations 16 + unmarshalTaskTags = (*models.Task).UnmarshalTags 17 + unmarshalTaskAnnotations = (*models.Task).UnmarshalAnnotations 18 + ) 19 + 13 20 // TaskListOptions defines options for listing tasks 14 21 type TaskListOptions struct { 15 22 Status string ··· 59 66 task.Entry = now 60 67 task.Modified = now 61 68 62 - tags, err := task.MarshalTags() 69 + tags, err := marshalTaskTags(task) 63 70 if err != nil { 64 71 return 0, fmt.Errorf("failed to marshal tags: %w", err) 65 72 } 66 73 67 - annotations, err := task.MarshalAnnotations() 74 + annotations, err := marshalTaskAnnotations(task) 68 75 if err != nil { 69 76 return 0, fmt.Errorf("failed to marshal annotations: %w", err) 70 77 } ··· 125 132 } 126 133 127 134 if tags.Valid { 128 - if err := task.UnmarshalTags(tags.String); err != nil { 135 + if err := unmarshalTaskTags(task, tags.String); err != nil { 129 136 return nil, fmt.Errorf("failed to unmarshal tags: %w", err) 130 137 } 131 138 } 132 139 133 140 if annotations.Valid { 134 - if err := task.UnmarshalAnnotations(annotations.String); err != nil { 141 + if err := unmarshalTaskAnnotations(task, annotations.String); err != nil { 135 142 return nil, fmt.Errorf("failed to unmarshal annotations: %w", err) 136 143 } 137 144 } ··· 151 158 func (r *TaskRepository) Update(ctx context.Context, task *models.Task) error { 152 159 task.Modified = time.Now() 153 160 154 - tags, err := task.MarshalTags() 161 + tags, err := marshalTaskTags(task) 155 162 if err != nil { 156 163 return fmt.Errorf("failed to marshal tags: %w", err) 157 164 } 158 165 159 - annotations, err := task.MarshalAnnotations() 166 + annotations, err := marshalTaskAnnotations(task) 160 167 if err != nil { 161 168 return fmt.Errorf("failed to marshal annotations: %w", err) 162 169 } ··· 345 352 } 346 353 347 354 if tags.Valid { 348 - if err := task.UnmarshalTags(tags.String); err != nil { 355 + if err := unmarshalTaskTags(task, tags.String); err != nil { 349 356 return fmt.Errorf("failed to unmarshal tags: %w", err) 350 357 } 351 358 } 352 359 353 360 if annotations.Valid { 354 - if err := task.UnmarshalAnnotations(annotations.String); err != nil { 361 + if err := unmarshalTaskAnnotations(task, annotations.String); err != nil { 355 362 return fmt.Errorf("failed to unmarshal annotations: %w", err) 356 363 } 357 364 } ··· 443 450 } 444 451 445 452 if tags.Valid { 446 - if err := task.UnmarshalTags(tags.String); err != nil { 453 + if err := unmarshalTaskTags(task, tags.String); err != nil { 447 454 return nil, fmt.Errorf("failed to unmarshal tags: %w", err) 448 455 } 449 456 } 450 457 451 458 if annotations.Valid { 452 - if err := task.UnmarshalAnnotations(annotations.String); err != nil { 459 + if err := unmarshalTaskAnnotations(task, annotations.String); err != nil { 453 460 return nil, fmt.Errorf("failed to unmarshal annotations: %w", err) 454 461 } 455 462 } ··· 623 630 // GetByPriority retrieves all tasks with a specific priority with special handling for empty priority by using raw SQL 624 631 func (r *TaskRepository) GetByPriority(ctx context.Context, priority string) ([]*models.Task, error) { 625 632 if priority == "" { 626 - query := ` 627 - SELECT id, uuid, description, status, priority, project, context, 628 - tags, due, entry, modified, end, start, annotations, 629 - recur, until, parent_uuid 630 - FROM tasks 631 - WHERE priority = '' OR priority IS NULL 632 - ORDER BY modified DESC` 633 + query := `SELECT id, uuid, description, status, priority, project, context, 634 + tags, due, entry, modified, end, start, annotations, recur, until, parent_uuid 635 + FROM tasks WHERE priority = '' OR priority IS NULL ORDER BY modified DESC` 633 636 634 637 rows, err := r.db.QueryContext(ctx, query) 635 638 if err != nil { ··· 668 671 669 672 // GetStatusSummary returns a summary of tasks by status 670 673 func (r *TaskRepository) GetStatusSummary(ctx context.Context) (map[string]int64, error) { 671 - query := ` 672 - SELECT status, COUNT(*) as count 673 - FROM tasks 674 - GROUP BY status 675 - ORDER BY status` 674 + query := `SELECT status, COUNT(*) as count FROM tasks GROUP BY status ORDER BY status` 676 675 677 676 rows, err := r.db.QueryContext(ctx, query) 678 677 if err != nil { ··· 689 688 } 690 689 summary[status] = count 691 690 } 692 - 693 691 return summary, rows.Err() 694 692 } 695 693 ··· 701 699 WHEN priority = '' OR priority IS NULL THEN 'No Priority' 702 700 ELSE priority 703 701 END as priority_group, 704 - COUNT(*) as count 705 - FROM tasks 706 - GROUP BY priority_group 707 - ORDER BY priority_group` 702 + COUNT(*) as count FROM tasks GROUP BY priority_group ORDER BY priority_group` 708 703 709 704 rows, err := r.db.QueryContext(ctx, query) 710 705 if err != nil { ··· 721 716 } 722 717 summary[priority] = count 723 718 } 724 - 725 719 return summary, rows.Err() 726 720 } 727 721 728 722 // AddDependency creates a dependency relationship where taskUUID depends on dependsOnUUID. 729 723 func (r *TaskRepository) AddDependency(ctx context.Context, taskUUID, dependsOnUUID string) error { 730 - _, err := r.db.ExecContext(ctx, 731 - `INSERT INTO task_dependencies (task_uuid, depends_on_uuid) VALUES (?, ?)`, 732 - taskUUID, dependsOnUUID) 733 - if err != nil { 734 - err = fmt.Errorf("failed to add dependency: %w", err) 724 + if _, err := r.db.ExecContext(ctx, `INSERT INTO task_dependencies (task_uuid, depends_on_uuid) VALUES (?, ?)`, taskUUID, dependsOnUUID); err != nil { 725 + return fmt.Errorf("failed to add dependency: %w", err) 735 726 } 736 - return err 727 + return nil 737 728 } 738 729 739 730 // RemoveDependency deletes a specific dependency relationship. 740 731 func (r *TaskRepository) RemoveDependency(ctx context.Context, taskUUID, dependsOnUUID string) error { 741 - _, err := r.db.ExecContext(ctx, 742 - `DELETE FROM task_dependencies WHERE task_uuid = ? AND depends_on_uuid = ?`, 743 - taskUUID, dependsOnUUID) 744 - if err != nil { 745 - err = fmt.Errorf("failed to remove dependency: %w", err) 732 + if _, err := r.db.ExecContext(ctx, `DELETE FROM task_dependencies WHERE task_uuid = ? AND depends_on_uuid = ?`, taskUUID, dependsOnUUID); err != nil { 733 + return fmt.Errorf("failed to remove dependency: %w", err) 746 734 } 747 - return err 735 + return nil 748 736 } 749 737 750 738 // ClearDependencies removes all dependencies for a given task. 751 739 func (r *TaskRepository) ClearDependencies(ctx context.Context, taskUUID string) error { 752 - _, err := r.db.ExecContext(ctx, 753 - `DELETE FROM task_dependencies WHERE task_uuid = ?`, 754 - taskUUID) 755 - if err != nil { 756 - err = fmt.Errorf("failed to clear dependencies: %w", err) 740 + if _, err := r.db.ExecContext(ctx, `DELETE FROM task_dependencies WHERE task_uuid = ?`, taskUUID); err != nil { 741 + return fmt.Errorf("failed to clear dependencies: %w", err) 757 742 } 758 - return err 743 + return nil 759 744 } 760 745 761 746 // GetDependencies returns the UUIDs of tasks this task depends on. 762 747 func (r *TaskRepository) GetDependencies(ctx context.Context, taskUUID string) ([]string, error) { 763 - rows, err := r.db.QueryContext(ctx, 764 - `SELECT depends_on_uuid FROM task_dependencies WHERE task_uuid = ?`, taskUUID) 748 + rows, err := r.db.QueryContext(ctx, `SELECT depends_on_uuid FROM task_dependencies WHERE task_uuid = ?`, taskUUID) 765 749 if err != nil { 766 750 return nil, fmt.Errorf("failed to get dependencies: %w", err) 767 751 } ··· 780 764 781 765 // PopulateDependencies loads dependency UUIDs from task_dependencies table into task.DependsOn 782 766 func (r *TaskRepository) PopulateDependencies(ctx context.Context, task *models.Task) error { 783 - deps, err := r.GetDependencies(ctx, task.UUID) 784 - if err != nil { 767 + if deps, err := r.GetDependencies(ctx, task.UUID); err != nil { 785 768 return err 769 + } else { 770 + task.DependsOn = deps 786 771 } 787 - task.DependsOn = deps 788 772 return nil 789 773 } 790 774 ··· 792 776 func (r *TaskRepository) GetDependents(ctx context.Context, blockingUUID string) ([]*models.Task, error) { 793 777 query := ` 794 778 SELECT t.id, t.uuid, t.description, t.status, t.priority, t.project, t.context, 795 - t.tags, t.due, t.entry, t.modified, t.end, t.start, t.annotations, 796 - t.recur, t.until, t.parent_uuid 797 - FROM tasks t 798 - JOIN task_dependencies d ON t.uuid = d.task_uuid 799 - WHERE d.depends_on_uuid = ?` 779 + t.tags, t.due, t.entry, t.modified, t.end, t.start, t.annotations, t.recur, t.until, t.parent_uuid 780 + FROM tasks t JOIN task_dependencies d ON t.uuid = d.task_uuid WHERE d.depends_on_uuid = ?` 800 781 801 782 rows, err := r.db.QueryContext(ctx, query, blockingUUID) 802 783 if err != nil { ··· 812 793 } 813 794 tasks = append(tasks, task) 814 795 } 815 - return tasks, rows.Err() 796 + if err := rows.Err(); err != nil { 797 + return nil, err 798 + } 799 + 800 + for _, task := range tasks { 801 + if err := r.PopulateDependencies(ctx, task); err != nil { 802 + return nil, fmt.Errorf("failed to populate dependencies: %w", err) 803 + } 804 + } 805 + return tasks, nil 816 806 } 817 807 818 808 // GetBlockedTasks finds tasks that are blocked by a given UUID. ··· 837 827 } 838 828 tasks = append(tasks, task) 839 829 } 840 - return tasks, rows.Err() 830 + if err := rows.Err(); err != nil { 831 + return nil, err 832 + } 833 + 834 + for _, task := range tasks { 835 + if err := r.PopulateDependencies(ctx, task); err != nil { 836 + return nil, fmt.Errorf("failed to populate dependencies: %w", err) 837 + } 838 + } 839 + return tasks, nil 841 840 }
+248 -125
internal/repo/task_repository_test.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "fmt" 5 6 "slices" 6 7 "testing" 7 8 "time" ··· 63 64 }) 64 65 65 66 t.Run("when called with context cancellation", func(t *testing.T) { 66 - cancelCtx, cancel := context.WithCancel(ctx) 67 - cancel() 68 - 69 67 task := CreateSampleTask() 70 - _, err := repo.Create(cancelCtx, task) 71 - if err == nil { 72 - t.Error("Expected error with cancelled context") 73 - } 68 + _, err := repo.Create(NewCanceledContext(), task) 69 + AssertError(t, err, "Expected error with cancelled context") 74 70 }) 75 71 }) 76 72 }) ··· 160 156 t.Run("when called with context cancellation", func(t *testing.T) { 161 157 task := CreateSampleTask() 162 158 _, err := repo.Create(ctx, task) 163 - if err != nil { 164 - t.Fatalf("Failed to create task: %v", err) 165 - } 166 - 167 - cancelCtx, cancel := context.WithCancel(ctx) 168 - cancel() 159 + AssertNoError(t, err, "Failed to create task") 169 160 170 161 task.Description = "Updated" 171 - err = repo.Update(cancelCtx, task) 172 - if err == nil { 173 - t.Error("Expected error with cancelled context") 174 - } 162 + err = repo.Update(NewCanceledContext(), task) 163 + AssertError(t, err, "Expected error with cancelled context") 175 164 }) 176 165 }) 177 166 }) ··· 344 333 }) 345 334 346 335 t.Run("GetByUUID with context cancellation", func(t *testing.T) { 347 - cancelCtx, cancel := context.WithCancel(ctx) 348 - cancel() 349 - 350 - _, err := repo.GetByUUID(cancelCtx, task1.UUID) 351 - if err == nil { 352 - t.Error("Expected error with cancelled context") 353 - } 336 + _, err := repo.GetByUUID(NewCanceledContext(), task1.UUID) 337 + AssertError(t, err, "Expected error with cancelled context") 354 338 }) 355 339 }) 356 340 ··· 402 386 }) 403 387 404 388 t.Run("Count with context cancellation", func(t *testing.T) { 405 - cancelCtx, cancel := context.WithCancel(ctx) 406 - cancel() 407 - 408 - _, err := repo.Count(cancelCtx, TaskListOptions{}) 409 - if err == nil { 410 - t.Error("Expected error with cancelled context") 411 - } 389 + _, err := repo.Count(NewCanceledContext(), TaskListOptions{}) 390 + AssertError(t, err, "Expected error with cancelled context") 412 391 }) 413 392 }) 414 393 ··· 934 913 t.Errorf("expected no dependencies after clear, got %v", deps) 935 914 } 936 915 }) 937 - } 916 + 917 + t.Run("Error Paths", func(t *testing.T) { 918 + t.Run("Create fails on MarshalTags error", func(t *testing.T) { 919 + orig := marshalTaskTags 920 + marshalTaskTags = func(t *models.Task) (string, error) { 921 + return "", fmt.Errorf("marshal fail") 922 + } 923 + defer func() { marshalTaskTags = orig }() 924 + 925 + _, err := repo.Create(ctx, CreateSampleTask()) 926 + AssertError(t, err, "expected MarshalTags error") 927 + AssertContains(t, err.Error(), "failed to marshal tags", "error message") 928 + }) 929 + 930 + t.Run("Create fails on MarshalAnnotations error", func(t *testing.T) { 931 + orig := marshalTaskAnnotations 932 + marshalTaskAnnotations = func(t *models.Task) (string, error) { 933 + return "", fmt.Errorf("marshal fail") 934 + } 935 + defer func() { marshalTaskAnnotations = orig }() 936 + 937 + _, err := repo.Create(ctx, CreateSampleTask()) 938 + AssertError(t, err, "expected MarshalAnnotations error") 939 + AssertContains(t, err.Error(), "failed to marshal annotations", "error message") 940 + }) 941 + 942 + t.Run("Update fails on MarshalTags error", func(t *testing.T) { 943 + task := CreateSampleTask() 944 + id, err := repo.Create(ctx, task) 945 + AssertNoError(t, err, "create should succeed") 946 + 947 + orig := marshalTaskTags 948 + marshalTaskTags = func(t *models.Task) (string, error) { 949 + return "", fmt.Errorf("marshal fail") 950 + } 951 + defer func() { marshalTaskTags = orig }() 952 + 953 + task.ID = id 954 + err = repo.Update(ctx, task) 955 + AssertError(t, err, "expected MarshalTags error") 956 + AssertContains(t, err.Error(), "failed to marshal tags", "error message") 957 + }) 938 958 939 - func TestTaskRepository_GetContexts(t *testing.T) { 940 - db := CreateTestDB(t) 941 - repo := NewTaskRepository(db) 942 - ctx := context.Background() 959 + t.Run("Update fails on MarshalAnnotations error", func(t *testing.T) { 960 + task := CreateSampleTask() 961 + id, err := repo.Create(ctx, task) 962 + AssertNoError(t, err, "create should succeed") 943 963 944 - task1 := CreateSampleTask() 945 - task1.Context = "work" 946 - _, err := repo.Create(ctx, task1) 947 - if err != nil { 948 - t.Fatalf("Failed to create task1: %v", err) 949 - } 964 + orig := marshalTaskAnnotations 965 + marshalTaskAnnotations = func(t *models.Task) (string, error) { 966 + return "", fmt.Errorf("marshal fail") 967 + } 968 + defer func() { marshalTaskAnnotations = orig }() 950 969 951 - task2 := CreateSampleTask() 952 - task2.Context = "home" 953 - _, err = repo.Create(ctx, task2) 954 - if err != nil { 955 - t.Fatalf("Failed to create task2: %v", err) 956 - } 970 + task.ID = id 971 + err = repo.Update(ctx, task) 972 + AssertError(t, err, "expected MarshalAnnotations error") 973 + AssertContains(t, err.Error(), "failed to marshal annotations", "error message") 974 + }) 975 + 976 + t.Run("Get fails on UnmarshalTags error", func(t *testing.T) { 977 + task := CreateSampleTask() 978 + task.Tags = []string{"test"} 979 + id, err := repo.Create(ctx, task) 980 + AssertNoError(t, err, "create should succeed") 981 + 982 + orig := unmarshalTaskTags 983 + unmarshalTaskTags = func(t *models.Task, s string) error { 984 + return fmt.Errorf("unmarshal fail") 985 + } 986 + defer func() { unmarshalTaskTags = orig }() 987 + 988 + _, err = repo.Get(ctx, id) 989 + AssertError(t, err, "expected UnmarshalTags error") 990 + AssertContains(t, err.Error(), "failed to unmarshal tags", "error message") 991 + }) 992 + 993 + t.Run("Get fails on UnmarshalAnnotations error", func(t *testing.T) { 994 + task := CreateSampleTask() 995 + task.Annotations = []string{"test"} 996 + id, err := repo.Create(ctx, task) 997 + AssertNoError(t, err, "create should succeed") 998 + 999 + orig := unmarshalTaskAnnotations 1000 + unmarshalTaskAnnotations = func(t *models.Task, s string) error { 1001 + return fmt.Errorf("unmarshal fail") 1002 + } 1003 + defer func() { unmarshalTaskAnnotations = orig }() 1004 + 1005 + _, err = repo.Get(ctx, id) 1006 + AssertError(t, err, "expected UnmarshalAnnotations error") 1007 + AssertContains(t, err.Error(), "failed to unmarshal annotations", "error message") 1008 + }) 1009 + 1010 + t.Run("GetByUUID fails on UnmarshalTags error", func(t *testing.T) { 1011 + task := CreateSampleTask() 1012 + task.Tags = []string{"test"} 1013 + _, err := repo.Create(ctx, task) 1014 + AssertNoError(t, err, "create should succeed") 1015 + 1016 + orig := unmarshalTaskTags 1017 + unmarshalTaskTags = func(t *models.Task, s string) error { 1018 + return fmt.Errorf("unmarshal fail") 1019 + } 1020 + defer func() { unmarshalTaskTags = orig }() 1021 + 1022 + _, err = repo.GetByUUID(ctx, task.UUID) 1023 + AssertError(t, err, "expected UnmarshalTags error") 1024 + AssertContains(t, err.Error(), "failed to unmarshal tags", "error message") 1025 + }) 1026 + 1027 + t.Run("GetByUUID fails on UnmarshalAnnotations error", func(t *testing.T) { 1028 + task := CreateSampleTask() 1029 + task.Annotations = []string{"test"} 1030 + _, err := repo.Create(ctx, task) 1031 + AssertNoError(t, err, "create should succeed") 1032 + 1033 + orig := unmarshalTaskAnnotations 1034 + unmarshalTaskAnnotations = func(t *models.Task, s string) error { 1035 + return fmt.Errorf("unmarshal fail") 1036 + } 1037 + defer func() { unmarshalTaskAnnotations = orig }() 1038 + 1039 + _, err = repo.GetByUUID(ctx, task.UUID) 1040 + AssertError(t, err, "expected UnmarshalAnnotations error") 1041 + AssertContains(t, err.Error(), "failed to unmarshal annotations", "error message") 1042 + }) 1043 + }) 1044 + 1045 + t.Run("GetContexts", func(t *testing.T) { 1046 + 1047 + task1 := CreateSampleTask() 1048 + task1.Context = "work" 1049 + _, err := repo.Create(ctx, task1) 1050 + if err != nil { 1051 + t.Fatalf("Failed to create task1: %v", err) 1052 + } 957 1053 958 - task3 := CreateSampleTask() 959 - task3.Context = "work" 960 - _, err = repo.Create(ctx, task3) 961 - if err != nil { 962 - t.Fatalf("Failed to create task3: %v", err) 963 - } 1054 + task2 := CreateSampleTask() 1055 + task2.Context = "home" 1056 + _, err = repo.Create(ctx, task2) 1057 + if err != nil { 1058 + t.Fatalf("Failed to create task2: %v", err) 1059 + } 964 1060 965 - task4 := CreateSampleTask() 966 - task4.Context = "" 967 - _, err = repo.Create(ctx, task4) 968 - if err != nil { 969 - t.Fatalf("Failed to create task4: %v", err) 970 - } 1061 + task3 := CreateSampleTask() 1062 + task3.Context = "work" 1063 + _, err = repo.Create(ctx, task3) 1064 + if err != nil { 1065 + t.Fatalf("Failed to create task3: %v", err) 1066 + } 971 1067 972 - contexts, err := repo.GetContexts(ctx) 973 - if err != nil { 974 - t.Fatalf("Failed to get contexts: %v", err) 975 - } 1068 + task4 := CreateSampleTask() 1069 + task4.Context = "" 1070 + _, err = repo.Create(ctx, task4) 1071 + if err != nil { 1072 + t.Fatalf("Failed to create task4: %v", err) 1073 + } 976 1074 977 - if len(contexts) != 2 { 978 - t.Errorf("Expected 2 contexts, got %d", len(contexts)) 979 - } 1075 + contexts, err := repo.GetContexts(ctx) 1076 + if err != nil { 1077 + t.Fatalf("Failed to get contexts: %v", err) 1078 + } 980 1079 981 - expectedCounts := map[string]int{ 982 - "home": 1, 983 - "work": 2, 984 - } 1080 + if len(contexts) < 2 { 1081 + t.Errorf("Expected at least 2 contexts, got %d", len(contexts)) 1082 + } 985 1083 986 - for _, context := range contexts { 987 - expected, exists := expectedCounts[context.Name] 988 - if !exists { 989 - t.Errorf("Unexpected context: %s", context.Name) 1084 + expectedCounts := map[string]int{ 1085 + "home": 1, 1086 + "work": 2, 1087 + "test-context": 14, 990 1088 } 991 - if context.TaskCount != expected { 992 - t.Errorf("Expected %d tasks for context %s, got %d", expected, context.Name, context.TaskCount) 1089 + 1090 + for _, context := range contexts { 1091 + expected, exists := expectedCounts[context.Name] 1092 + if !exists { 1093 + t.Errorf("Unexpected context: %s", context.Name) 1094 + } 1095 + if context.TaskCount < expected { 1096 + t.Errorf("Expected at least %d tasks for context %s, got %d", expected, context.Name, context.TaskCount) 1097 + } 993 1098 } 994 - } 995 - } 1099 + }) 996 1100 997 - func TestTaskRepository_GetByContext(t *testing.T) { 998 - db := CreateTestDB(t) 999 - repo := NewTaskRepository(db) 1000 - ctx := context.Background() 1101 + t.Run("GetByContext", func(t *testing.T) { 1102 + task1 := NewTaskBuilder().WithContext("work").WithDescription("Work task 1").Build() 1103 + _, err := repo.Create(ctx, task1) 1104 + AssertNoError(t, err, "Failed to create task1") 1001 1105 1002 - task1 := CreateSampleTask() 1003 - task1.Context = "work" 1004 - task1.Description = "Work task 1" 1005 - _, err := repo.Create(ctx, task1) 1006 - if err != nil { 1007 - t.Fatalf("Failed to create task1: %v", err) 1008 - } 1106 + task2 := NewTaskBuilder().WithContext("home").WithDescription("Home task 1").Build() 1107 + _, err = repo.Create(ctx, task2) 1108 + AssertNoError(t, err, "Failed to create task2") 1009 1109 1010 - task2 := CreateSampleTask() 1011 - task2.Context = "home" 1012 - task2.Description = "Home task 1" 1013 - _, err = repo.Create(ctx, task2) 1014 - if err != nil { 1015 - t.Fatalf("Failed to create task2: %v", err) 1016 - } 1110 + task3 := NewTaskBuilder().WithContext("work").WithDescription("Work task 2").Build() 1111 + _, err = repo.Create(ctx, task3) 1112 + AssertNoError(t, err, "Failed to create task3") 1017 1113 1018 - task3 := CreateSampleTask() 1019 - task3.Context = "work" 1020 - task3.Description = "Work task 2" 1021 - _, err = repo.Create(ctx, task3) 1022 - if err != nil { 1023 - t.Fatalf("Failed to create task3: %v", err) 1024 - } 1114 + workTasks, err := repo.GetByContext(ctx, "work") 1115 + if err != nil { 1116 + t.Fatalf("Failed to get tasks by context: %v", err) 1117 + } 1025 1118 1026 - workTasks, err := repo.GetByContext(ctx, "work") 1027 - if err != nil { 1028 - t.Fatalf("Failed to get tasks by context: %v", err) 1029 - } 1119 + if len(workTasks) < 2 { 1120 + t.Errorf("Expected at least 2 work tasks, got %d", len(workTasks)) 1121 + } 1030 1122 1031 - if len(workTasks) != 2 { 1032 - t.Errorf("Expected 2 work tasks, got %d", len(workTasks)) 1033 - } 1123 + for _, task := range workTasks { 1124 + if task.Context != "work" { 1125 + t.Errorf("Expected context 'work', got '%s'", task.Context) 1126 + } 1127 + } 1034 1128 1035 - for _, task := range workTasks { 1036 - if task.Context != "work" { 1037 - t.Errorf("Expected context 'work', got '%s'", task.Context) 1129 + homeTasks, err := repo.GetByContext(ctx, "home") 1130 + if err != nil { 1131 + t.Fatalf("Failed to get tasks by context: %v", err) 1132 + } 1133 + if len(homeTasks) < 1 { 1134 + t.Errorf("Expected at least 1 home task, got %d", len(homeTasks)) 1135 + } 1136 + if homeTasks[0].Context != "home" { 1137 + t.Errorf("Expected context 'home', got '%s'", homeTasks[0].Context) 1038 1138 } 1039 - } 1139 + }) 1040 1140 1041 - homeTasks, err := repo.GetByContext(ctx, "home") 1042 - if err != nil { 1043 - t.Fatalf("Failed to get tasks by context: %v", err) 1044 - } 1141 + t.Run("GetBlockedTasks", func(t *testing.T) { 1142 + blocker := CreateSampleTask() 1143 + blocker.Description = "Blocker task" 1144 + _, err := repo.Create(ctx, blocker) 1145 + AssertNoError(t, err, "create blocker should succeed") 1045 1146 1046 - if len(homeTasks) != 1 { 1047 - t.Errorf("Expected 1 home task, got %d", len(homeTasks)) 1048 - } 1147 + blocked1 := CreateSampleTask() 1148 + blocked1.Description = "Blocked task 1" 1149 + blocked1.DependsOn = []string{blocker.UUID} 1150 + _, err = repo.Create(ctx, blocked1) 1151 + AssertNoError(t, err, "create blocked1 should succeed") 1049 1152 1050 - if homeTasks[0].Context != "home" { 1051 - t.Errorf("Expected context 'home', got '%s'", homeTasks[0].Context) 1052 - } 1153 + blocked2 := CreateSampleTask() 1154 + blocked2.Description = "Blocked task 2" 1155 + blocked2.DependsOn = []string{blocker.UUID} 1156 + _, err = repo.Create(ctx, blocked2) 1157 + AssertNoError(t, err, "create blocked2 should succeed") 1158 + 1159 + independent := CreateSampleTask() 1160 + independent.Description = "Independent task" 1161 + _, err = repo.Create(ctx, independent) 1162 + AssertNoError(t, err, "create independent should succeed") 1163 + 1164 + blockedTasks, err := repo.GetBlockedTasks(ctx, blocker.UUID) 1165 + AssertNoError(t, err, "GetBlockedTasks should succeed") 1166 + AssertEqual(t, 2, len(blockedTasks), "should find 2 blocked tasks") 1167 + 1168 + for _, task := range blockedTasks { 1169 + AssertTrue(t, slices.Contains(task.DependsOn, blocker.UUID), "task should depend on blocker") 1170 + } 1171 + 1172 + emptyBlocked, err := repo.GetBlockedTasks(ctx, independent.UUID) 1173 + AssertNoError(t, err, "GetBlockedTasks for independent should succeed") 1174 + AssertEqual(t, 0, len(emptyBlocked), "independent task should not block anything") 1175 + }) 1053 1176 }
+127 -106
internal/repo/test_utilities.go
··· 12 12 "github.com/jaswdr/faker/v2" 13 13 _ "github.com/mattn/go-sqlite3" 14 14 "github.com/stormlightlabs/noteleaf/internal/models" 15 + "github.com/stormlightlabs/noteleaf/internal/store" 15 16 ) 16 17 17 18 var fake = faker.New() 18 19 19 - const testSchema string = ` 20 - CREATE TABLE IF NOT EXISTS tasks ( 21 - id INTEGER PRIMARY KEY AUTOINCREMENT, 22 - uuid TEXT UNIQUE NOT NULL, 23 - description TEXT NOT NULL, 24 - status TEXT DEFAULT 'pending', 25 - priority TEXT, 26 - project TEXT, 27 - context TEXT, 28 - tags TEXT, 29 - due DATETIME, 30 - entry DATETIME DEFAULT CURRENT_TIMESTAMP, 31 - modified DATETIME DEFAULT CURRENT_TIMESTAMP, 32 - end DATETIME, 33 - start DATETIME, 34 - annotations TEXT, 35 - recur TEXT, 36 - until DATETIME, 37 - parent_uuid TEXT 38 - ); 39 - 40 - CREATE TABLE IF NOT EXISTS task_dependencies ( 41 - id INTEGER PRIMARY KEY AUTOINCREMENT, 42 - task_uuid TEXT NOT NULL, 43 - depends_on_uuid TEXT NOT NULL, 44 - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, 45 - 46 - FOREIGN KEY(task_uuid) REFERENCES tasks(uuid) ON DELETE CASCADE, 47 - FOREIGN KEY(depends_on_uuid) REFERENCES tasks(uuid) ON DELETE CASCADE 48 - ); 49 - 50 - CREATE TABLE IF NOT EXISTS books ( 51 - id INTEGER PRIMARY KEY AUTOINCREMENT, 52 - title TEXT NOT NULL, 53 - author TEXT, 54 - status TEXT DEFAULT 'queued', 55 - progress INTEGER DEFAULT 0, 56 - pages INTEGER, 57 - rating REAL, 58 - notes TEXT, 59 - added DATETIME DEFAULT CURRENT_TIMESTAMP, 60 - started DATETIME, 61 - finished DATETIME 62 - ); 63 - 64 - CREATE TABLE IF NOT EXISTS movies ( 65 - id INTEGER PRIMARY KEY AUTOINCREMENT, 66 - title TEXT NOT NULL, 67 - year INTEGER, 68 - status TEXT DEFAULT 'queued', 69 - rating REAL, 70 - notes TEXT, 71 - added DATETIME DEFAULT CURRENT_TIMESTAMP, 72 - watched DATETIME 73 - ); 74 - 75 - CREATE TABLE IF NOT EXISTS tv_shows ( 76 - id INTEGER PRIMARY KEY AUTOINCREMENT, 77 - title TEXT NOT NULL, 78 - season INTEGER, 79 - episode INTEGER, 80 - status TEXT DEFAULT 'queued', 81 - rating REAL, 82 - notes TEXT, 83 - added DATETIME DEFAULT CURRENT_TIMESTAMP, 84 - last_watched DATETIME 85 - ); 86 - 87 - CREATE TABLE IF NOT EXISTS notes ( 88 - id INTEGER PRIMARY KEY AUTOINCREMENT, 89 - title TEXT NOT NULL, 90 - content TEXT, 91 - tags TEXT, 92 - archived BOOLEAN DEFAULT FALSE, 93 - created DATETIME DEFAULT CURRENT_TIMESTAMP, 94 - modified DATETIME DEFAULT CURRENT_TIMESTAMP, 95 - file_path TEXT 96 - ); 97 - 98 - CREATE TABLE IF NOT EXISTS time_entries ( 99 - id INTEGER PRIMARY KEY AUTOINCREMENT, 100 - task_id INTEGER NOT NULL, 101 - start_time DATETIME NOT NULL, 102 - end_time DATETIME, 103 - duration_seconds INTEGER, 104 - description TEXT, 105 - created DATETIME DEFAULT CURRENT_TIMESTAMP, 106 - modified DATETIME DEFAULT CURRENT_TIMESTAMP, 107 - FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE 108 - ); 109 - 110 - CREATE TABLE IF NOT EXISTS articles ( 111 - id INTEGER PRIMARY KEY AUTOINCREMENT, 112 - url TEXT UNIQUE NOT NULL, 113 - title TEXT NOT NULL, 114 - author TEXT, 115 - date TEXT, 116 - markdown_path TEXT NOT NULL, 117 - html_path TEXT NOT NULL, 118 - created DATETIME DEFAULT CURRENT_TIMESTAMP, 119 - modified DATETIME DEFAULT CURRENT_TIMESTAMP 120 - ); 121 - ` 122 - 123 20 // CreateTestDB creates an in-memory SQLite database with the full schema for testing 124 21 func CreateTestDB(t *testing.T) *sql.DB { 22 + t.Helper() 125 23 db, err := sql.Open("sqlite3", ":memory:") 126 24 if err != nil { 127 25 t.Fatalf("Failed to create in-memory database: %v", err) ··· 131 29 t.Fatalf("Failed to enable foreign keys: %v", err) 132 30 } 133 31 134 - if _, err := db.Exec(testSchema); err != nil { 135 - t.Fatalf("Failed to create schema: %v", err) 32 + // if _, err := db.Exec(testSchema); err != nil { 33 + // t.Fatalf("Failed to create schema: %v", err) 34 + // } 35 + 36 + mr := store.NewMigrationRunner(&store.Database{DB: db}) 37 + if err := mr.RunMigrations(); err != nil { 38 + t.Errorf("failed to run migrations %v", err) 136 39 } 137 40 138 41 t.Cleanup(func() { ··· 317 220 if !strings.Contains(str, substr) { 318 221 t.Fatalf("%s: expected string '%s' to contain '%s'", msg, str, substr) 319 222 } 223 + } 224 + 225 + func AssertNil(t *testing.T, value interface{}, msg string) { 226 + t.Helper() 227 + if value != nil { 228 + t.Fatalf("%s: expected nil, got %v", msg, value) 229 + } 230 + } 231 + 232 + func AssertNotNil(t *testing.T, value interface{}, msg string) { 233 + t.Helper() 234 + if value == nil { 235 + t.Fatalf("%s: expected non-nil value", msg) 236 + } 237 + } 238 + 239 + func AssertGreaterThan[T interface{ int | int64 | float64 }](t *testing.T, actual, threshold T, msg string) { 240 + t.Helper() 241 + if actual <= threshold { 242 + t.Fatalf("%s: expected %v > %v", msg, actual, threshold) 243 + } 244 + } 245 + 246 + func AssertLessThan[T interface{ int | int64 | float64 }](t *testing.T, actual, threshold T, msg string) { 247 + t.Helper() 248 + if actual >= threshold { 249 + t.Fatalf("%s: expected %v < %v", msg, actual, threshold) 250 + } 251 + } 252 + 253 + func AssertStringContains(t *testing.T, str, substr, msg string) { 254 + t.Helper() 255 + if !strings.Contains(str, substr) { 256 + t.Fatalf("%s: expected string to contain '%s', got '%s'", msg, substr, str) 257 + } 258 + } 259 + 260 + // NewCanceledContext returns a pre-canceled context for testing error conditions 261 + func NewCanceledContext() context.Context { 262 + ctx, cancel := context.WithCancel(context.Background()) 263 + cancel() 264 + return ctx 265 + } 266 + 267 + // TaskBuilder provides a fluent interface for building test tasks 268 + type TaskBuilder struct { 269 + task *models.Task 270 + } 271 + 272 + // NewTaskBuilder creates a new TaskBuilder with sensible defaults 273 + func NewTaskBuilder() *TaskBuilder { 274 + return &TaskBuilder{ 275 + task: &models.Task{ 276 + UUID: uuid.New().String(), 277 + Status: "pending", 278 + Entry: time.Now(), 279 + Modified: time.Now(), 280 + }, 281 + } 282 + } 283 + 284 + func (b *TaskBuilder) WithUUID(uuid string) *TaskBuilder { 285 + b.task.UUID = uuid 286 + return b 287 + } 288 + 289 + func (b *TaskBuilder) WithDescription(desc string) *TaskBuilder { 290 + b.task.Description = desc 291 + return b 292 + } 293 + 294 + func (b *TaskBuilder) WithStatus(status string) *TaskBuilder { 295 + b.task.Status = status 296 + return b 297 + } 298 + 299 + func (b *TaskBuilder) WithPriority(priority string) *TaskBuilder { 300 + b.task.Priority = priority 301 + return b 302 + } 303 + 304 + func (b *TaskBuilder) WithProject(project string) *TaskBuilder { 305 + b.task.Project = project 306 + return b 307 + } 308 + 309 + func (b *TaskBuilder) WithContext(ctx string) *TaskBuilder { 310 + b.task.Context = ctx 311 + return b 312 + } 313 + 314 + func (b *TaskBuilder) WithTags(tags []string) *TaskBuilder { 315 + b.task.Tags = tags 316 + return b 317 + } 318 + 319 + func (b *TaskBuilder) WithDue(due time.Time) *TaskBuilder { 320 + b.task.Due = &due 321 + return b 322 + } 323 + 324 + func (b *TaskBuilder) WithEnd(end time.Time) *TaskBuilder { 325 + b.task.End = &end 326 + return b 327 + } 328 + 329 + func (b *TaskBuilder) WithRecur(recur string) *TaskBuilder { 330 + b.task.Recur = models.RRule(recur) 331 + return b 332 + } 333 + 334 + func (b *TaskBuilder) WithDependsOn(deps []string) *TaskBuilder { 335 + b.task.DependsOn = deps 336 + return b 337 + } 338 + 339 + func (b *TaskBuilder) Build() *models.Task { 340 + return b.task 320 341 } 321 342 322 343 // SetupTestData creates sample data in the database and returns the repositories
internal/repo/time_entries.go internal/repo/time_entry_repository.go
internal/repo/time_entries_test.go internal/repo/time_entry_repository_test.go
+11 -11
internal/store/database.go
··· 12 12 ) 13 13 14 14 var ( 15 - sqlOpen = sql.Open 16 - pragmaExec = func(db *sql.DB, stmt string) (sql.Result, error) { return db.Exec(stmt) } 17 - newMigrationRunner = NewMigrationRunner 18 - getRuntime = func() string { return runtime.GOOS } 19 - getHomeDir = os.UserHomeDir 20 - mkdirAll = os.MkdirAll 15 + sqlOpen = sql.Open 16 + pragmaExec = func(db *sql.DB, stmt string) (sql.Result, error) { return db.Exec(stmt) } 17 + createMigrationRunner = CreateMigrationRunner 18 + getRuntime = func() string { return runtime.GOOS } 19 + getHomeDir = os.UserHomeDir 20 + mkdirAll = os.MkdirAll 21 21 ) 22 22 23 23 //go:embed sql/migrations ··· 113 113 return NewDatabaseWithConfig(nil) 114 114 } 115 115 116 - // NewDatabaseWithConfig creates and initializes a new database connection using the provided config 116 + // NewDatabaseWithConfig creates and initializes a new [Database] connection using the provided [Config] 117 117 func NewDatabaseWithConfig(config *Config) (*Database, error) { 118 118 if config == nil { 119 119 var err error ··· 156 156 } 157 157 158 158 database := &Database{DB: db, path: dbPath} 159 - runner := newMigrationRunner(db, migrationFiles) 159 + runner := createMigrationRunner(db, migrationFiles) 160 160 if err := runner.RunMigrations(); err != nil { 161 161 db.Close() 162 162 return nil, fmt.Errorf("failed to run migrations: %w", err) ··· 165 165 return database, nil 166 166 } 167 167 168 - // NewMigrationRunnerFromDB creates a new migration runner from a Database instance 169 - func (db *Database) NewMigrationRunner() *MigrationRunner { 170 - return newMigrationRunner(db.DB, migrationFiles) 168 + // NewMigrationRunner creates a new migration runner from a Database instance 169 + func NewMigrationRunner(db *Database) *MigrationRunner { 170 + return createMigrationRunner(db.DB, migrationFiles) 171 171 } 172 172 173 173 // GetPath returns the database file path
+3 -3
internal/store/database_test.go
··· 118 118 }) 119 119 120 120 t.Run("migration runner fails", func(t *testing.T) { 121 - orig := newMigrationRunner 122 - newMigrationRunner = func(db *sql.DB, fs FileSystem) *MigrationRunner { 121 + orig := createMigrationRunner 122 + createMigrationRunner = func(db *sql.DB, fs FileSystem) *MigrationRunner { 123 123 return &MigrationRunner{runFn: func() error { return fmt.Errorf("migration fail") }} 124 124 } 125 - t.Cleanup(func() { newMigrationRunner = orig }) 125 + t.Cleanup(func() { createMigrationRunner = orig }) 126 126 127 127 _, err := NewDatabase() 128 128 if err == nil || !strings.Contains(err.Error(), "failed to run migrations") {
+2 -2
internal/store/migration.go
··· 31 31 runFn func() error // inject for testing 32 32 } 33 33 34 - // NewMigrationRunner creates a new migration runner 35 - func NewMigrationRunner(db *sql.DB, files FileSystem) *MigrationRunner { 34 + // CreateMigrationRunner creates a new migration runner 35 + func CreateMigrationRunner(db *sql.DB, files FileSystem) *MigrationRunner { 36 36 mr := &MigrationRunner{ 37 37 db: db, 38 38 migrationFiles: files,
+28 -28
internal/store/migration_test.go
··· 83 83 func TestNewMigrationRunner(t *testing.T) { 84 84 db := createTestDB(t) 85 85 86 - runner := NewMigrationRunner(db, testMigrationFiles) 86 + runner := CreateMigrationRunner(db, testMigrationFiles) 87 87 if runner == nil { 88 88 t.Fatal("NewMigrationRunner should not return nil") 89 89 } ··· 96 96 func TestMigrationRunner_RunMigrations(t *testing.T) { 97 97 t.Run("runs migrations successfully", func(t *testing.T) { 98 98 db := createTestDB(t) 99 - runner := NewMigrationRunner(db, testMigrationFiles) 99 + runner := CreateMigrationRunner(db, testMigrationFiles) 100 100 101 101 err := runner.RunMigrations() 102 102 if err != nil { ··· 128 128 db := createTestDB(t) 129 129 130 130 emptyFS := embed.FS{} 131 - runner := NewMigrationRunner(db, emptyFS) 131 + runner := CreateMigrationRunner(db, emptyFS) 132 132 133 133 err := runner.RunMigrations() 134 134 if err == nil { ··· 140 140 db := createTestDB(t) 141 141 db.Close() 142 142 143 - runner := NewMigrationRunner(db, testMigrationFiles) 143 + runner := CreateMigrationRunner(db, testMigrationFiles) 144 144 err := runner.RunMigrations() 145 145 if err == nil { 146 146 t.Error("RunMigrations should fail when database connection is closed") ··· 151 151 db := createTestDB(t) 152 152 153 153 fakeFS := &fakeMigrationFS{shouldFailRead: true, hasNewMigrations: true} 154 - runner := NewMigrationRunner(db, fakeFS) 154 + runner := CreateMigrationRunner(db, fakeFS) 155 155 156 156 err := runner.RunMigrations() 157 157 if err == nil { ··· 163 163 db := createTestDB(t) 164 164 165 165 fakeFS := &fakeMigrationFS{invalidSQL: true, hasNewMigrations: true} 166 - runner := NewMigrationRunner(db, fakeFS) 166 + runner := CreateMigrationRunner(db, fakeFS) 167 167 168 168 err := runner.RunMigrations() 169 169 if err == nil { ··· 173 173 174 174 t.Run("handles migration record insertion failure", func(t *testing.T) { 175 175 db := createTestDB(t) 176 - runner := NewMigrationRunner(db, testMigrationFiles) 176 + runner := CreateMigrationRunner(db, testMigrationFiles) 177 177 178 178 err := runner.RunMigrations() 179 179 if err != nil { ··· 198 198 199 199 t.Run("skips already applied migrations", func(t *testing.T) { 200 200 db := createTestDB(t) 201 - runner := NewMigrationRunner(db, testMigrationFiles) 201 + runner := CreateMigrationRunner(db, testMigrationFiles) 202 202 203 203 err := runner.RunMigrations() 204 204 if err != nil { ··· 229 229 230 230 t.Run("creates expected tables", func(t *testing.T) { 231 231 db := createTestDB(t) 232 - runner := NewMigrationRunner(db, testMigrationFiles) 232 + runner := CreateMigrationRunner(db, testMigrationFiles) 233 233 234 234 err := runner.RunMigrations() 235 235 if err != nil { ··· 255 255 func TestMigrationRunner_GetAppliedMigrations(t *testing.T) { 256 256 t.Run("returns empty list when no migrations table", func(t *testing.T) { 257 257 db := createTestDB(t) 258 - runner := NewMigrationRunner(db, testMigrationFiles) 258 + runner := CreateMigrationRunner(db, testMigrationFiles) 259 259 260 260 migrations, err := runner.GetAppliedMigrations() 261 261 if err != nil { ··· 270 270 t.Run("handles database connection failure", func(t *testing.T) { 271 271 db := createTestDB(t) 272 272 db.Close() 273 - runner := NewMigrationRunner(db, testMigrationFiles) 273 + runner := CreateMigrationRunner(db, testMigrationFiles) 274 274 275 275 _, err := runner.GetAppliedMigrations() 276 276 if err == nil { ··· 280 280 281 281 t.Run("handles query execution failure", func(t *testing.T) { 282 282 db := createTestDB(t) 283 - runner := NewMigrationRunner(db, testMigrationFiles) 283 + runner := CreateMigrationRunner(db, testMigrationFiles) 284 284 285 285 err := runner.RunMigrations() 286 286 if err != nil { ··· 298 298 299 299 t.Run("handles row scan failure", func(t *testing.T) { 300 300 db := createTestDB(t) 301 - runner := NewMigrationRunner(db, testMigrationFiles) 301 + runner := CreateMigrationRunner(db, testMigrationFiles) 302 302 303 303 err := runner.RunMigrations() 304 304 if err != nil { ··· 319 319 320 320 t.Run("returns applied migrations", func(t *testing.T) { 321 321 db := createTestDB(t) 322 - runner := NewMigrationRunner(db, testMigrationFiles) 322 + runner := CreateMigrationRunner(db, testMigrationFiles) 323 323 324 324 // Run migrations first 325 325 err := runner.RunMigrations() ··· 359 359 func TestMigrationRunner_GetAvailableMigrations(t *testing.T) { 360 360 t.Run("returns available migrations from embedded files", func(t *testing.T) { 361 361 db := createTestDB(t) 362 - runner := NewMigrationRunner(db, testMigrationFiles) 362 + runner := CreateMigrationRunner(db, testMigrationFiles) 363 363 364 364 migrations, err := runner.GetAvailableMigrations() 365 365 if err != nil { ··· 391 391 db := createTestDB(t) 392 392 393 393 emptyFS := embed.FS{} 394 - runner := NewMigrationRunner(db, emptyFS) 394 + runner := CreateMigrationRunner(db, emptyFS) 395 395 396 396 _, err := runner.GetAvailableMigrations() 397 397 if err == nil { ··· 403 403 db := createTestDB(t) 404 404 405 405 fakeFS := &fakeMigrationFS{shouldFailRead: true} 406 - runner := NewMigrationRunner(db, fakeFS) 406 + runner := CreateMigrationRunner(db, fakeFS) 407 407 408 408 _, err := runner.GetAvailableMigrations() 409 409 if err == nil { ··· 413 413 414 414 t.Run("includes both up and down SQL when available", func(t *testing.T) { 415 415 db := createTestDB(t) 416 - runner := NewMigrationRunner(db, testMigrationFiles) 416 + runner := CreateMigrationRunner(db, testMigrationFiles) 417 417 418 418 migrations, err := runner.GetAvailableMigrations() 419 419 if err != nil { ··· 437 437 func TestMigrationRunner_Rollback(t *testing.T) { 438 438 t.Run("fails when no migrations to rollback", func(t *testing.T) { 439 439 db := createTestDB(t) 440 - runner := NewMigrationRunner(db, testMigrationFiles) 440 + runner := CreateMigrationRunner(db, testMigrationFiles) 441 441 442 442 err := runner.Rollback() 443 443 if err == nil { ··· 447 447 448 448 t.Run("handles database connection failure", func(t *testing.T) { 449 449 db := createTestDB(t) 450 - runner := NewMigrationRunner(db, testMigrationFiles) 450 + runner := CreateMigrationRunner(db, testMigrationFiles) 451 451 452 452 err := runner.RunMigrations() 453 453 if err != nil { ··· 464 464 465 465 t.Run("handles migration directory read failure during rollback", func(t *testing.T) { 466 466 db := createTestDB(t) 467 - runner := NewMigrationRunner(db, testMigrationFiles) 467 + runner := CreateMigrationRunner(db, testMigrationFiles) 468 468 469 469 err := runner.RunMigrations() 470 470 if err != nil { ··· 482 482 483 483 t.Run("handles missing down migration file", func(t *testing.T) { 484 484 db := createTestDB(t) 485 - runner := NewMigrationRunner(db, testMigrationFiles) 485 + runner := CreateMigrationRunner(db, testMigrationFiles) 486 486 487 487 err := runner.RunMigrations() 488 488 if err != nil { ··· 502 502 db := createTestDB(t) 503 503 504 504 fakeFS := &fakeMigrationFS{} 505 - runner := NewMigrationRunner(db, fakeFS) 505 + runner := CreateMigrationRunner(db, fakeFS) 506 506 507 507 err := runner.RunMigrations() 508 508 if err != nil { ··· 521 521 db := createTestDB(t) 522 522 523 523 fakeFS := &fakeMigrationFS{} 524 - runner := NewMigrationRunner(db, fakeFS) 524 + runner := CreateMigrationRunner(db, fakeFS) 525 525 526 526 err := runner.RunMigrations() 527 527 if err != nil { ··· 538 538 539 539 t.Run("handles migration record deletion failure", func(t *testing.T) { 540 540 db := createTestDB(t) 541 - runner := NewMigrationRunner(db, testMigrationFiles) 541 + runner := CreateMigrationRunner(db, testMigrationFiles) 542 542 543 543 err := runner.RunMigrations() 544 544 if err != nil { ··· 558 558 559 559 t.Run("rolls back last migration", func(t *testing.T) { 560 560 db := createTestDB(t) 561 - runner := NewMigrationRunner(db, testMigrationFiles) 561 + runner := CreateMigrationRunner(db, testMigrationFiles) 562 562 563 563 err := runner.RunMigrations() 564 564 if err != nil { ··· 638 638 func TestMigrationIntegration(t *testing.T) { 639 639 t.Run("full migration lifecycle", func(t *testing.T) { 640 640 db := createTestDB(t) 641 - runner := NewMigrationRunner(db, testMigrationFiles) 641 + runner := CreateMigrationRunner(db, testMigrationFiles) 642 642 643 643 available, err := runner.GetAvailableMigrations() 644 644 if err != nil { ··· 682 682 683 683 t.Run("migration runner works with real database", func(t *testing.T) { 684 684 db := createTestDB(t) 685 - runner := NewMigrationRunner(db, migrationFiles) 685 + runner := CreateMigrationRunner(db, migrationFiles) 686 686 687 687 err := runner.RunMigrations() 688 688 if err != nil {