A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
1package token
2
3import (
4 "context"
5 "crypto/tls"
6 "database/sql"
7 "encoding/base64"
8 "encoding/json"
9 "net/http"
10 "net/http/httptest"
11 "os"
12 "path/filepath"
13 "strings"
14 "sync"
15 "testing"
16 "time"
17
18 "atcr.io/pkg/appview/db"
19)
20
21// Shared test key to avoid generating a new RSA key for each test
22// Generating a 2048-bit RSA key takes ~0.15s, so reusing one key saves ~4.5s for 32 tests
23var (
24 sharedTestKeyPath string
25 sharedTestKeyOnce sync.Once
26 sharedTestKeyDir string
27)
28
29// getSharedTestKey returns a shared RSA key and its file path for all tests
30// The key is generated once and reused across all tests in this package
31func getSharedTestKey(t *testing.T) string {
32 sharedTestKeyOnce.Do(func() {
33 // Create a persistent temp directory for the shared key
34 var err error
35 sharedTestKeyDir, err = os.MkdirTemp("", "atcr-test-keys-*")
36 if err != nil {
37 t.Fatalf("Failed to create test key directory: %v", err)
38 }
39
40 sharedTestKeyPath = filepath.Join(sharedTestKeyDir, "test-key.pem")
41
42 // Generate the key once (this is the expensive operation we want to avoid repeating)
43 // This will also generate the certificate via NewIssuer
44 _, err = NewIssuer(sharedTestKeyPath, "atcr.io", "registry", 15*time.Minute)
45 if err != nil {
46 t.Fatalf("Failed to generate shared test key: %v", err)
47 }
48 })
49
50 return sharedTestKeyPath
51}
52
53// setupTestDeviceStore creates an in-memory SQLite database for testing
54func setupTestDeviceStore(t *testing.T) (*db.DeviceStore, *sql.DB) {
55 testDB, err := db.InitDB(":memory:", db.LibsqlConfig{})
56 if err != nil {
57 t.Fatalf("Failed to initialize test database: %v", err)
58 }
59 return db.NewDeviceStore(testDB), testDB
60}
61
62// createTestDevice creates a device in the test database and returns its secret
63// Requires both DeviceStore and sql.DB to insert user record first
64func createTestDevice(t *testing.T, store *db.DeviceStore, testDB *sql.DB, did, handle string) string {
65 // First create a user record (required by foreign key constraint)
66 user := &db.User{
67 DID: did,
68 Handle: handle,
69 PDSEndpoint: "https://pds.example.com",
70 }
71 err := db.UpsertUser(testDB, user)
72 if err != nil {
73 t.Fatalf("Failed to create user: %v", err)
74 }
75
76 // Create pending authorization
77 pending, err := store.CreatePendingAuth("Test Device", "127.0.0.1", "test-agent")
78 if err != nil {
79 t.Fatalf("Failed to create pending auth: %v", err)
80 }
81
82 // Approve the pending authorization
83 secret, err := store.ApprovePending(pending.UserCode, did, handle)
84 if err != nil {
85 t.Fatalf("Failed to approve pending auth: %v", err)
86 }
87
88 return secret
89}
90
91func TestNewHandler(t *testing.T) {
92 keyPath := getSharedTestKey(t)
93
94 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
95 if err != nil {
96 t.Fatalf("NewIssuer() error = %v", err)
97 }
98
99 handler := NewHandler(issuer, nil)
100 if handler == nil {
101 t.Fatal("Expected non-nil handler")
102 }
103
104 if handler.issuer == nil {
105 t.Error("Expected issuer to be set")
106 }
107
108 if handler.validator == nil {
109 t.Error("Expected validator to be initialized")
110 }
111}
112
113func TestHandler_SetPostAuthCallback(t *testing.T) {
114 keyPath := getSharedTestKey(t)
115
116 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
117 if err != nil {
118 t.Fatalf("NewIssuer() error = %v", err)
119 }
120
121 handler := NewHandler(issuer, nil)
122
123 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error {
124 return nil
125 })
126
127 if handler.postAuthCallback == nil {
128 t.Error("Expected post-auth callback to be set")
129 }
130}
131
132func TestHandler_ServeHTTP_NoAuth(t *testing.T) {
133 keyPath := getSharedTestKey(t)
134
135 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
136 if err != nil {
137 t.Fatalf("NewIssuer() error = %v", err)
138 }
139
140 handler := NewHandler(issuer, nil)
141
142 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
143 w := httptest.NewRecorder()
144
145 handler.ServeHTTP(w, req)
146
147 if w.Code != http.StatusUnauthorized {
148 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
149 }
150
151 // Check for WWW-Authenticate header
152 if w.Header().Get("WWW-Authenticate") == "" {
153 t.Error("Expected WWW-Authenticate header")
154 }
155}
156
157func TestHandler_ServeHTTP_WrongMethod(t *testing.T) {
158 keyPath := getSharedTestKey(t)
159
160 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
161 if err != nil {
162 t.Fatalf("NewIssuer() error = %v", err)
163 }
164
165 handler := NewHandler(issuer, nil)
166
167 // Try POST instead of GET
168 req := httptest.NewRequest(http.MethodPost, "/auth/token", nil)
169 w := httptest.NewRecorder()
170
171 handler.ServeHTTP(w, req)
172
173 if w.Code != http.StatusMethodNotAllowed {
174 t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code)
175 }
176}
177
178func TestHandler_ServeHTTP_DeviceAuth_Valid(t *testing.T) {
179 keyPath := getSharedTestKey(t)
180
181 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
182 if err != nil {
183 t.Fatalf("NewIssuer() error = %v", err)
184 }
185
186 // Create real device store with in-memory database
187 deviceStore, database := setupTestDeviceStore(t)
188 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
189
190 handler := NewHandler(issuer, deviceStore)
191
192 // Create request with device secret
193 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull,push", nil)
194 req.SetBasicAuth("alice.bsky.social", deviceSecret)
195 w := httptest.NewRecorder()
196
197 handler.ServeHTTP(w, req)
198
199 if w.Code != http.StatusOK {
200 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
201 t.Logf("Response body: %s", w.Body.String())
202 }
203
204 // Parse response
205 var resp TokenResponse
206 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
207 t.Fatalf("Failed to decode response: %v", err)
208 }
209
210 if resp.Token == "" {
211 t.Error("Expected non-empty token")
212 }
213
214 if resp.AccessToken == "" {
215 t.Error("Expected non-empty access_token")
216 }
217
218 if resp.ExpiresIn == 0 {
219 t.Error("Expected non-zero expires_in")
220 }
221
222 // Verify token and access_token are the same
223 if resp.Token != resp.AccessToken {
224 t.Error("Expected token and access_token to be the same")
225 }
226}
227
228func TestHandler_ServeHTTP_DeviceAuth_Invalid(t *testing.T) {
229 keyPath := getSharedTestKey(t)
230
231 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
232 if err != nil {
233 t.Fatalf("NewIssuer() error = %v", err)
234 }
235
236 // Create device store but don't add any devices
237 deviceStore, _ := setupTestDeviceStore(t)
238
239 handler := NewHandler(issuer, deviceStore)
240
241 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
242 req.SetBasicAuth("alice", "atcr_device_invalid")
243 w := httptest.NewRecorder()
244
245 handler.ServeHTTP(w, req)
246
247 if w.Code != http.StatusUnauthorized {
248 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code)
249 }
250}
251
252func TestHandler_ServeHTTP_InvalidScope(t *testing.T) {
253 keyPath := getSharedTestKey(t)
254
255 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
256 if err != nil {
257 t.Fatalf("NewIssuer() error = %v", err)
258 }
259
260 deviceStore, database := setupTestDeviceStore(t)
261 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
262
263 handler := NewHandler(issuer, deviceStore)
264
265 // Invalid scope format (missing colons)
266 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=invalid", nil)
267 req.SetBasicAuth("alice", deviceSecret)
268 w := httptest.NewRecorder()
269
270 handler.ServeHTTP(w, req)
271
272 if w.Code != http.StatusBadRequest {
273 t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code)
274 }
275
276 body := w.Body.String()
277 if !strings.Contains(body, "invalid scope") {
278 t.Errorf("Expected error message to contain 'invalid scope', got: %s", body)
279 }
280}
281
282func TestHandler_ServeHTTP_AccessDenied(t *testing.T) {
283 keyPath := getSharedTestKey(t)
284
285 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
286 if err != nil {
287 t.Fatalf("NewIssuer() error = %v", err)
288 }
289
290 deviceStore, database := setupTestDeviceStore(t)
291 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
292
293 handler := NewHandler(issuer, deviceStore)
294
295 // Try to push to someone else's repository
296 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:push", nil)
297 req.SetBasicAuth("alice", deviceSecret)
298 w := httptest.NewRecorder()
299
300 handler.ServeHTTP(w, req)
301
302 if w.Code != http.StatusForbidden {
303 t.Errorf("Expected status %d, got %d", http.StatusForbidden, w.Code)
304 }
305
306 body := w.Body.String()
307 if !strings.Contains(body, "access denied") {
308 t.Errorf("Expected error message to contain 'access denied', got: %s", body)
309 }
310}
311
312func TestHandler_ServeHTTP_WithCallback(t *testing.T) {
313 keyPath := getSharedTestKey(t)
314
315 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
316 if err != nil {
317 t.Fatalf("NewIssuer() error = %v", err)
318 }
319
320 deviceStore, database := setupTestDeviceStore(t)
321 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social")
322
323 handler := NewHandler(issuer, deviceStore)
324
325 // Set callback to track if it's called
326 callbackCalled := false
327 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error {
328 callbackCalled = true
329 // Note: We don't check the values because callback shouldn't be called for device auth
330 return nil
331 })
332
333 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
334 req.SetBasicAuth("alice", deviceSecret)
335 w := httptest.NewRecorder()
336
337 handler.ServeHTTP(w, req)
338
339 // Note: Callback is only called for app password auth, not device auth
340 // So callbackCalled should be false for this test
341 if callbackCalled {
342 t.Error("Expected callback NOT to be called for device auth")
343 }
344}
345
346func TestHandler_ServeHTTP_MultipleScopes(t *testing.T) {
347 keyPath := getSharedTestKey(t)
348
349 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
350 if err != nil {
351 t.Fatalf("NewIssuer() error = %v", err)
352 }
353
354 deviceStore, database := setupTestDeviceStore(t)
355 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
356
357 handler := NewHandler(issuer, deviceStore)
358
359 // Multiple scopes separated by space (URL encoded)
360 scopes := "repository%3Aalice.bsky.social%2Fapp1%3Apull+repository%3Aalice.bsky.social%2Fapp2%3Apush"
361 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope="+scopes, nil)
362 req.SetBasicAuth("alice", deviceSecret)
363 w := httptest.NewRecorder()
364
365 handler.ServeHTTP(w, req)
366
367 if w.Code != http.StatusOK {
368 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
369 }
370}
371
372func TestHandler_ServeHTTP_WildcardScope(t *testing.T) {
373 keyPath := getSharedTestKey(t)
374
375 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
376 if err != nil {
377 t.Fatalf("NewIssuer() error = %v", err)
378 }
379
380 deviceStore, database := setupTestDeviceStore(t)
381 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
382
383 handler := NewHandler(issuer, deviceStore)
384
385 // Wildcard scope should be allowed
386 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:*:pull,push", nil)
387 req.SetBasicAuth("alice", deviceSecret)
388 w := httptest.NewRecorder()
389
390 handler.ServeHTTP(w, req)
391
392 if w.Code != http.StatusOK {
393 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
394 }
395}
396
397func TestHandler_ServeHTTP_NoScope(t *testing.T) {
398 keyPath := getSharedTestKey(t)
399
400 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
401 if err != nil {
402 t.Fatalf("NewIssuer() error = %v", err)
403 }
404
405 deviceStore, database := setupTestDeviceStore(t)
406 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
407
408 handler := NewHandler(issuer, deviceStore)
409
410 // No scope parameter - should still work (empty access)
411 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
412 req.SetBasicAuth("alice", deviceSecret)
413 w := httptest.NewRecorder()
414
415 handler.ServeHTTP(w, req)
416
417 if w.Code != http.StatusOK {
418 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code)
419 }
420
421 var resp TokenResponse
422 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
423 t.Fatalf("Failed to decode response: %v", err)
424 }
425
426 if resp.Token == "" {
427 t.Error("Expected non-empty token even with no scope")
428 }
429}
430
431func TestGetBaseURL(t *testing.T) {
432 tests := []struct {
433 name string
434 host string
435 headers map[string]string
436 expectedURL string
437 }{
438 {
439 name: "simple host",
440 host: "registry.example.com",
441 headers: map[string]string{},
442 expectedURL: "http://registry.example.com",
443 },
444 {
445 name: "with TLS",
446 host: "registry.example.com",
447 headers: map[string]string{},
448 expectedURL: "https://registry.example.com", // Would need TLS in request
449 },
450 {
451 name: "with X-Forwarded-Host",
452 host: "internal-host",
453 headers: map[string]string{
454 "X-Forwarded-Host": "registry.example.com",
455 },
456 expectedURL: "http://registry.example.com",
457 },
458 {
459 name: "with X-Forwarded-Proto",
460 host: "registry.example.com",
461 headers: map[string]string{
462 "X-Forwarded-Proto": "https",
463 },
464 expectedURL: "https://registry.example.com",
465 },
466 {
467 name: "with both forwarded headers",
468 host: "internal",
469 headers: map[string]string{
470 "X-Forwarded-Host": "registry.example.com",
471 "X-Forwarded-Proto": "https",
472 },
473 expectedURL: "https://registry.example.com",
474 },
475 }
476
477 for _, tt := range tests {
478 t.Run(tt.name, func(t *testing.T) {
479 req := httptest.NewRequest(http.MethodGet, "/", nil)
480 req.Host = tt.host
481
482 for key, value := range tt.headers {
483 req.Header.Set(key, value)
484 }
485
486 // For TLS test
487 if tt.expectedURL == "https://registry.example.com" && len(tt.headers) == 0 {
488 req.TLS = &tls.ConnectionState{} // Non-nil TLS indicates HTTPS
489 }
490
491 baseURL := getBaseURL(req)
492
493 if baseURL != tt.expectedURL {
494 t.Errorf("Expected URL %q, got %q", tt.expectedURL, baseURL)
495 }
496 })
497 }
498}
499
500func TestTokenResponse_JSONFormat(t *testing.T) {
501 resp := TokenResponse{
502 Token: "jwt_token_here",
503 AccessToken: "jwt_token_here",
504 ExpiresIn: 900,
505 IssuedAt: "2025-01-01T00:00:00Z",
506 }
507
508 data, err := json.Marshal(resp)
509 if err != nil {
510 t.Fatalf("Failed to marshal response: %v", err)
511 }
512
513 // Verify JSON structure
514 var decoded map[string]any
515 if err := json.Unmarshal(data, &decoded); err != nil {
516 t.Fatalf("Failed to unmarshal JSON: %v", err)
517 }
518
519 if decoded["token"] != "jwt_token_here" {
520 t.Error("Expected token field in JSON")
521 }
522
523 if decoded["access_token"] != "jwt_token_here" {
524 t.Error("Expected access_token field in JSON")
525 }
526
527 if decoded["expires_in"] != float64(900) {
528 t.Error("Expected expires_in field in JSON")
529 }
530
531 if decoded["issued_at"] != "2025-01-01T00:00:00Z" {
532 t.Error("Expected issued_at field in JSON")
533 }
534}
535
536func TestHandler_ServeHTTP_AuthHeader(t *testing.T) {
537 keyPath := getSharedTestKey(t)
538
539 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
540 if err != nil {
541 t.Fatalf("NewIssuer() error = %v", err)
542 }
543
544 handler := NewHandler(issuer, nil)
545
546 // Test with manually constructed auth header
547 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil)
548 auth := base64.StdEncoding.EncodeToString([]byte("username:password"))
549 req.Header.Set("Authorization", "Basic "+auth)
550 w := httptest.NewRecorder()
551
552 handler.ServeHTTP(w, req)
553
554 // Should fail because we don't have valid credentials, but we're testing the header parsing
555 if w.Code != http.StatusUnauthorized {
556 t.Logf("Got status %d (this is fine, we're just testing header parsing)", w.Code)
557 }
558}
559
560func TestHandler_ServeHTTP_ContentType(t *testing.T) {
561 keyPath := getSharedTestKey(t)
562
563 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
564 if err != nil {
565 t.Fatalf("NewIssuer() error = %v", err)
566 }
567
568 deviceStore, database := setupTestDeviceStore(t)
569 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
570
571 handler := NewHandler(issuer, deviceStore)
572
573 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
574 req.SetBasicAuth("alice", deviceSecret)
575 w := httptest.NewRecorder()
576
577 handler.ServeHTTP(w, req)
578
579 if w.Code != http.StatusOK {
580 t.Fatalf("Expected status %d, got %d", http.StatusOK, w.Code)
581 }
582
583 contentType := w.Header().Get("Content-Type")
584 if contentType != "application/json" {
585 t.Errorf("Expected Content-Type 'application/json', got %q", contentType)
586 }
587}
588
589func TestHandler_ServeHTTP_ExpiresIn(t *testing.T) {
590 keyPath := getSharedTestKey(t)
591
592 // Create issuer with specific expiration
593 expiration := 10 * time.Minute
594 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", expiration)
595 if err != nil {
596 t.Fatalf("NewIssuer() error = %v", err)
597 }
598
599 deviceStore, database := setupTestDeviceStore(t)
600 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
601
602 handler := NewHandler(issuer, deviceStore)
603
604 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil)
605 req.SetBasicAuth("alice", deviceSecret)
606 w := httptest.NewRecorder()
607
608 handler.ServeHTTP(w, req)
609
610 var resp TokenResponse
611 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
612 t.Fatalf("Failed to decode response: %v", err)
613 }
614
615 expectedExpiresIn := int(expiration.Seconds())
616 if resp.ExpiresIn != expectedExpiresIn {
617 t.Errorf("Expected expires_in %d, got %d", expectedExpiresIn, resp.ExpiresIn)
618 }
619}
620
621func TestHandler_ServeHTTP_PullOnlyAccess(t *testing.T) {
622 keyPath := getSharedTestKey(t)
623
624 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute)
625 if err != nil {
626 t.Fatalf("NewIssuer() error = %v", err)
627 }
628
629 deviceStore, database := setupTestDeviceStore(t)
630 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social")
631
632 handler := NewHandler(issuer, deviceStore)
633
634 // Pull from someone else's repo should be allowed
635 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:pull", nil)
636 req.SetBasicAuth("alice", deviceSecret)
637 w := httptest.NewRecorder()
638
639 handler.ServeHTTP(w, req)
640
641 if w.Code != http.StatusOK {
642 t.Errorf("Expected status %d for pull-only access, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String())
643 }
644}