A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
1package db
2
3import (
4 "context"
5 "net/http"
6 "net/http/httptest"
7 "strings"
8 "testing"
9 "time"
10)
11
12// setupSessionTestDB creates an in-memory SQLite database for testing
13func setupSessionTestDB(t *testing.T) *SessionStore {
14 t.Helper()
15 // Use file::memory: with cache=shared to ensure all connections share the same in-memory DB
16 db, err := InitDB("file::memory:?cache=shared", true)
17 if err != nil {
18 t.Fatalf("Failed to initialize test database: %v", err)
19 }
20 // Limit to single connection to avoid race conditions in tests
21 db.SetMaxOpenConns(1)
22 t.Cleanup(func() {
23 db.Close()
24 })
25 return NewSessionStore(db)
26}
27
28// createSessionTestUser creates a test user in the database
29func createSessionTestUser(t *testing.T, store *SessionStore, did, handle string) {
30 t.Helper()
31 _, err := store.db.Exec(`
32 INSERT OR IGNORE INTO users (did, handle, pds_endpoint, last_seen)
33 VALUES (?, ?, ?, datetime('now'))
34 `, did, handle, "https://pds.example.com")
35 if err != nil {
36 t.Fatalf("Failed to create test user: %v", err)
37 }
38}
39
40func TestSession_Struct(t *testing.T) {
41 sess := &Session{
42 ID: "test-session",
43 DID: "did:plc:test",
44 Handle: "alice.bsky.social",
45 PDSEndpoint: "https://bsky.social",
46 OAuthSessionID: "oauth-123",
47 ExpiresAt: time.Now().Add(1 * time.Hour),
48 }
49
50 if sess.DID != "did:plc:test" {
51 t.Errorf("Expected DID, got %q", sess.DID)
52 }
53}
54
55// TestSessionStore_Create tests session creation without OAuth
56func TestSessionStore_Create(t *testing.T) {
57 store := setupSessionTestDB(t)
58 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
59
60 sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
61 if err != nil {
62 t.Fatalf("Create() error = %v", err)
63 }
64
65 if sessionID == "" {
66 t.Error("Create() returned empty session ID")
67 }
68
69 // Verify session can be retrieved
70 sess, found := store.Get(sessionID)
71 if !found {
72 t.Error("Created session not found")
73 }
74 if sess == nil {
75 t.Fatal("Session is nil")
76 }
77 if sess.DID != "did:plc:alice123" {
78 t.Errorf("DID = %v, want did:plc:alice123", sess.DID)
79 }
80 if sess.Handle != "alice.bsky.social" {
81 t.Errorf("Handle = %v, want alice.bsky.social", sess.Handle)
82 }
83 if sess.OAuthSessionID != "" {
84 t.Errorf("OAuthSessionID should be empty, got %v", sess.OAuthSessionID)
85 }
86}
87
88// TestSessionStore_CreateWithOAuth tests session creation with OAuth
89func TestSessionStore_CreateWithOAuth(t *testing.T) {
90 store := setupSessionTestDB(t)
91 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
92
93 oauthSessionID := "oauth-123"
94 sessionID, err := store.CreateWithOAuth("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", oauthSessionID, 1*time.Hour)
95 if err != nil {
96 t.Fatalf("CreateWithOAuth() error = %v", err)
97 }
98
99 if sessionID == "" {
100 t.Error("CreateWithOAuth() returned empty session ID")
101 }
102
103 // Verify session has OAuth session ID
104 sess, found := store.Get(sessionID)
105 if !found {
106 t.Error("Created session not found")
107 }
108 if sess.OAuthSessionID != oauthSessionID {
109 t.Errorf("OAuthSessionID = %v, want %v", sess.OAuthSessionID, oauthSessionID)
110 }
111}
112
113// TestSessionStore_Get tests retrieving sessions
114func TestSessionStore_Get(t *testing.T) {
115 store := setupSessionTestDB(t)
116 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
117
118 // Create a valid session
119 validID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
120 if err != nil {
121 t.Fatalf("Create() error = %v", err)
122 }
123
124 // Create a session and manually expire it
125 expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
126 if err != nil {
127 t.Fatalf("Create() error = %v", err)
128 }
129
130 // Manually update expiration to the past
131 _, err = store.db.Exec(`
132 UPDATE ui_sessions
133 SET expires_at = datetime('now', '-1 hour')
134 WHERE id = ?
135 `, expiredID)
136 if err != nil {
137 t.Fatalf("Failed to update expiration: %v", err)
138 }
139
140 tests := []struct {
141 name string
142 sessionID string
143 wantFound bool
144 }{
145 {
146 name: "valid session",
147 sessionID: validID,
148 wantFound: true,
149 },
150 {
151 name: "expired session",
152 sessionID: expiredID,
153 wantFound: false,
154 },
155 {
156 name: "non-existent session",
157 sessionID: "non-existent-id",
158 wantFound: false,
159 },
160 }
161
162 for _, tt := range tests {
163 t.Run(tt.name, func(t *testing.T) {
164 sess, found := store.Get(tt.sessionID)
165 if found != tt.wantFound {
166 t.Errorf("Get() found = %v, want %v", found, tt.wantFound)
167 }
168 if tt.wantFound && sess == nil {
169 t.Error("Expected session, got nil")
170 }
171 })
172 }
173}
174
175// TestSessionStore_Extend tests extending session expiration
176func TestSessionStore_Extend(t *testing.T) {
177 store := setupSessionTestDB(t)
178 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
179
180 sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
181 if err != nil {
182 t.Fatalf("Create() error = %v", err)
183 }
184
185 // Get initial expiration
186 sess1, _ := store.Get(sessionID)
187 initialExpiry := sess1.ExpiresAt
188
189 // Wait a bit to ensure time difference
190 time.Sleep(10 * time.Millisecond)
191
192 // Extend session
193 err = store.Extend(sessionID, 2*time.Hour)
194 if err != nil {
195 t.Errorf("Extend() error = %v", err)
196 }
197
198 // Verify expiration was updated
199 sess2, found := store.Get(sessionID)
200 if !found {
201 t.Fatal("Session not found after extend")
202 }
203 if !sess2.ExpiresAt.After(initialExpiry) {
204 t.Error("ExpiresAt should be later after extend")
205 }
206
207 // Test extending non-existent session
208 err = store.Extend("non-existent-id", 1*time.Hour)
209 if err == nil {
210 t.Error("Expected error when extending non-existent session")
211 }
212 if err != nil && !strings.Contains(err.Error(), "not found") {
213 t.Errorf("Expected 'not found' error, got %v", err)
214 }
215}
216
217// TestSessionStore_Delete tests deleting a session
218func TestSessionStore_Delete(t *testing.T) {
219 store := setupSessionTestDB(t)
220 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
221
222 sessionID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
223 if err != nil {
224 t.Fatalf("Create() error = %v", err)
225 }
226
227 // Verify session exists
228 _, found := store.Get(sessionID)
229 if !found {
230 t.Fatal("Session should exist before delete")
231 }
232
233 // Delete session
234 store.Delete(sessionID)
235
236 // Verify session is gone
237 _, found = store.Get(sessionID)
238 if found {
239 t.Error("Session should not exist after delete")
240 }
241
242 // Deleting non-existent session should not error
243 store.Delete("non-existent-id")
244}
245
246// TestSessionStore_DeleteByDID tests deleting all sessions for a DID
247func TestSessionStore_DeleteByDID(t *testing.T) {
248 store := setupSessionTestDB(t)
249 did := "did:plc:alice123"
250 createSessionTestUser(t, store, did, "alice.bsky.social")
251 createSessionTestUser(t, store, "did:plc:bob123", "bob.bsky.social")
252
253 // Create multiple sessions for alice
254 sessionIDs := make([]string, 3)
255 for i := 0; i < 3; i++ {
256 id, err := store.Create(did, "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
257 if err != nil {
258 t.Fatalf("Create() error = %v", err)
259 }
260 sessionIDs[i] = id
261 }
262
263 // Create a session for bob
264 bobSessionID, err := store.Create("did:plc:bob123", "bob.bsky.social", "https://pds.example.com", 1*time.Hour)
265 if err != nil {
266 t.Fatalf("Create() error = %v", err)
267 }
268
269 // Delete all sessions for alice
270 store.DeleteByDID(did)
271
272 // Verify alice's sessions are gone
273 for _, id := range sessionIDs {
274 _, found := store.Get(id)
275 if found {
276 t.Errorf("Session %v should have been deleted", id)
277 }
278 }
279
280 // Verify bob's session still exists
281 _, found := store.Get(bobSessionID)
282 if !found {
283 t.Error("Bob's session should still exist")
284 }
285
286 // Deleting sessions for non-existent DID should not error
287 store.DeleteByDID("did:plc:nonexistent")
288}
289
290// TestSessionStore_Cleanup tests removing expired sessions
291func TestSessionStore_Cleanup(t *testing.T) {
292 store := setupSessionTestDB(t)
293 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
294
295 // Create valid session by inserting directly with SQLite datetime format
296 validID := "valid-session-id"
297 _, err := store.db.Exec(`
298 INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at)
299 VALUES (?, ?, ?, ?, ?, datetime('now', '+1 hour'), datetime('now'))
300 `, validID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "")
301 if err != nil {
302 t.Fatalf("Failed to create valid session: %v", err)
303 }
304
305 // Create expired session
306 expiredID := "expired-session-id"
307 _, err = store.db.Exec(`
308 INSERT INTO ui_sessions (id, did, handle, pds_endpoint, oauth_session_id, expires_at, created_at)
309 VALUES (?, ?, ?, ?, ?, datetime('now', '-1 hour'), datetime('now'))
310 `, expiredID, "did:plc:alice123", "alice.bsky.social", "https://pds.example.com", "")
311 if err != nil {
312 t.Fatalf("Failed to create expired session: %v", err)
313 }
314
315 // Verify we have 2 sessions before cleanup
316 var countBefore int
317 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions").Scan(&countBefore)
318 if err != nil {
319 t.Fatalf("Query error: %v", err)
320 }
321 if countBefore != 2 {
322 t.Fatalf("Expected 2 sessions before cleanup, got %d", countBefore)
323 }
324
325 // Run cleanup
326 store.Cleanup()
327
328 // Verify valid session still exists in database
329 var countValid int
330 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", validID).Scan(&countValid)
331 if err != nil {
332 t.Fatalf("Query error: %v", err)
333 }
334 if countValid != 1 {
335 t.Errorf("Valid session should still exist in database, count = %d", countValid)
336 }
337
338 // Verify expired session was cleaned up
339 var countExpired int
340 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&countExpired)
341 if err != nil {
342 t.Fatalf("Query error: %v", err)
343 }
344 if countExpired != 0 {
345 t.Error("Expired session should have been deleted from database")
346 }
347
348 // Verify we can still get the valid session
349 _, found := store.Get(validID)
350 if !found {
351 t.Error("Valid session should be retrievable after cleanup")
352 }
353}
354
355// TestSessionStore_CleanupContext tests context-aware cleanup
356func TestSessionStore_CleanupContext(t *testing.T) {
357 store := setupSessionTestDB(t)
358 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
359
360 // Create a session and manually expire it
361 expiredID, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
362 if err != nil {
363 t.Fatalf("Create() error = %v", err)
364 }
365
366 // Manually update expiration to the past
367 _, err = store.db.Exec(`
368 UPDATE ui_sessions
369 SET expires_at = datetime('now', '-1 hour')
370 WHERE id = ?
371 `, expiredID)
372 if err != nil {
373 t.Fatalf("Failed to update expiration: %v", err)
374 }
375
376 // Run context-aware cleanup
377 ctx := context.Background()
378 err = store.CleanupContext(ctx)
379 if err != nil {
380 t.Errorf("CleanupContext() error = %v", err)
381 }
382
383 // Verify expired session was cleaned up
384 var count int
385 err = store.db.QueryRow("SELECT COUNT(*) FROM ui_sessions WHERE id = ?", expiredID).Scan(&count)
386 if err != nil {
387 t.Fatalf("Query error: %v", err)
388 }
389 if count != 0 {
390 t.Error("Expired session should have been deleted from database")
391 }
392}
393
394// TestSetCookie tests setting session cookie
395func TestSetCookie(t *testing.T) {
396 w := httptest.NewRecorder()
397 sessionID := "test-session-id"
398 maxAge := 3600
399
400 SetCookie(w, sessionID, maxAge)
401
402 cookies := w.Result().Cookies()
403 if len(cookies) != 1 {
404 t.Fatalf("Expected 1 cookie, got %d", len(cookies))
405 }
406
407 cookie := cookies[0]
408 if cookie.Name != "atcr_session" {
409 t.Errorf("Name = %v, want atcr_session", cookie.Name)
410 }
411 if cookie.Value != sessionID {
412 t.Errorf("Value = %v, want %v", cookie.Value, sessionID)
413 }
414 if cookie.MaxAge != maxAge {
415 t.Errorf("MaxAge = %v, want %v", cookie.MaxAge, maxAge)
416 }
417 if !cookie.HttpOnly {
418 t.Error("HttpOnly should be true")
419 }
420 if !cookie.Secure {
421 t.Error("Secure should be true")
422 }
423 if cookie.SameSite != http.SameSiteLaxMode {
424 t.Errorf("SameSite = %v, want Lax", cookie.SameSite)
425 }
426 if cookie.Path != "/" {
427 t.Errorf("Path = %v, want /", cookie.Path)
428 }
429}
430
431// TestClearCookie tests clearing session cookie
432func TestClearCookie(t *testing.T) {
433 w := httptest.NewRecorder()
434
435 ClearCookie(w)
436
437 cookies := w.Result().Cookies()
438 if len(cookies) != 1 {
439 t.Fatalf("Expected 1 cookie, got %d", len(cookies))
440 }
441
442 cookie := cookies[0]
443 if cookie.Name != "atcr_session" {
444 t.Errorf("Name = %v, want atcr_session", cookie.Name)
445 }
446 if cookie.Value != "" {
447 t.Errorf("Value should be empty, got %v", cookie.Value)
448 }
449 if cookie.MaxAge != -1 {
450 t.Errorf("MaxAge = %v, want -1", cookie.MaxAge)
451 }
452 if !cookie.HttpOnly {
453 t.Error("HttpOnly should be true")
454 }
455 if !cookie.Secure {
456 t.Error("Secure should be true")
457 }
458}
459
460// TestGetSessionID tests retrieving session ID from cookie
461func TestGetSessionID(t *testing.T) {
462 tests := []struct {
463 name string
464 cookie *http.Cookie
465 wantID string
466 wantFound bool
467 }{
468 {
469 name: "valid cookie",
470 cookie: &http.Cookie{
471 Name: "atcr_session",
472 Value: "test-session-id",
473 },
474 wantID: "test-session-id",
475 wantFound: true,
476 },
477 {
478 name: "no cookie",
479 cookie: nil,
480 wantID: "",
481 wantFound: false,
482 },
483 {
484 name: "wrong cookie name",
485 cookie: &http.Cookie{
486 Name: "other_cookie",
487 Value: "test-value",
488 },
489 wantID: "",
490 wantFound: false,
491 },
492 }
493
494 for _, tt := range tests {
495 t.Run(tt.name, func(t *testing.T) {
496 req := httptest.NewRequest("GET", "/", nil)
497 if tt.cookie != nil {
498 req.AddCookie(tt.cookie)
499 }
500
501 id, found := GetSessionID(req)
502 if found != tt.wantFound {
503 t.Errorf("GetSessionID() found = %v, want %v", found, tt.wantFound)
504 }
505 if id != tt.wantID {
506 t.Errorf("GetSessionID() id = %v, want %v", id, tt.wantID)
507 }
508 })
509 }
510}
511
512// TestSessionStore_SessionIDUniqueness tests that generated session IDs are unique
513func TestSessionStore_SessionIDUniqueness(t *testing.T) {
514 store := setupSessionTestDB(t)
515 createSessionTestUser(t, store, "did:plc:alice123", "alice.bsky.social")
516
517 // Generate multiple session IDs
518 ids := make(map[string]bool)
519 for i := 0; i < 100; i++ {
520 id, err := store.Create("did:plc:alice123", "alice.bsky.social", "https://pds.example.com", 1*time.Hour)
521 if err != nil {
522 t.Fatalf("Create() error = %v", err)
523 }
524 if ids[id] {
525 t.Errorf("Duplicate session ID generated: %v", id)
526 }
527 ids[id] = true
528 }
529
530 if len(ids) != 100 {
531 t.Errorf("Expected 100 unique IDs, got %d", len(ids))
532 }
533}