A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 644 lines 19 kB view raw
1package token 2 3import ( 4 "context" 5 "crypto/tls" 6 "database/sql" 7 "encoding/base64" 8 "encoding/json" 9 "net/http" 10 "net/http/httptest" 11 "os" 12 "path/filepath" 13 "strings" 14 "sync" 15 "testing" 16 "time" 17 18 "atcr.io/pkg/appview/db" 19) 20 21// Shared test key to avoid generating a new RSA key for each test 22// Generating a 2048-bit RSA key takes ~0.15s, so reusing one key saves ~4.5s for 32 tests 23var ( 24 sharedTestKeyPath string 25 sharedTestKeyOnce sync.Once 26 sharedTestKeyDir string 27) 28 29// getSharedTestKey returns a shared RSA key and its file path for all tests 30// The key is generated once and reused across all tests in this package 31func getSharedTestKey(t *testing.T) string { 32 sharedTestKeyOnce.Do(func() { 33 // Create a persistent temp directory for the shared key 34 var err error 35 sharedTestKeyDir, err = os.MkdirTemp("", "atcr-test-keys-*") 36 if err != nil { 37 t.Fatalf("Failed to create test key directory: %v", err) 38 } 39 40 sharedTestKeyPath = filepath.Join(sharedTestKeyDir, "test-key.pem") 41 42 // Generate the key once (this is the expensive operation we want to avoid repeating) 43 // This will also generate the certificate via NewIssuer 44 _, err = NewIssuer(sharedTestKeyPath, "atcr.io", "registry", 15*time.Minute) 45 if err != nil { 46 t.Fatalf("Failed to generate shared test key: %v", err) 47 } 48 }) 49 50 return sharedTestKeyPath 51} 52 53// setupTestDeviceStore creates an in-memory SQLite database for testing 54func setupTestDeviceStore(t *testing.T) (*db.DeviceStore, *sql.DB) { 55 testDB, err := db.InitDB(":memory:", db.LibsqlConfig{}) 56 if err != nil { 57 t.Fatalf("Failed to initialize test database: %v", err) 58 } 59 return db.NewDeviceStore(testDB), testDB 60} 61 62// createTestDevice creates a device in the test database and returns its secret 63// Requires both DeviceStore and sql.DB to insert user record first 64func createTestDevice(t *testing.T, store *db.DeviceStore, testDB *sql.DB, did, handle string) string { 65 // First create a user record (required by foreign key constraint) 66 user := &db.User{ 67 DID: did, 68 Handle: handle, 69 PDSEndpoint: "https://pds.example.com", 70 } 71 err := db.UpsertUser(testDB, user) 72 if err != nil { 73 t.Fatalf("Failed to create user: %v", err) 74 } 75 76 // Create pending authorization 77 pending, err := store.CreatePendingAuth("Test Device", "127.0.0.1", "test-agent") 78 if err != nil { 79 t.Fatalf("Failed to create pending auth: %v", err) 80 } 81 82 // Approve the pending authorization 83 secret, err := store.ApprovePending(pending.UserCode, did, handle) 84 if err != nil { 85 t.Fatalf("Failed to approve pending auth: %v", err) 86 } 87 88 return secret 89} 90 91func TestNewHandler(t *testing.T) { 92 keyPath := getSharedTestKey(t) 93 94 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 95 if err != nil { 96 t.Fatalf("NewIssuer() error = %v", err) 97 } 98 99 handler := NewHandler(issuer, nil) 100 if handler == nil { 101 t.Fatal("Expected non-nil handler") 102 } 103 104 if handler.issuer == nil { 105 t.Error("Expected issuer to be set") 106 } 107 108 if handler.validator == nil { 109 t.Error("Expected validator to be initialized") 110 } 111} 112 113func TestHandler_SetPostAuthCallback(t *testing.T) { 114 keyPath := getSharedTestKey(t) 115 116 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 117 if err != nil { 118 t.Fatalf("NewIssuer() error = %v", err) 119 } 120 121 handler := NewHandler(issuer, nil) 122 123 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error { 124 return nil 125 }) 126 127 if handler.postAuthCallback == nil { 128 t.Error("Expected post-auth callback to be set") 129 } 130} 131 132func TestHandler_ServeHTTP_NoAuth(t *testing.T) { 133 keyPath := getSharedTestKey(t) 134 135 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 136 if err != nil { 137 t.Fatalf("NewIssuer() error = %v", err) 138 } 139 140 handler := NewHandler(issuer, nil) 141 142 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 143 w := httptest.NewRecorder() 144 145 handler.ServeHTTP(w, req) 146 147 if w.Code != http.StatusUnauthorized { 148 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) 149 } 150 151 // Check for WWW-Authenticate header 152 if w.Header().Get("WWW-Authenticate") == "" { 153 t.Error("Expected WWW-Authenticate header") 154 } 155} 156 157func TestHandler_ServeHTTP_WrongMethod(t *testing.T) { 158 keyPath := getSharedTestKey(t) 159 160 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 161 if err != nil { 162 t.Fatalf("NewIssuer() error = %v", err) 163 } 164 165 handler := NewHandler(issuer, nil) 166 167 // Try POST instead of GET 168 req := httptest.NewRequest(http.MethodPost, "/auth/token", nil) 169 w := httptest.NewRecorder() 170 171 handler.ServeHTTP(w, req) 172 173 if w.Code != http.StatusMethodNotAllowed { 174 t.Errorf("Expected status %d, got %d", http.StatusMethodNotAllowed, w.Code) 175 } 176} 177 178func TestHandler_ServeHTTP_DeviceAuth_Valid(t *testing.T) { 179 keyPath := getSharedTestKey(t) 180 181 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 182 if err != nil { 183 t.Fatalf("NewIssuer() error = %v", err) 184 } 185 186 // Create real device store with in-memory database 187 deviceStore, database := setupTestDeviceStore(t) 188 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") 189 190 handler := NewHandler(issuer, deviceStore) 191 192 // Create request with device secret 193 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull,push", nil) 194 req.SetBasicAuth("alice.bsky.social", deviceSecret) 195 w := httptest.NewRecorder() 196 197 handler.ServeHTTP(w, req) 198 199 if w.Code != http.StatusOK { 200 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) 201 t.Logf("Response body: %s", w.Body.String()) 202 } 203 204 // Parse response 205 var resp TokenResponse 206 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { 207 t.Fatalf("Failed to decode response: %v", err) 208 } 209 210 if resp.Token == "" { 211 t.Error("Expected non-empty token") 212 } 213 214 if resp.AccessToken == "" { 215 t.Error("Expected non-empty access_token") 216 } 217 218 if resp.ExpiresIn == 0 { 219 t.Error("Expected non-zero expires_in") 220 } 221 222 // Verify token and access_token are the same 223 if resp.Token != resp.AccessToken { 224 t.Error("Expected token and access_token to be the same") 225 } 226} 227 228func TestHandler_ServeHTTP_DeviceAuth_Invalid(t *testing.T) { 229 keyPath := getSharedTestKey(t) 230 231 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 232 if err != nil { 233 t.Fatalf("NewIssuer() error = %v", err) 234 } 235 236 // Create device store but don't add any devices 237 deviceStore, _ := setupTestDeviceStore(t) 238 239 handler := NewHandler(issuer, deviceStore) 240 241 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 242 req.SetBasicAuth("alice", "atcr_device_invalid") 243 w := httptest.NewRecorder() 244 245 handler.ServeHTTP(w, req) 246 247 if w.Code != http.StatusUnauthorized { 248 t.Errorf("Expected status %d, got %d", http.StatusUnauthorized, w.Code) 249 } 250} 251 252func TestHandler_ServeHTTP_InvalidScope(t *testing.T) { 253 keyPath := getSharedTestKey(t) 254 255 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 256 if err != nil { 257 t.Fatalf("NewIssuer() error = %v", err) 258 } 259 260 deviceStore, database := setupTestDeviceStore(t) 261 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") 262 263 handler := NewHandler(issuer, deviceStore) 264 265 // Invalid scope format (missing colons) 266 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=invalid", nil) 267 req.SetBasicAuth("alice", deviceSecret) 268 w := httptest.NewRecorder() 269 270 handler.ServeHTTP(w, req) 271 272 if w.Code != http.StatusBadRequest { 273 t.Errorf("Expected status %d, got %d", http.StatusBadRequest, w.Code) 274 } 275 276 body := w.Body.String() 277 if !strings.Contains(body, "invalid scope") { 278 t.Errorf("Expected error message to contain 'invalid scope', got: %s", body) 279 } 280} 281 282func TestHandler_ServeHTTP_AccessDenied(t *testing.T) { 283 keyPath := getSharedTestKey(t) 284 285 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 286 if err != nil { 287 t.Fatalf("NewIssuer() error = %v", err) 288 } 289 290 deviceStore, database := setupTestDeviceStore(t) 291 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 292 293 handler := NewHandler(issuer, deviceStore) 294 295 // Try to push to someone else's repository 296 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:push", nil) 297 req.SetBasicAuth("alice", deviceSecret) 298 w := httptest.NewRecorder() 299 300 handler.ServeHTTP(w, req) 301 302 if w.Code != http.StatusForbidden { 303 t.Errorf("Expected status %d, got %d", http.StatusForbidden, w.Code) 304 } 305 306 body := w.Body.String() 307 if !strings.Contains(body, "access denied") { 308 t.Errorf("Expected error message to contain 'access denied', got: %s", body) 309 } 310} 311 312func TestHandler_ServeHTTP_WithCallback(t *testing.T) { 313 keyPath := getSharedTestKey(t) 314 315 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 316 if err != nil { 317 t.Fatalf("NewIssuer() error = %v", err) 318 } 319 320 deviceStore, database := setupTestDeviceStore(t) 321 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:user123", "alice.bsky.social") 322 323 handler := NewHandler(issuer, deviceStore) 324 325 // Set callback to track if it's called 326 callbackCalled := false 327 handler.SetPostAuthCallback(func(ctx context.Context, did, handle, pds, token string) error { 328 callbackCalled = true 329 // Note: We don't check the values because callback shouldn't be called for device auth 330 return nil 331 }) 332 333 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) 334 req.SetBasicAuth("alice", deviceSecret) 335 w := httptest.NewRecorder() 336 337 handler.ServeHTTP(w, req) 338 339 // Note: Callback is only called for app password auth, not device auth 340 // So callbackCalled should be false for this test 341 if callbackCalled { 342 t.Error("Expected callback NOT to be called for device auth") 343 } 344} 345 346func TestHandler_ServeHTTP_MultipleScopes(t *testing.T) { 347 keyPath := getSharedTestKey(t) 348 349 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 350 if err != nil { 351 t.Fatalf("NewIssuer() error = %v", err) 352 } 353 354 deviceStore, database := setupTestDeviceStore(t) 355 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 356 357 handler := NewHandler(issuer, deviceStore) 358 359 // Multiple scopes separated by space (URL encoded) 360 scopes := "repository%3Aalice.bsky.social%2Fapp1%3Apull+repository%3Aalice.bsky.social%2Fapp2%3Apush" 361 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope="+scopes, nil) 362 req.SetBasicAuth("alice", deviceSecret) 363 w := httptest.NewRecorder() 364 365 handler.ServeHTTP(w, req) 366 367 if w.Code != http.StatusOK { 368 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) 369 } 370} 371 372func TestHandler_ServeHTTP_WildcardScope(t *testing.T) { 373 keyPath := getSharedTestKey(t) 374 375 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 376 if err != nil { 377 t.Fatalf("NewIssuer() error = %v", err) 378 } 379 380 deviceStore, database := setupTestDeviceStore(t) 381 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 382 383 handler := NewHandler(issuer, deviceStore) 384 385 // Wildcard scope should be allowed 386 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:*:pull,push", nil) 387 req.SetBasicAuth("alice", deviceSecret) 388 w := httptest.NewRecorder() 389 390 handler.ServeHTTP(w, req) 391 392 if w.Code != http.StatusOK { 393 t.Errorf("Expected status %d, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) 394 } 395} 396 397func TestHandler_ServeHTTP_NoScope(t *testing.T) { 398 keyPath := getSharedTestKey(t) 399 400 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 401 if err != nil { 402 t.Fatalf("NewIssuer() error = %v", err) 403 } 404 405 deviceStore, database := setupTestDeviceStore(t) 406 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 407 408 handler := NewHandler(issuer, deviceStore) 409 410 // No scope parameter - should still work (empty access) 411 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 412 req.SetBasicAuth("alice", deviceSecret) 413 w := httptest.NewRecorder() 414 415 handler.ServeHTTP(w, req) 416 417 if w.Code != http.StatusOK { 418 t.Errorf("Expected status %d, got %d", http.StatusOK, w.Code) 419 } 420 421 var resp TokenResponse 422 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { 423 t.Fatalf("Failed to decode response: %v", err) 424 } 425 426 if resp.Token == "" { 427 t.Error("Expected non-empty token even with no scope") 428 } 429} 430 431func TestGetBaseURL(t *testing.T) { 432 tests := []struct { 433 name string 434 host string 435 headers map[string]string 436 expectedURL string 437 }{ 438 { 439 name: "simple host", 440 host: "registry.example.com", 441 headers: map[string]string{}, 442 expectedURL: "http://registry.example.com", 443 }, 444 { 445 name: "with TLS", 446 host: "registry.example.com", 447 headers: map[string]string{}, 448 expectedURL: "https://registry.example.com", // Would need TLS in request 449 }, 450 { 451 name: "with X-Forwarded-Host", 452 host: "internal-host", 453 headers: map[string]string{ 454 "X-Forwarded-Host": "registry.example.com", 455 }, 456 expectedURL: "http://registry.example.com", 457 }, 458 { 459 name: "with X-Forwarded-Proto", 460 host: "registry.example.com", 461 headers: map[string]string{ 462 "X-Forwarded-Proto": "https", 463 }, 464 expectedURL: "https://registry.example.com", 465 }, 466 { 467 name: "with both forwarded headers", 468 host: "internal", 469 headers: map[string]string{ 470 "X-Forwarded-Host": "registry.example.com", 471 "X-Forwarded-Proto": "https", 472 }, 473 expectedURL: "https://registry.example.com", 474 }, 475 } 476 477 for _, tt := range tests { 478 t.Run(tt.name, func(t *testing.T) { 479 req := httptest.NewRequest(http.MethodGet, "/", nil) 480 req.Host = tt.host 481 482 for key, value := range tt.headers { 483 req.Header.Set(key, value) 484 } 485 486 // For TLS test 487 if tt.expectedURL == "https://registry.example.com" && len(tt.headers) == 0 { 488 req.TLS = &tls.ConnectionState{} // Non-nil TLS indicates HTTPS 489 } 490 491 baseURL := getBaseURL(req) 492 493 if baseURL != tt.expectedURL { 494 t.Errorf("Expected URL %q, got %q", tt.expectedURL, baseURL) 495 } 496 }) 497 } 498} 499 500func TestTokenResponse_JSONFormat(t *testing.T) { 501 resp := TokenResponse{ 502 Token: "jwt_token_here", 503 AccessToken: "jwt_token_here", 504 ExpiresIn: 900, 505 IssuedAt: "2025-01-01T00:00:00Z", 506 } 507 508 data, err := json.Marshal(resp) 509 if err != nil { 510 t.Fatalf("Failed to marshal response: %v", err) 511 } 512 513 // Verify JSON structure 514 var decoded map[string]any 515 if err := json.Unmarshal(data, &decoded); err != nil { 516 t.Fatalf("Failed to unmarshal JSON: %v", err) 517 } 518 519 if decoded["token"] != "jwt_token_here" { 520 t.Error("Expected token field in JSON") 521 } 522 523 if decoded["access_token"] != "jwt_token_here" { 524 t.Error("Expected access_token field in JSON") 525 } 526 527 if decoded["expires_in"] != float64(900) { 528 t.Error("Expected expires_in field in JSON") 529 } 530 531 if decoded["issued_at"] != "2025-01-01T00:00:00Z" { 532 t.Error("Expected issued_at field in JSON") 533 } 534} 535 536func TestHandler_ServeHTTP_AuthHeader(t *testing.T) { 537 keyPath := getSharedTestKey(t) 538 539 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 540 if err != nil { 541 t.Fatalf("NewIssuer() error = %v", err) 542 } 543 544 handler := NewHandler(issuer, nil) 545 546 // Test with manually constructed auth header 547 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry", nil) 548 auth := base64.StdEncoding.EncodeToString([]byte("username:password")) 549 req.Header.Set("Authorization", "Basic "+auth) 550 w := httptest.NewRecorder() 551 552 handler.ServeHTTP(w, req) 553 554 // Should fail because we don't have valid credentials, but we're testing the header parsing 555 if w.Code != http.StatusUnauthorized { 556 t.Logf("Got status %d (this is fine, we're just testing header parsing)", w.Code) 557 } 558} 559 560func TestHandler_ServeHTTP_ContentType(t *testing.T) { 561 keyPath := getSharedTestKey(t) 562 563 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 564 if err != nil { 565 t.Fatalf("NewIssuer() error = %v", err) 566 } 567 568 deviceStore, database := setupTestDeviceStore(t) 569 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 570 571 handler := NewHandler(issuer, deviceStore) 572 573 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) 574 req.SetBasicAuth("alice", deviceSecret) 575 w := httptest.NewRecorder() 576 577 handler.ServeHTTP(w, req) 578 579 if w.Code != http.StatusOK { 580 t.Fatalf("Expected status %d, got %d", http.StatusOK, w.Code) 581 } 582 583 contentType := w.Header().Get("Content-Type") 584 if contentType != "application/json" { 585 t.Errorf("Expected Content-Type 'application/json', got %q", contentType) 586 } 587} 588 589func TestHandler_ServeHTTP_ExpiresIn(t *testing.T) { 590 keyPath := getSharedTestKey(t) 591 592 // Create issuer with specific expiration 593 expiration := 10 * time.Minute 594 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", expiration) 595 if err != nil { 596 t.Fatalf("NewIssuer() error = %v", err) 597 } 598 599 deviceStore, database := setupTestDeviceStore(t) 600 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 601 602 handler := NewHandler(issuer, deviceStore) 603 604 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:alice.bsky.social/myapp:pull", nil) 605 req.SetBasicAuth("alice", deviceSecret) 606 w := httptest.NewRecorder() 607 608 handler.ServeHTTP(w, req) 609 610 var resp TokenResponse 611 if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { 612 t.Fatalf("Failed to decode response: %v", err) 613 } 614 615 expectedExpiresIn := int(expiration.Seconds()) 616 if resp.ExpiresIn != expectedExpiresIn { 617 t.Errorf("Expected expires_in %d, got %d", expectedExpiresIn, resp.ExpiresIn) 618 } 619} 620 621func TestHandler_ServeHTTP_PullOnlyAccess(t *testing.T) { 622 keyPath := getSharedTestKey(t) 623 624 issuer, err := NewIssuer(keyPath, "atcr.io", "registry", 15*time.Minute) 625 if err != nil { 626 t.Fatalf("NewIssuer() error = %v", err) 627 } 628 629 deviceStore, database := setupTestDeviceStore(t) 630 deviceSecret := createTestDevice(t, deviceStore, database, "did:plc:alice123", "alice.bsky.social") 631 632 handler := NewHandler(issuer, deviceStore) 633 634 // Pull from someone else's repo should be allowed 635 req := httptest.NewRequest(http.MethodGet, "/auth/token?service=registry&scope=repository:bob.bsky.social/myapp:pull", nil) 636 req.SetBasicAuth("alice", deviceSecret) 637 w := httptest.NewRecorder() 638 639 handler.ServeHTTP(w, req) 640 641 if w.Code != http.StatusOK { 642 t.Errorf("Expected status %d for pull-only access, got %d. Body: %s", http.StatusOK, w.Code, w.Body.String()) 643 } 644}