package db import ( "context" "net/http" "net/http/httptest" "strings" "testing" "time" ) // setupSessionTestDB creates an in-memory SQLite database for testing func setupSessionTestDB(t *testing.T) *SessionStore { t.Helper() // Use file::memory: with cache=shared to ensure all connections share the same in-memory DB db, err := InitDB("file::memory:?cache=shared", true) if err != nil { t.Fatalf("Failed to initialize test database: %v", err) } // Limit to single connection to avoid race conditions in tests db.SetMaxOpenConns(1) t.Cleanup(func() { db.Close() }) return NewSessionStore(db) } // createSessionTestUser creates a test user in the database func createSessionTestUser(t *testing.T, store *SessionStore, did, handle string) { t.Helper() _, err := store.db.Exec(` INSERT OR IGNORE INTO users (did, handle, pds_endpoint, last_seen) VALUES (?, ?, ?, datetime('now')) `, did, handle, "https://pds.example.com") if err != nil { t.Fatalf("Failed to create test user: %v", err) } } func TestSession_Struct(t *testing.T) { sess := &Session{ ID: "test-session", DID: "did:plc:test", Handle: "alice.bsky.social", PDSEndpoint: "https://bsky.social", OAuthSessionID: "oauth-123", ExpiresAt: time.Now().Add(1 * time.Hour), } if sess.DID != "did:plc:test" { t.Errorf("Expected DID, got %q", sess.DID) } } // TestSessionStore_Create tests session creation without OAuth func TestSessionStore_Create(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } if sessionID == "" { t.Error("Create() returned empty session ID") } // Verify session can be retrieved sess, found := store.Get(sessionID) if !found { t.Error("Created session not found") } if sess == nil { t.Fatal("Session is nil") } if sess.DID != "did:plc:alice123" { t.Errorf("DID = %v, want did:plc:alice123", sess.DID) } if sess.Handle != "alice.bsky.social" { t.Errorf("Handle = %v, want alice.bsky.social", sess.Handle) } if sess.OAuthSessionID != "" { t.Errorf("OAuthSessionID should be empty, got %v", sess.OAuthSessionID) } } // TestSessionStore_CreateWithOAuth tests session creation with OAuth func TestSessionStore_CreateWithOAuth(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") oauthSessionID := "oauth-123" sessionID, err := store.CreateWithOAuth("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", oauthSessionID, 1*time.Hour) if err != nil { t.Fatalf("CreateWithOAuth() error = %v", err) } if sessionID == "" { t.Error("CreateWithOAuth() returned empty session ID") } // Verify session has OAuth session ID sess, found := store.Get(sessionID) if !found { t.Error("Created session not found") } if sess.OAuthSessionID != oauthSessionID { t.Errorf("OAuthSessionID = %v, want %v", sess.OAuthSessionID, oauthSessionID) } } // TestSessionStore_Get tests retrieving sessions func TestSessionStore_Get(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") // Create a valid session validID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } // Create a session and manually expire it expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } // Manually update expiration to the past _, err = store.db.Exec(` UPDATE ui_sessions SET expires_at = datetime('now', '-1 hour') WHERE id = ? `, expiredID) if err != nil { t.Fatalf("Failed to update expiration: %v", err) } tests := []struct { name string sessionID string wantFound bool }{ { name: "valid session", sessionID: validID, wantFound: true, }, { name: "expired session", sessionID: expiredID, wantFound: false, }, { name: "non-existent session", sessionID: "non-existent-id", wantFound: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sess, found := store.Get(tt.sessionID) if found != tt.wantFound { t.Errorf("Get() found = %v, want %v", found, tt.wantFound) } if tt.wantFound && sess == nil { t.Error("Expected session, got nil") } }) } } // TestSessionStore_Extend tests extending session expiration func TestSessionStore_Extend(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } // Get initial expiration sess1, _ := store.Get(sessionID) initialExpiry := sess1.ExpiresAt // Wait a bit to ensure time difference time.Sleep(10 * time.Millisecond) // Extend session err = store.Extend(sessionID, 2*time.Hour) if err != nil { t.Errorf("Extend() error = %v", err) } // Verify expiration was updated sess2, found := store.Get(sessionID) if !found { t.Fatal("Session not found after extend") } if !sess2.ExpiresAt.After(initialExpiry) { t.Error("ExpiresAt should be later after extend") } // Test extending non-existent session err = store.Extend("non-existent-id", 1*time.Hour) if err == nil { t.Error("Expected error when extending non-existent session") } if err != nil && !strings.Contains(err.Error(), "not found") { t.Errorf("Expected 'not found' error, got %v", err) } } // TestSessionStore_Delete tests deleting a session func TestSessionStore_Delete(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } // Verify session exists _, found := store.Get(sessionID) if !found { t.Fatal("Session should exist before delete") } // Delete session store.Delete(sessionID) // Verify session is gone _, found = store.Get(sessionID) if found { t.Error("Session should not exist after delete") } // Deleting non-existent session should not error store.Delete("non-existent-id") } // TestSessionStore_DeleteByDID tests deleting all sessions for a DID func TestSessionStore_DeleteByDID(t *testing.T) { store := setupSessionTestDB(t) did := "did:plc:alice123" createSessionTestUser(t, store, did, "alice.bsky.social") createSessionTestUser(t, store, "did:plc:bob123", "bob.bsky.social") // Create multiple sessions for alice sessionIDs := make([]string, 3) for i := 0; i < 3; i++ { id, err := store.Create(did, "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } sessionIDs[i] = id } // Create a session for bob bobSessionID, err := store.Create("did:plc:bob123", "bob.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } // Delete all sessions for alice store.DeleteByDID(did) // Verify alice's sessions are gone for _, id := range sessionIDs { _, found := store.Get(id) if found { t.Errorf("Session %v should have been deleted", id) } } // Verify bob's session still exists _, found := store.Get(bobSessionID) if !found { t.Error("Bob's session should still exist") } // Deleting sessions for non-existent DID should not error store.DeleteByDID("did:plc:nonexistent") } // TestSessionStore_Cleanup tests removing expired sessions func TestSessionStore_Cleanup(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") // Create valid session by inserting directly with SQLite datetime format validID := "valid-session-id" _, err := store.db.Exec(` INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at) VALUES (?, ?, ?, ?, ?, datetime('now', '+1 hour'), datetime('now')) `, validID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "") if err != nil { t.Fatalf("Failed to create valid session: %v", err) } // Create expired session expiredID := "expired-session-id" _, err = store.db.Exec(` INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at) VALUES (?, ?, ?, ?, ?, datetime('now', '-1 hour'), datetime('now')) `, expiredID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "") if err != nil { t.Fatalf("Failed to create expired session: %v", err) } // Verify we have 2 sessions before cleanup var countBefore int err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions").Scan(&countBefore) if err != nil { t.Fatalf("Query error: %v", err) } if countBefore != 2 { t.Fatalf("Expected 2 sessions before cleanup, got %d", countBefore) } // Run cleanup store.Cleanup() // Verify valid session still exists in database var countValid int err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", validID).Scan(&countValid) if err != nil { t.Fatalf("Query error: %v", err) } if countValid != 1 { t.Errorf("Valid session should still exist in database, count = %d", countValid) } // Verify expired session was cleaned up var countExpired int err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&countExpired) if err != nil { t.Fatalf("Query error: %v", err) } if countExpired != 0 { t.Error("Expired session should have been deleted from database") } // Verify we can still get the valid session _, found := store.Get(validID) if !found { t.Error("Valid session should be retrievable after cleanup") } } // TestSessionStore_CleanupContext tests context-aware cleanup func TestSessionStore_CleanupContext(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") // Create a session and manually expire it expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } // Manually update expiration to the past _, err = store.db.Exec(` UPDATE ui_sessions SET expires_at = datetime('now', '-1 hour') WHERE id = ? `, expiredID) if err != nil { t.Fatalf("Failed to update expiration: %v", err) } // Run context-aware cleanup ctx := context.Background() err = store.CleanupContext(ctx) if err != nil { t.Errorf("CleanupContext() error = %v", err) } // Verify expired session was cleaned up var count int err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&count) if err != nil { t.Fatalf("Query error: %v", err) } if count != 0 { t.Error("Expired session should have been deleted from database") } } // TestSetCookie tests setting session cookie func TestSetCookie(t *testing.T) { w := httptest.NewRecorder() sessionID := "test-session-id" maxAge := 3600 SetCookie(w, sessionID, maxAge) cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "atcr_session" { t.Errorf("Name = %v, want atcr_session", cookie.Name) } if cookie.Value != sessionID { t.Errorf("Value = %v, want %v", cookie.Value, sessionID) } if cookie.MaxAge != maxAge { t.Errorf("MaxAge = %v, want %v", cookie.MaxAge, maxAge) } if !cookie.HttpOnly { t.Error("HttpOnly should be true") } if !cookie.Secure { t.Error("Secure should be true") } if cookie.SameSite != http.SameSiteLaxMode { t.Errorf("SameSite = %v, want Lax", cookie.SameSite) } if cookie.Path != "/" { t.Errorf("Path = %v, want /", cookie.Path) } } // TestClearCookie tests clearing session cookie func TestClearCookie(t *testing.T) { w := httptest.NewRecorder() ClearCookie(w) cookies := w.Result().Cookies() if len(cookies) != 1 { t.Fatalf("Expected 1 cookie, got %d", len(cookies)) } cookie := cookies[0] if cookie.Name != "atcr_session" { t.Errorf("Name = %v, want atcr_session", cookie.Name) } if cookie.Value != "" { t.Errorf("Value should be empty, got %v", cookie.Value) } if cookie.MaxAge != -1 { t.Errorf("MaxAge = %v, want -1", cookie.MaxAge) } if !cookie.HttpOnly { t.Error("HttpOnly should be true") } if !cookie.Secure { t.Error("Secure should be true") } } // TestGetSessionID tests retrieving session ID from cookie func TestGetSessionID(t *testing.T) { tests := []struct { name string cookie *http.Cookie wantID string wantFound bool }{ { name: "valid cookie", cookie: &http.Cookie{ Name: "atcr_session", Value: "test-session-id", }, wantID: "test-session-id", wantFound: true, }, { name: "no cookie", cookie: nil, wantID: "", wantFound: false, }, { name: "wrong cookie name", cookie: &http.Cookie{ Name: "other_cookie", Value: "test-value", }, wantID: "", wantFound: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) if tt.cookie != nil { req.AddCookie(tt.cookie) } id, found := GetSessionID(req) if found != tt.wantFound { t.Errorf("GetSessionID() found = %v, want %v", found, tt.wantFound) } if id != tt.wantID { t.Errorf("GetSessionID() id = %v, want %v", id, tt.wantID) } }) } } // TestSessionStore_SessionIDUniqueness tests that generated session IDs are unique func TestSessionStore_SessionIDUniqueness(t *testing.T) { store := setupSessionTestDB(t) createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social") // Generate multiple session IDs ids := make(map[string]bool) for i := 0; i < 100; i++ { id, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour) if err != nil { t.Fatalf("Create() error = %v", err) } if ids[id] { t.Errorf("Duplicate session ID generated: %v", id) } ids[id] = true } if len(ids) != 100 { t.Errorf("Expected 100 unique IDs, got %d", len(ids)) } }