A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at codeberg-source 533 lines 15 kB view raw
1package db 2 3import ( 4 "context" 5 "net/http" 6 "net/http/httptest" 7 "strings" 8 "testing" 9 "time" 10) 11 12// setupSessionTestDB creates an in-memory SQLite database for testing 13func setupSessionTestDB(t *testing.T) *SessionStore { 14 t.Helper() 15 // Use file::memory: with cache=shared to ensure all connections share the same in-memory DB 16 db, err := InitDB("file::memory:?cache=shared", true) 17 if err != nil { 18 t.Fatalf("Failed to initialize test database: %v", err) 19 } 20 // Limit to single connection to avoid race conditions in tests 21 db.SetMaxOpenConns(1) 22 t.Cleanup(func() { 23 db.Close() 24 }) 25 return NewSessionStore(db) 26} 27 28// createSessionTestUser creates a test user in the database 29func createSessionTestUser(t *testing.T, store *SessionStore, did, handle string) { 30 t.Helper() 31 _, err := store.db.Exec(` 32 INSERT OR IGNORE INTO users (did, handle, pds_endpoint, last_seen) 33 VALUES (?, ?, ?, datetime('now')) 34 `, did, handle, "https://pds.example.com") 35 if err != nil { 36 t.Fatalf("Failed to create test user: %v", err) 37 } 38} 39 40func TestSession_Struct(t *testing.T) { 41 sess := &Session{ 42 ID: "test-session", 43 DID: "did:plc:test", 44 Handle: "alice.bsky.social", 45 PDSEndpoint: "https://bsky.social", 46 OAuthSessionID: "oauth-123", 47 ExpiresAt: time.Now().Add(1 * time.Hour), 48 } 49 50 if sess.DID != "did:plc:test" { 51 t.Errorf("Expected DID, got %q", sess.DID) 52 } 53} 54 55// TestSessionStore_Create tests session creation without OAuth 56func TestSessionStore_Create(t *testing.T) { 57 store := setupSessionTestDB(t) 58 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 59 60 sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 61 if err != nil { 62 t.Fatalf("Create() error = %v", err) 63 } 64 65 if sessionID == "" { 66 t.Error("Create() returned empty session ID") 67 } 68 69 // Verify session can be retrieved 70 sess, found := store.Get(sessionID) 71 if !found { 72 t.Error("Created session not found") 73 } 74 if sess == nil { 75 t.Fatal("Session is nil") 76 } 77 if sess.DID != "did:plc:alice123" { 78 t.Errorf("DID = %v, want did:plc:alice123", sess.DID) 79 } 80 if sess.Handle != "alice.bsky.social" { 81 t.Errorf("Handle = %v, want alice.bsky.social", sess.Handle) 82 } 83 if sess.OAuthSessionID != "" { 84 t.Errorf("OAuthSessionID should be empty, got %v", sess.OAuthSessionID) 85 } 86} 87 88// TestSessionStore_CreateWithOAuth tests session creation with OAuth 89func TestSessionStore_CreateWithOAuth(t *testing.T) { 90 store := setupSessionTestDB(t) 91 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 92 93 oauthSessionID := "oauth-123" 94 sessionID, err := store.CreateWithOAuth("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", oauthSessionID, 1*time.Hour) 95 if err != nil { 96 t.Fatalf("CreateWithOAuth() error = %v", err) 97 } 98 99 if sessionID == "" { 100 t.Error("CreateWithOAuth() returned empty session ID") 101 } 102 103 // Verify session has OAuth session ID 104 sess, found := store.Get(sessionID) 105 if !found { 106 t.Error("Created session not found") 107 } 108 if sess.OAuthSessionID != oauthSessionID { 109 t.Errorf("OAuthSessionID = %v, want %v", sess.OAuthSessionID, oauthSessionID) 110 } 111} 112 113// TestSessionStore_Get tests retrieving sessions 114func TestSessionStore_Get(t *testing.T) { 115 store := setupSessionTestDB(t) 116 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 117 118 // Create a valid session 119 validID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 120 if err != nil { 121 t.Fatalf("Create() error = %v", err) 122 } 123 124 // Create a session and manually expire it 125 expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 126 if err != nil { 127 t.Fatalf("Create() error = %v", err) 128 } 129 130 // Manually update expiration to the past 131 _, err = store.db.Exec(` 132 UPDATE ui_sessions 133 SET expires_at = datetime('now', '-1 hour') 134 WHERE id = ? 135 `, expiredID) 136 if err != nil { 137 t.Fatalf("Failed to update expiration: %v", err) 138 } 139 140 tests := []struct { 141 name string 142 sessionID string 143 wantFound bool 144 }{ 145 { 146 name: "valid session", 147 sessionID: validID, 148 wantFound: true, 149 }, 150 { 151 name: "expired session", 152 sessionID: expiredID, 153 wantFound: false, 154 }, 155 { 156 name: "non-existent session", 157 sessionID: "non-existent-id", 158 wantFound: false, 159 }, 160 } 161 162 for _, tt := range tests { 163 t.Run(tt.name, func(t *testing.T) { 164 sess, found := store.Get(tt.sessionID) 165 if found != tt.wantFound { 166 t.Errorf("Get() found = %v, want %v", found, tt.wantFound) 167 } 168 if tt.wantFound && sess == nil { 169 t.Error("Expected session, got nil") 170 } 171 }) 172 } 173} 174 175// TestSessionStore_Extend tests extending session expiration 176func TestSessionStore_Extend(t *testing.T) { 177 store := setupSessionTestDB(t) 178 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 179 180 sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 181 if err != nil { 182 t.Fatalf("Create() error = %v", err) 183 } 184 185 // Get initial expiration 186 sess1, _ := store.Get(sessionID) 187 initialExpiry := sess1.ExpiresAt 188 189 // Wait a bit to ensure time difference 190 time.Sleep(10 * time.Millisecond) 191 192 // Extend session 193 err = store.Extend(sessionID, 2*time.Hour) 194 if err != nil { 195 t.Errorf("Extend() error = %v", err) 196 } 197 198 // Verify expiration was updated 199 sess2, found := store.Get(sessionID) 200 if !found { 201 t.Fatal("Session not found after extend") 202 } 203 if !sess2.ExpiresAt.After(initialExpiry) { 204 t.Error("ExpiresAt should be later after extend") 205 } 206 207 // Test extending non-existent session 208 err = store.Extend("non-existent-id", 1*time.Hour) 209 if err == nil { 210 t.Error("Expected error when extending non-existent session") 211 } 212 if err != nil && !strings.Contains(err.Error(), "not found") { 213 t.Errorf("Expected 'not found' error, got %v", err) 214 } 215} 216 217// TestSessionStore_Delete tests deleting a session 218func TestSessionStore_Delete(t *testing.T) { 219 store := setupSessionTestDB(t) 220 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 221 222 sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 223 if err != nil { 224 t.Fatalf("Create() error = %v", err) 225 } 226 227 // Verify session exists 228 _, found := store.Get(sessionID) 229 if !found { 230 t.Fatal("Session should exist before delete") 231 } 232 233 // Delete session 234 store.Delete(sessionID) 235 236 // Verify session is gone 237 _, found = store.Get(sessionID) 238 if found { 239 t.Error("Session should not exist after delete") 240 } 241 242 // Deleting non-existent session should not error 243 store.Delete("non-existent-id") 244} 245 246// TestSessionStore_DeleteByDID tests deleting all sessions for a DID 247func TestSessionStore_DeleteByDID(t *testing.T) { 248 store := setupSessionTestDB(t) 249 did := "did:plc:alice123" 250 createSessionTestUser(t, store, did, "alice.bsky.social") 251 createSessionTestUser(t, store, "did:plc:bob123", "bob.bsky.social") 252 253 // Create multiple sessions for alice 254 sessionIDs := make([]string, 3) 255 for i := 0; i < 3; i++ { 256 id, err := store.Create(did, "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 257 if err != nil { 258 t.Fatalf("Create() error = %v", err) 259 } 260 sessionIDs[i] = id 261 } 262 263 // Create a session for bob 264 bobSessionID, err := store.Create("did:plc:bob123", "bob.bsky.social", "https://pds.example.com", 1*time.Hour) 265 if err != nil { 266 t.Fatalf("Create() error = %v", err) 267 } 268 269 // Delete all sessions for alice 270 store.DeleteByDID(did) 271 272 // Verify alice's sessions are gone 273 for _, id := range sessionIDs { 274 _, found := store.Get(id) 275 if found { 276 t.Errorf("Session %v should have been deleted", id) 277 } 278 } 279 280 // Verify bob's session still exists 281 _, found := store.Get(bobSessionID) 282 if !found { 283 t.Error("Bob's session should still exist") 284 } 285 286 // Deleting sessions for non-existent DID should not error 287 store.DeleteByDID("did:plc:nonexistent") 288} 289 290// TestSessionStore_Cleanup tests removing expired sessions 291func TestSessionStore_Cleanup(t *testing.T) { 292 store := setupSessionTestDB(t) 293 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 294 295 // Create valid session by inserting directly with SQLite datetime format 296 validID := "valid-session-id" 297 _, err := store.db.Exec(` 298 INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at) 299 VALUES (?, ?, ?, ?, ?, datetime('now', '+1 hour'), datetime('now')) 300 `, validID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "") 301 if err != nil { 302 t.Fatalf("Failed to create valid session: %v", err) 303 } 304 305 // Create expired session 306 expiredID := "expired-session-id" 307 _, err = store.db.Exec(` 308 INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at) 309 VALUES (?, ?, ?, ?, ?, datetime('now', '-1 hour'), datetime('now')) 310 `, expiredID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "") 311 if err != nil { 312 t.Fatalf("Failed to create expired session: %v", err) 313 } 314 315 // Verify we have 2 sessions before cleanup 316 var countBefore int 317 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions").Scan(&countBefore) 318 if err != nil { 319 t.Fatalf("Query error: %v", err) 320 } 321 if countBefore != 2 { 322 t.Fatalf("Expected 2 sessions before cleanup, got %d", countBefore) 323 } 324 325 // Run cleanup 326 store.Cleanup() 327 328 // Verify valid session still exists in database 329 var countValid int 330 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", validID).Scan(&countValid) 331 if err != nil { 332 t.Fatalf("Query error: %v", err) 333 } 334 if countValid != 1 { 335 t.Errorf("Valid session should still exist in database, count = %d", countValid) 336 } 337 338 // Verify expired session was cleaned up 339 var countExpired int 340 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&countExpired) 341 if err != nil { 342 t.Fatalf("Query error: %v", err) 343 } 344 if countExpired != 0 { 345 t.Error("Expired session should have been deleted from database") 346 } 347 348 // Verify we can still get the valid session 349 _, found := store.Get(validID) 350 if !found { 351 t.Error("Valid session should be retrievable after cleanup") 352 } 353} 354 355// TestSessionStore_CleanupContext tests context-aware cleanup 356func TestSessionStore_CleanupContext(t *testing.T) { 357 store := setupSessionTestDB(t) 358 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 359 360 // Create a session and manually expire it 361 expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 362 if err != nil { 363 t.Fatalf("Create() error = %v", err) 364 } 365 366 // Manually update expiration to the past 367 _, err = store.db.Exec(` 368 UPDATE ui_sessions 369 SET expires_at = datetime('now', '-1 hour') 370 WHERE id = ? 371 `, expiredID) 372 if err != nil { 373 t.Fatalf("Failed to update expiration: %v", err) 374 } 375 376 // Run context-aware cleanup 377 ctx := context.Background() 378 err = store.CleanupContext(ctx) 379 if err != nil { 380 t.Errorf("CleanupContext() error = %v", err) 381 } 382 383 // Verify expired session was cleaned up 384 var count int 385 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&count) 386 if err != nil { 387 t.Fatalf("Query error: %v", err) 388 } 389 if count != 0 { 390 t.Error("Expired session should have been deleted from database") 391 } 392} 393 394// TestSetCookie tests setting session cookie 395func TestSetCookie(t *testing.T) { 396 w := httptest.NewRecorder() 397 sessionID := "test-session-id" 398 maxAge := 3600 399 400 SetCookie(w, sessionID, maxAge) 401 402 cookies := w.Result().Cookies() 403 if len(cookies) != 1 { 404 t.Fatalf("Expected 1 cookie, got %d", len(cookies)) 405 } 406 407 cookie := cookies[0] 408 if cookie.Name != "atcr_session" { 409 t.Errorf("Name = %v, want atcr_session", cookie.Name) 410 } 411 if cookie.Value != sessionID { 412 t.Errorf("Value = %v, want %v", cookie.Value, sessionID) 413 } 414 if cookie.MaxAge != maxAge { 415 t.Errorf("MaxAge = %v, want %v", cookie.MaxAge, maxAge) 416 } 417 if !cookie.HttpOnly { 418 t.Error("HttpOnly should be true") 419 } 420 if !cookie.Secure { 421 t.Error("Secure should be true") 422 } 423 if cookie.SameSite != http.SameSiteLaxMode { 424 t.Errorf("SameSite = %v, want Lax", cookie.SameSite) 425 } 426 if cookie.Path != "/" { 427 t.Errorf("Path = %v, want /", cookie.Path) 428 } 429} 430 431// TestClearCookie tests clearing session cookie 432func TestClearCookie(t *testing.T) { 433 w := httptest.NewRecorder() 434 435 ClearCookie(w) 436 437 cookies := w.Result().Cookies() 438 if len(cookies) != 1 { 439 t.Fatalf("Expected 1 cookie, got %d", len(cookies)) 440 } 441 442 cookie := cookies[0] 443 if cookie.Name != "atcr_session" { 444 t.Errorf("Name = %v, want atcr_session", cookie.Name) 445 } 446 if cookie.Value != "" { 447 t.Errorf("Value should be empty, got %v", cookie.Value) 448 } 449 if cookie.MaxAge != -1 { 450 t.Errorf("MaxAge = %v, want -1", cookie.MaxAge) 451 } 452 if !cookie.HttpOnly { 453 t.Error("HttpOnly should be true") 454 } 455 if !cookie.Secure { 456 t.Error("Secure should be true") 457 } 458} 459 460// TestGetSessionID tests retrieving session ID from cookie 461func TestGetSessionID(t *testing.T) { 462 tests := []struct { 463 name string 464 cookie *http.Cookie 465 wantID string 466 wantFound bool 467 }{ 468 { 469 name: "valid cookie", 470 cookie: &http.Cookie{ 471 Name: "atcr_session", 472 Value: "test-session-id", 473 }, 474 wantID: "test-session-id", 475 wantFound: true, 476 }, 477 { 478 name: "no cookie", 479 cookie: nil, 480 wantID: "", 481 wantFound: false, 482 }, 483 { 484 name: "wrong cookie name", 485 cookie: &http.Cookie{ 486 Name: "other_cookie", 487 Value: "test-value", 488 }, 489 wantID: "", 490 wantFound: false, 491 }, 492 } 493 494 for _, tt := range tests { 495 t.Run(tt.name, func(t *testing.T) { 496 req := httptest.NewRequest("GET", "/", nil) 497 if tt.cookie != nil { 498 req.AddCookie(tt.cookie) 499 } 500 501 id, found := GetSessionID(req) 502 if found != tt.wantFound { 503 t.Errorf("GetSessionID() found = %v, want %v", found, tt.wantFound) 504 } 505 if id != tt.wantID { 506 t.Errorf("GetSessionID() id = %v, want %v", id, tt.wantID) 507 } 508 }) 509 } 510} 511 512// TestSessionStore_SessionIDUniqueness tests that generated session IDs are unique 513func TestSessionStore_SessionIDUniqueness(t *testing.T) { 514 store := setupSessionTestDB(t) 515 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") 516 517 // Generate multiple session IDs 518 ids := make(map[string]bool) 519 for i := 0; i < 100; i++ { 520 id, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) 521 if err != nil { 522 t.Fatalf("Create() error = %v", err) 523 } 524 if ids[id] { 525 t.Errorf("Duplicate session ID generated: %v", id) 526 } 527 ids[id] = true 528 } 529 530 if len(ids) != 100 { 531 t.Errorf("Expected 100 unique IDs, got %d", len(ids)) 532 } 533}