A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 191 lines 6.0 kB view raw
1// Package auth provides authentication and authorization for ATCR, including 2// ATProto session validation, hold authorization (captain/crew membership), 3// scope parsing, and token caching for OAuth and service tokens. 4package auth 5 6import ( 7 "bytes" 8 "context" 9 "crypto/sha256" 10 "encoding/hex" 11 "encoding/json" 12 "errors" 13 "fmt" 14 "io" 15 "log/slog" 16 "net/http" 17 "sync" 18 "time" 19 20 "atcr.io/pkg/atproto" 21) 22 23// Sentinel errors for authentication failures 24var ( 25 // ErrIdentityResolution indicates handle/DID resolution failed 26 ErrIdentityResolution = errors.New("identity resolution failed") 27 // ErrInvalidCredentials indicates PDS returned 401 (bad password/app-password) 28 ErrInvalidCredentials = errors.New("invalid credentials") 29 // ErrPDSUnavailable indicates PDS is unreachable or returned a server error 30 ErrPDSUnavailable = errors.New("PDS unavailable") 31) 32 33// CachedSession represents a cached session 34type CachedSession struct { 35 DID string 36 Handle string 37 PDS string 38 AccessToken string 39 ExpiresAt time.Time 40} 41 42// SessionValidator validates ATProto credentials 43type SessionValidator struct { 44 httpClient *http.Client 45 cache map[string]*CachedSession 46 cacheMu sync.RWMutex 47} 48 49// NewSessionValidator creates a new ATProto session validator 50func NewSessionValidator() *SessionValidator { 51 return &SessionValidator{ 52 httpClient: &http.Client{}, 53 cache: make(map[string]*CachedSession), 54 } 55} 56 57// getCacheKey generates a cache key from username and password 58func getCacheKey(username, password string) string { 59 h := sha256.New() 60 h.Write([]byte(username + ":" + password)) 61 return hex.EncodeToString(h.Sum(nil)) 62} 63 64// getCachedSession retrieves a cached session if valid 65func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) { 66 v.cacheMu.RLock() 67 defer v.cacheMu.RUnlock() 68 69 session, ok := v.cache[cacheKey] 70 if !ok { 71 return nil, false 72 } 73 74 // Check if expired (with 5 minute buffer) 75 if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) { 76 return nil, false 77 } 78 79 return session, true 80} 81 82// setCachedSession stores a session in the cache 83func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) { 84 v.cacheMu.Lock() 85 defer v.cacheMu.Unlock() 86 v.cache[cacheKey] = session 87} 88 89// SessionResponse represents the response from createSession 90type SessionResponse struct { 91 DID string `json:"did"` 92 Handle string `json:"handle"` 93 AccessJWT string `json:"accessJwt"` 94 RefreshJWT string `json:"refreshJwt"` 95 Email string `json:"email,omitempty"` 96 AccessToken string `json:"access_token,omitempty"` // Alternative field name 97} 98 99// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token 100func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) { 101 // Check cache first 102 cacheKey := getCacheKey(identifier, password) 103 if cached, ok := v.getCachedSession(cacheKey); ok { 104 slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID) 105 return cached.DID, cached.Handle, cached.AccessToken, nil 106 } 107 108 slog.Debug("No cached session, creating new session", "identifier", identifier) 109 110 // Resolve identifier to PDS endpoint 111 _, _, pds, err := atproto.ResolveIdentity(ctx, identifier) 112 if err != nil { 113 return "", "", "", fmt.Errorf("%w: %v", ErrIdentityResolution, err) 114 } 115 116 // Create session 117 sessionResp, err := v.createSession(ctx, pds, identifier, password) 118 if err != nil { 119 // Pass through typed errors from createSession 120 return "", "", "", err 121 } 122 123 // Cache the session (ATProto sessions typically last 2 hours) 124 v.setCachedSession(cacheKey, &CachedSession{ 125 DID: sessionResp.DID, 126 Handle: sessionResp.Handle, 127 PDS: pds, 128 AccessToken: sessionResp.AccessJWT, 129 ExpiresAt: time.Now().Add(2 * time.Hour), 130 }) 131 slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID) 132 133 return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil 134} 135 136// createSession calls com.atproto.server.createSession 137func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) { 138 payload := map[string]string{ 139 "identifier": identifier, 140 "password": password, 141 } 142 143 body, err := json.Marshal(payload) 144 if err != nil { 145 return nil, fmt.Errorf("failed to marshal request: %w", err) 146 } 147 148 url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession) 149 slog.Debug("Creating ATProto session", "url", url) 150 151 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) 152 if err != nil { 153 return nil, err 154 } 155 156 req.Header.Set("Content-Type", "application/json") 157 158 resp, err := v.httpClient.Do(req) 159 if err != nil { 160 slog.Debug("Session creation HTTP request failed", "error", err) 161 return nil, fmt.Errorf("%w: %v", ErrPDSUnavailable, err) 162 } 163 defer resp.Body.Close() 164 165 slog.Debug("Received session creation response", "status", resp.StatusCode) 166 167 if resp.StatusCode == http.StatusUnauthorized { 168 bodyBytes, _ := io.ReadAll(resp.Body) 169 slog.Debug("Session creation unauthorized", "response", string(bodyBytes)) 170 return nil, ErrInvalidCredentials 171 } 172 173 if resp.StatusCode >= 500 { 174 bodyBytes, _ := io.ReadAll(resp.Body) 175 slog.Debug("PDS server error", "status", resp.StatusCode, "response", string(bodyBytes)) 176 return nil, fmt.Errorf("%w: server returned %d", ErrPDSUnavailable, resp.StatusCode) 177 } 178 179 if resp.StatusCode != http.StatusOK { 180 bodyBytes, _ := io.ReadAll(resp.Body) 181 slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes)) 182 return nil, fmt.Errorf("%w: unexpected status %d: %s", ErrPDSUnavailable, resp.StatusCode, string(bodyBytes)) 183 } 184 185 var sessionResp SessionResponse 186 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { 187 return nil, fmt.Errorf("failed to decode response: %w", err) 188 } 189 190 return &sessionResp, nil 191}