A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

try and provide more helpful reponses when oauth expires and when pushing manifest lists

+434 -42
+6 -32
.tangled/workflows/release.yml
··· 2 2 # Triggers on version tags and builds cross-platform binaries using buildah 3 3 4 4 when: 5 - - event: ["manual"] 5 + - event: ["push"] 6 6 tag: ["v*"] 7 7 8 - engine: "nixery" 9 - 10 - dependencies: 11 - nixpkgs: 12 - - buildah 13 - - gnugrep # Required for tag detection 8 + engine: "kubernetes" 14 9 15 10 environment: 16 11 IMAGE_REGISTRY: atcr.io 17 12 IMAGE_USER: evan.jarrett.net 18 13 19 14 steps: 20 - - name: Get tag for current commit 21 - command: | 22 - # Fetch tags (shallow clone doesn't include them by default) 23 - git fetch --tags 24 - 25 - # Find the tag that points to the current commit 26 - TAG=$(git tag --points-at HEAD | grep -E '^v[0-9]' | head -n1) 27 - 28 - if [ -z "$TAG" ]; then 29 - echo "Error: No version tag found for current commit" 30 - echo "Available tags:" 31 - git tag 32 - echo "Current commit:" 33 - git rev-parse HEAD 34 - exit 1 35 - fi 36 - 37 - echo "Building version: $TAG" 38 - echo "$TAG" > .version 39 15 40 16 - name: Setup build environment 41 17 command: | ··· 53 29 54 30 - name: Build and push AppView image 55 31 command: | 56 - TAG=$(cat .version) 57 - 32 + echo ${TANGLED_REF_NAME} 58 33 buildah bud \ 59 34 --storage-driver vfs \ 60 - --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TAG} \ 35 + --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:${TANGLED_REF_NAME} \ 61 36 --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-appview:latest \ 62 37 --file ./Dockerfile.appview \ 63 38 . ··· 68 43 69 44 - name: Build and push Hold image 70 45 command: | 71 - TAG=$(cat .version) 72 - 46 + echo ${TANGLED_REF_NAME} 73 47 buildah bud \ 74 48 --storage-driver vfs \ 75 - --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TAG} \ 49 + --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:${TANGLED_REF_NAME} \ 76 50 --tag ${IMAGE_REGISTRY}/${IMAGE_USER}/atcr-hold:latest \ 77 51 --file ./Dockerfile.hold \ 78 52 .
+4
cmd/appview/serve.go
··· 409 409 // Basic Auth token endpoint (supports device secrets and app passwords) 410 410 tokenHandler := token.NewHandler(issuer, deviceStore) 411 411 412 + // Register OAuth session checker for device auth validation 413 + // This ensures device secrets only work when the linked OAuth session exists 414 + tokenHandler.SetOAuthSessionChecker(oauthStore) 415 + 412 416 // Register token post-auth callback for profile management 413 417 // This decouples the token package from AppView-specific dependencies 414 418 tokenHandler.SetPostAuthCallback(func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error {
+79 -6
cmd/credential-helper/main.go
··· 67 67 Error string `json:"error,omitempty"` 68 68 } 69 69 70 + // AuthErrorResponse is the JSON error response from /auth/token 71 + type AuthErrorResponse struct { 72 + Error string `json:"error"` 73 + Message string `json:"message"` 74 + LoginURL string `json:"login_url,omitempty"` 75 + } 76 + 77 + // ValidationResult represents the result of credential validation 78 + type ValidationResult struct { 79 + Valid bool 80 + OAuthSessionExpired bool 81 + LoginURL string 82 + } 83 + 70 84 var ( 71 85 version = "dev" 72 86 commit = "none" ··· 123 137 124 138 // If credentials exist, validate them 125 139 if found && deviceConfig.DeviceSecret != "" { 126 - if !validateCredentials(appViewURL, deviceConfig.Handle, deviceConfig.DeviceSecret) { 140 + result := validateCredentials(appViewURL, deviceConfig.Handle, deviceConfig.DeviceSecret) 141 + if !result.Valid { 142 + if result.OAuthSessionExpired { 143 + // OAuth session expired - need to re-authenticate via browser 144 + // Device secret is still valid, just need to restore OAuth session 145 + fmt.Fprintf(os.Stderr, "OAuth session expired. Opening browser to re-authenticate...\n") 146 + 147 + loginURL := result.LoginURL 148 + if loginURL == "" { 149 + loginURL = appViewURL + "/auth/oauth/login" 150 + } 151 + 152 + // Try to open browser 153 + if err := openBrowser(loginURL); err != nil { 154 + fmt.Fprintf(os.Stderr, "Could not open browser automatically.\n") 155 + fmt.Fprintf(os.Stderr, "Please visit: %s\n", loginURL) 156 + } else { 157 + fmt.Fprintf(os.Stderr, "Please complete authentication in your browser.\n") 158 + } 159 + 160 + // Wait for user to complete OAuth flow, then retry 161 + fmt.Fprintf(os.Stderr, "Waiting for authentication") 162 + for i := 0; i < 60; i++ { // Wait up to 2 minutes 163 + time.Sleep(2 * time.Second) 164 + fmt.Fprintf(os.Stderr, ".") 165 + 166 + // Retry validation 167 + retryResult := validateCredentials(appViewURL, deviceConfig.Handle, deviceConfig.DeviceSecret) 168 + if retryResult.Valid { 169 + fmt.Fprintf(os.Stderr, "\n✓ Re-authenticated successfully!\n") 170 + goto credentialsValid 171 + } 172 + } 173 + fmt.Fprintf(os.Stderr, "\nAuthentication timed out. Please try again.\n") 174 + os.Exit(1) 175 + } 176 + 177 + // Generic auth failure - delete credentials and re-authorize 127 178 fmt.Fprintf(os.Stderr, "Stored credentials for %s are invalid or expired\n", appViewURL) 128 179 // Delete the invalid credentials 129 180 delete(allCreds.Credentials, appViewURL) ··· 134 185 found = false 135 186 } 136 187 } 188 + credentialsValid: 137 189 138 190 if !found || deviceConfig.DeviceSecret == "" { 139 191 // No credentials for this AppView ··· 550 602 } 551 603 552 604 // validateCredentials checks if the credentials are still valid by making a test request 553 - func validateCredentials(appViewURL, handle, deviceSecret string) bool { 605 + func validateCredentials(appViewURL, handle, deviceSecret string) ValidationResult { 554 606 // Call /auth/token to validate device secret and get JWT 555 607 // This is the proper way to validate credentials - /v2/ requires JWT, not Basic Auth 556 608 client := &http.Client{ ··· 562 614 563 615 req, err := http.NewRequest("GET", tokenURL, nil) 564 616 if err != nil { 565 - return false 617 + return ValidationResult{Valid: false} 566 618 } 567 619 568 620 // Set basic auth with device credentials ··· 572 624 if err != nil { 573 625 // Network error - assume credentials are valid but server unreachable 574 626 // Don't trigger re-auth on network issues 575 - return true 627 + return ValidationResult{Valid: true} 576 628 } 577 629 defer resp.Body.Close() 578 630 579 631 // 200 = valid credentials 580 - // 401 = invalid/expired credentials 632 + if resp.StatusCode == http.StatusOK { 633 + return ValidationResult{Valid: true} 634 + } 635 + 636 + // 401 = check if it's OAuth session expired 637 + if resp.StatusCode == http.StatusUnauthorized { 638 + // Try to parse JSON error response 639 + body, err := io.ReadAll(resp.Body) 640 + if err == nil { 641 + var authErr AuthErrorResponse 642 + if json.Unmarshal(body, &authErr) == nil && authErr.Error == "oauth_session_expired" { 643 + return ValidationResult{ 644 + Valid: false, 645 + OAuthSessionExpired: true, 646 + LoginURL: authErr.LoginURL, 647 + } 648 + } 649 + } 650 + // Generic auth failure 651 + return ValidationResult{Valid: false} 652 + } 653 + 581 654 // Any other error = assume valid (don't re-auth on server issues) 582 - return resp.StatusCode == http.StatusOK 655 + return ValidationResult{Valid: true} 583 656 }
+14
pkg/appview/db/oauth_store.go
··· 212 212 return &sessionData, sessionID, nil 213 213 } 214 214 215 + // HasSessionForDID checks if an OAuth session exists for the given DID 216 + // This is a lightweight check used by the token handler to verify device auth 217 + func (s *OAuthStore) HasSessionForDID(ctx context.Context, did string) bool { 218 + var count int 219 + err := s.db.QueryRowContext(ctx, ` 220 + SELECT COUNT(*) FROM oauth_sessions WHERE account_did = ? 221 + `, did).Scan(&count) 222 + if err != nil { 223 + slog.Debug("Failed to check session existence", "did", did, "error", err) 224 + return false 225 + } 226 + return count > 0 227 + } 228 + 215 229 // CleanupOldSessions removes sessions older than the specified duration 216 230 func (s *OAuthStore) CleanupOldSessions(ctx context.Context, olderThan time.Duration) error { 217 231 cutoff := time.Now().Add(-olderThan)
+31
pkg/appview/storage/manifest_store.go
··· 143 143 isManifestList := strings.Contains(manifestRecord.MediaType, "manifest.list") || 144 144 strings.Contains(manifestRecord.MediaType, "image.index") 145 145 146 + // Validate manifest list child references 147 + // Reject manifest lists that reference non-existent child manifests 148 + // This matches Docker Hub/ECR behavior and prevents users from accidentally pushing 149 + // manifest lists where the underlying images don't exist 150 + if isManifestList { 151 + for _, ref := range manifestRecord.Manifests { 152 + // Check if referenced manifest exists in user's PDS 153 + refDigest, err := digest.Parse(ref.Digest) 154 + if err != nil { 155 + return "", fmt.Errorf("invalid digest in manifest list: %s", ref.Digest) 156 + } 157 + 158 + exists, err := s.Exists(ctx, refDigest) 159 + if err != nil { 160 + return "", fmt.Errorf("failed to check manifest reference: %w", err) 161 + } 162 + 163 + if !exists { 164 + platform := "unknown" 165 + if ref.Platform != nil { 166 + platform = fmt.Sprintf("%s/%s", ref.Platform.OS, ref.Platform.Architecture) 167 + } 168 + slog.Warn("Manifest list references non-existent child manifest", 169 + "repository", s.ctx.Repository, 170 + "missingDigest", ref.Digest, 171 + "platform", platform) 172 + return "", distribution.ErrManifestBlobUnknown{Digest: refDigest} 173 + } 174 + } 175 + } 176 + 146 177 if !isManifestList && s.blobStore != nil && manifestRecord.Config != nil && manifestRecord.Config.Digest != "" { 147 178 labels, err := s.extractConfigLabels(ctx, manifestRecord.Config.Digest) 148 179 if err != nil {
+247
pkg/appview/storage/manifest_store_test.go
··· 3 3 import ( 4 4 "context" 5 5 "encoding/json" 6 + "errors" 6 7 "io" 7 8 "net/http" 8 9 "net/http/httptest" ··· 912 913 }) 913 914 } 914 915 } 916 + 917 + // TestManifestStore_Put_ManifestListValidation tests validation of manifest list child references 918 + func TestManifestStore_Put_ManifestListValidation(t *testing.T) { 919 + // Create a valid child manifest that exists 920 + childManifest := []byte(`{ 921 + "schemaVersion":2, 922 + "mediaType":"application/vnd.oci.image.manifest.v1+json", 923 + "config":{"digest":"sha256:config123","size":100}, 924 + "layers":[{"digest":"sha256:layer1","size":200}] 925 + }`) 926 + childDigest := digest.FromBytes(childManifest) 927 + 928 + tests := []struct { 929 + name string 930 + manifestList []byte 931 + childExists bool // Whether the child manifest exists 932 + wantErr bool 933 + wantErrType string // "ErrManifestBlobUnknown" or empty 934 + checkErrDigest string // Expected digest in error 935 + }{ 936 + { 937 + name: "valid manifest list - child exists", 938 + manifestList: []byte(`{ 939 + "schemaVersion":2, 940 + "mediaType":"application/vnd.oci.image.index.v1+json", 941 + "manifests":[ 942 + {"digest":"` + childDigest.String() + `","size":300,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"linux","architecture":"amd64"}} 943 + ] 944 + }`), 945 + childExists: true, 946 + wantErr: false, 947 + }, 948 + { 949 + name: "invalid manifest list - child does not exist", 950 + manifestList: []byte(`{ 951 + "schemaVersion":2, 952 + "mediaType":"application/vnd.oci.image.index.v1+json", 953 + "manifests":[ 954 + {"digest":"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef","size":300,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"linux","architecture":"amd64"}} 955 + ] 956 + }`), 957 + childExists: false, 958 + wantErr: true, 959 + wantErrType: "ErrManifestBlobUnknown", 960 + checkErrDigest: "sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", 961 + }, 962 + { 963 + name: "attestation-only manifest list - attestation must also exist", 964 + manifestList: []byte(`{ 965 + "schemaVersion":2, 966 + "mediaType":"application/vnd.oci.image.index.v1+json", 967 + "manifests":[ 968 + {"digest":"sha256:4444444444444444444444444444444444444444444444444444444444444444","size":100,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"unknown","architecture":"unknown"}} 969 + ] 970 + }`), 971 + childExists: false, 972 + wantErr: true, 973 + wantErrType: "ErrManifestBlobUnknown", 974 + checkErrDigest: "sha256:4444444444444444444444444444444444444444444444444444444444444444", 975 + }, 976 + { 977 + name: "mixed manifest list - real platform missing, attestation present", 978 + manifestList: []byte(`{ 979 + "schemaVersion":2, 980 + "mediaType":"application/vnd.oci.image.index.v1+json", 981 + "manifests":[ 982 + {"digest":"sha256:1111111111111111111111111111111111111111111111111111111111111111","size":300,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"linux","architecture":"arm64"}}, 983 + {"digest":"sha256:5555555555555555555555555555555555555555555555555555555555555555","size":100,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"unknown","architecture":"unknown"}} 984 + ] 985 + }`), 986 + childExists: false, 987 + wantErr: true, 988 + wantErrType: "ErrManifestBlobUnknown", 989 + checkErrDigest: "sha256:1111111111111111111111111111111111111111111111111111111111111111", 990 + }, 991 + { 992 + name: "docker manifest list media type - child missing", 993 + manifestList: []byte(`{ 994 + "schemaVersion":2, 995 + "mediaType":"application/vnd.docker.distribution.manifest.list.v2+json", 996 + "manifests":[ 997 + {"digest":"sha256:2222222222222222222222222222222222222222222222222222222222222222","size":300,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","platform":{"os":"linux","architecture":"amd64"}} 998 + ] 999 + }`), 1000 + childExists: false, 1001 + wantErr: true, 1002 + wantErrType: "ErrManifestBlobUnknown", 1003 + checkErrDigest: "sha256:2222222222222222222222222222222222222222222222222222222222222222", 1004 + }, 1005 + { 1006 + name: "manifest list with nil platform - should still validate", 1007 + manifestList: []byte(`{ 1008 + "schemaVersion":2, 1009 + "mediaType":"application/vnd.oci.image.index.v1+json", 1010 + "manifests":[ 1011 + {"digest":"sha256:3333333333333333333333333333333333333333333333333333333333333333","size":300,"mediaType":"application/vnd.oci.image.manifest.v1+json"} 1012 + ] 1013 + }`), 1014 + childExists: false, 1015 + wantErr: true, 1016 + wantErrType: "ErrManifestBlobUnknown", 1017 + checkErrDigest: "sha256:3333333333333333333333333333333333333333333333333333333333333333", 1018 + }, 1019 + } 1020 + 1021 + for _, tt := range tests { 1022 + t.Run(tt.name, func(t *testing.T) { 1023 + // Track GetRecord calls for manifest existence checks 1024 + getRecordCalls := make(map[string]bool) 1025 + 1026 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1027 + // Handle uploadBlob 1028 + if r.URL.Path == atproto.RepoUploadBlob { 1029 + w.WriteHeader(http.StatusOK) 1030 + w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"mimeType":"application/json","size":100}}`)) 1031 + return 1032 + } 1033 + 1034 + // Handle getRecord (for Exists check) 1035 + if r.URL.Path == atproto.RepoGetRecord { 1036 + rkey := r.URL.Query().Get("rkey") 1037 + getRecordCalls[rkey] = true 1038 + 1039 + // If child should exist, return it; otherwise return RecordNotFound 1040 + if tt.childExists || rkey == childDigest.Encoded() { 1041 + w.WriteHeader(http.StatusOK) 1042 + w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafytest","value":{}}`)) 1043 + } else { 1044 + w.WriteHeader(http.StatusBadRequest) 1045 + w.Write([]byte(`{"error":"RecordNotFound","message":"Record not found"}`)) 1046 + } 1047 + return 1048 + } 1049 + 1050 + // Handle putRecord 1051 + if r.URL.Path == atproto.RepoPutRecord { 1052 + w.WriteHeader(http.StatusOK) 1053 + w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafytest"}`)) 1054 + return 1055 + } 1056 + 1057 + w.WriteHeader(http.StatusOK) 1058 + })) 1059 + defer server.Close() 1060 + 1061 + client := atproto.NewClient(server.URL, "did:plc:test123", "token") 1062 + db := &mockDatabaseMetrics{} 1063 + ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db) 1064 + store := NewManifestStore(ctx, nil) 1065 + 1066 + manifest := &rawManifest{ 1067 + mediaType: "application/vnd.oci.image.index.v1+json", 1068 + payload: tt.manifestList, 1069 + } 1070 + 1071 + _, err := store.Put(context.Background(), manifest) 1072 + 1073 + if (err != nil) != tt.wantErr { 1074 + t.Errorf("Put() error = %v, wantErr %v", err, tt.wantErr) 1075 + return 1076 + } 1077 + 1078 + if tt.wantErr && tt.wantErrType == "ErrManifestBlobUnknown" { 1079 + // Check that the error is of the correct type 1080 + var blobErr distribution.ErrManifestBlobUnknown 1081 + if !errors.As(err, &blobErr) { 1082 + t.Errorf("Put() error type = %T, want distribution.ErrManifestBlobUnknown", err) 1083 + return 1084 + } 1085 + 1086 + // Check that the error contains the expected digest 1087 + if tt.checkErrDigest != "" { 1088 + expectedDigest, _ := digest.Parse(tt.checkErrDigest) 1089 + if blobErr.Digest != expectedDigest { 1090 + t.Errorf("ErrManifestBlobUnknown.Digest = %v, want %v", blobErr.Digest, expectedDigest) 1091 + } 1092 + } 1093 + } 1094 + }) 1095 + } 1096 + } 1097 + 1098 + // TestManifestStore_Put_ManifestListValidation_MultipleChildren tests validation with multiple child manifests 1099 + func TestManifestStore_Put_ManifestListValidation_MultipleChildren(t *testing.T) { 1100 + // Create two valid child manifests 1101 + childManifest1 := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.manifest.v1+json","config":{"digest":"sha256:config1","size":100},"layers":[]}`) 1102 + childManifest2 := []byte(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.manifest.v1+json","config":{"digest":"sha256:config2","size":100},"layers":[]}`) 1103 + childDigest1 := digest.FromBytes(childManifest1) 1104 + childDigest2 := digest.FromBytes(childManifest2) 1105 + 1106 + // Track which manifests exist 1107 + existingManifests := map[string]bool{ 1108 + childDigest1.Encoded(): true, 1109 + childDigest2.Encoded(): true, 1110 + } 1111 + 1112 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1113 + if r.URL.Path == atproto.RepoUploadBlob { 1114 + w.Write([]byte(`{"blob":{"$type":"blob","ref":{"$link":"bafytest"},"size":100}}`)) 1115 + return 1116 + } 1117 + 1118 + if r.URL.Path == atproto.RepoGetRecord { 1119 + rkey := r.URL.Query().Get("rkey") 1120 + if existingManifests[rkey] { 1121 + w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/` + rkey + `","cid":"bafytest","value":{}}`)) 1122 + } else { 1123 + w.WriteHeader(http.StatusBadRequest) 1124 + w.Write([]byte(`{"error":"RecordNotFound"}`)) 1125 + } 1126 + return 1127 + } 1128 + 1129 + if r.URL.Path == atproto.RepoPutRecord { 1130 + w.Write([]byte(`{"uri":"at://did:plc:test123/io.atcr.manifest/test123","cid":"bafytest"}`)) 1131 + return 1132 + } 1133 + 1134 + w.WriteHeader(http.StatusOK) 1135 + })) 1136 + defer server.Close() 1137 + 1138 + client := atproto.NewClient(server.URL, "did:plc:test123", "token") 1139 + ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", nil) 1140 + store := NewManifestStore(ctx, nil) 1141 + 1142 + // Create manifest list with both children 1143 + manifestList := []byte(`{ 1144 + "schemaVersion":2, 1145 + "mediaType":"application/vnd.oci.image.index.v1+json", 1146 + "manifests":[ 1147 + {"digest":"` + childDigest1.String() + `","size":300,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"linux","architecture":"amd64"}}, 1148 + {"digest":"` + childDigest2.String() + `","size":300,"mediaType":"application/vnd.oci.image.manifest.v1+json","platform":{"os":"linux","architecture":"arm64"}} 1149 + ] 1150 + }`) 1151 + 1152 + manifest := &rawManifest{ 1153 + mediaType: "application/vnd.oci.image.index.v1+json", 1154 + payload: manifestList, 1155 + } 1156 + 1157 + _, err := store.Put(context.Background(), manifest) 1158 + if err != nil { 1159 + t.Errorf("Put() should succeed when all child manifests exist, got error: %v", err) 1160 + } 1161 + }
+53 -4
pkg/auth/token/handler.go
··· 20 20 // without coupling the token package to AppView-specific dependencies. 21 21 type PostAuthCallback func(ctx context.Context, did, handle, pdsEndpoint, accessToken string) error 22 22 23 + // OAuthSessionChecker checks if an OAuth session exists for a DID 24 + // This interface allows the token handler to verify OAuth sessions without 25 + // depending directly on the OAuth store implementation. 26 + type OAuthSessionChecker interface { 27 + HasSessionForDID(ctx context.Context, did string) bool 28 + } 29 + 23 30 // Handler handles /auth/token requests 24 31 type Handler struct { 25 - issuer *Issuer 26 - validator *auth.SessionValidator 27 - deviceStore *db.DeviceStore // For validating device secrets 28 - postAuthCallback PostAuthCallback 32 + issuer *Issuer 33 + validator *auth.SessionValidator 34 + deviceStore *db.DeviceStore // For validating device secrets 35 + postAuthCallback PostAuthCallback 36 + oauthSessionChecker OAuthSessionChecker 29 37 } 30 38 31 39 // NewHandler creates a new token handler ··· 41 49 // This allows AppView to inject business logic without coupling the token package 42 50 func (h *Handler) SetPostAuthCallback(callback PostAuthCallback) { 43 51 h.postAuthCallback = callback 52 + } 53 + 54 + // SetOAuthSessionChecker sets the OAuth session checker for validating device auth 55 + // When set, the handler will verify OAuth sessions exist before issuing tokens for device auth 56 + func (h *Handler) SetOAuthSessionChecker(checker OAuthSessionChecker) { 57 + h.oauthSessionChecker = checker 44 58 } 45 59 46 60 // TokenResponse represents the response from /auth/token ··· 80 94 (use your ATProto handle + app-password)`, message, baseURL, r.Host), http.StatusUnauthorized) 81 95 } 82 96 97 + // AuthErrorResponse is returned when authentication fails in a way the credential helper can handle 98 + type AuthErrorResponse struct { 99 + Error string `json:"error"` 100 + Message string `json:"message"` 101 + LoginURL string `json:"login_url,omitempty"` 102 + } 103 + 104 + // sendOAuthSessionExpiredError sends a JSON error response when OAuth session is missing 105 + // This allows the credential helper to detect this specific error and open the browser 106 + func sendOAuthSessionExpiredError(w http.ResponseWriter, r *http.Request) { 107 + baseURL := getBaseURL(r) 108 + loginURL := baseURL + "/auth/oauth/login" 109 + 110 + w.Header().Set("WWW-Authenticate", `Basic realm="ATCR Registry"`) 111 + w.Header().Set("Content-Type", "application/json") 112 + w.WriteHeader(http.StatusUnauthorized) 113 + 114 + resp := AuthErrorResponse{ 115 + Error: "oauth_session_expired", 116 + Message: "OAuth session expired or invalidated. Please re-authenticate in your browser.", 117 + LoginURL: loginURL, 118 + } 119 + json.NewEncoder(w).Encode(resp) 120 + } 121 + 83 122 // ServeHTTP handles the token request 84 123 func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 85 124 slog.Debug("Received token request", "method", r.Method, "path", r.URL.Path) ··· 128 167 slog.Debug("Device secret validation failed", "error", err) 129 168 sendAuthError(w, r, "authentication failed") 130 169 return 170 + } 171 + 172 + // Check if OAuth session exists for this device's DID 173 + // Device secrets are permanent, but they require an active OAuth session to work 174 + if h.oauthSessionChecker != nil { 175 + if !h.oauthSessionChecker.HasSessionForDID(r.Context(), device.DID) { 176 + slog.Debug("No OAuth session for device", "did", device.DID) 177 + sendOAuthSessionExpiredError(w, r) 178 + return 179 + } 131 180 } 132 181 133 182 did = device.DID