A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.
1// Package auth provides authentication and authorization for ATCR, including
2// ATProto session validation, hold authorization (captain/crew membership),
3// scope parsing, and token caching for OAuth and service tokens.
4package auth
5
6import (
7 "bytes"
8 "context"
9 "crypto/sha256"
10 "encoding/hex"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "io"
15 "log/slog"
16 "net/http"
17 "sync"
18 "time"
19
20 "atcr.io/pkg/atproto"
21)
22
23// Sentinel errors for authentication failures
24var (
25 // ErrIdentityResolution indicates handle/DID resolution failed
26 ErrIdentityResolution = errors.New("identity resolution failed")
27 // ErrInvalidCredentials indicates PDS returned 401 (bad password/app-password)
28 ErrInvalidCredentials = errors.New("invalid credentials")
29 // ErrPDSUnavailable indicates PDS is unreachable or returned a server error
30 ErrPDSUnavailable = errors.New("PDS unavailable")
31)
32
33// CachedSession represents a cached session
34type CachedSession struct {
35 DID string
36 Handle string
37 PDS string
38 AccessToken string
39 ExpiresAt time.Time
40}
41
42// SessionValidator validates ATProto credentials
43type SessionValidator struct {
44 httpClient *http.Client
45 cache map[string]*CachedSession
46 cacheMu sync.RWMutex
47}
48
49// NewSessionValidator creates a new ATProto session validator
50func NewSessionValidator() *SessionValidator {
51 return &SessionValidator{
52 httpClient: &http.Client{},
53 cache: make(map[string]*CachedSession),
54 }
55}
56
57// getCacheKey generates a cache key from username and password
58func getCacheKey(username, password string) string {
59 h := sha256.New()
60 h.Write([]byte(username + ":" + password))
61 return hex.EncodeToString(h.Sum(nil))
62}
63
64// getCachedSession retrieves a cached session if valid
65func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) {
66 v.cacheMu.RLock()
67 defer v.cacheMu.RUnlock()
68
69 session, ok := v.cache[cacheKey]
70 if !ok {
71 return nil, false
72 }
73
74 // Check if expired (with 5 minute buffer)
75 if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) {
76 return nil, false
77 }
78
79 return session, true
80}
81
82// setCachedSession stores a session in the cache
83func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) {
84 v.cacheMu.Lock()
85 defer v.cacheMu.Unlock()
86 v.cache[cacheKey] = session
87}
88
89// SessionResponse represents the response from createSession
90type SessionResponse struct {
91 DID string `json:"did"`
92 Handle string `json:"handle"`
93 AccessJWT string `json:"accessJwt"`
94 RefreshJWT string `json:"refreshJwt"`
95 Email string `json:"email,omitempty"`
96 AccessToken string `json:"access_token,omitempty"` // Alternative field name
97}
98
99// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token
100func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) {
101 // Check cache first
102 cacheKey := getCacheKey(identifier, password)
103 if cached, ok := v.getCachedSession(cacheKey); ok {
104 slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID)
105 return cached.DID, cached.Handle, cached.AccessToken, nil
106 }
107
108 slog.Debug("No cached session, creating new session", "identifier", identifier)
109
110 // Resolve identifier to PDS endpoint
111 _, _, pds, err := atproto.ResolveIdentity(ctx, identifier)
112 if err != nil {
113 return "", "", "", fmt.Errorf("%w: %v", ErrIdentityResolution, err)
114 }
115
116 // Create session
117 sessionResp, err := v.createSession(ctx, pds, identifier, password)
118 if err != nil {
119 // Pass through typed errors from createSession
120 return "", "", "", err
121 }
122
123 // Cache the session (ATProto sessions typically last 2 hours)
124 v.setCachedSession(cacheKey, &CachedSession{
125 DID: sessionResp.DID,
126 Handle: sessionResp.Handle,
127 PDS: pds,
128 AccessToken: sessionResp.AccessJWT,
129 ExpiresAt: time.Now().Add(2 * time.Hour),
130 })
131 slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID)
132
133 return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil
134}
135
136// createSession calls com.atproto.server.createSession
137func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) {
138 payload := map[string]string{
139 "identifier": identifier,
140 "password": password,
141 }
142
143 body, err := json.Marshal(payload)
144 if err != nil {
145 return nil, fmt.Errorf("failed to marshal request: %w", err)
146 }
147
148 url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession)
149 slog.Debug("Creating ATProto session", "url", url)
150
151 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
152 if err != nil {
153 return nil, err
154 }
155
156 req.Header.Set("Content-Type", "application/json")
157
158 resp, err := v.httpClient.Do(req)
159 if err != nil {
160 slog.Debug("Session creation HTTP request failed", "error", err)
161 return nil, fmt.Errorf("%w: %v", ErrPDSUnavailable, err)
162 }
163 defer resp.Body.Close()
164
165 slog.Debug("Received session creation response", "status", resp.StatusCode)
166
167 if resp.StatusCode == http.StatusUnauthorized {
168 bodyBytes, _ := io.ReadAll(resp.Body)
169 slog.Debug("Session creation unauthorized", "response", string(bodyBytes))
170 return nil, ErrInvalidCredentials
171 }
172
173 if resp.StatusCode >= 500 {
174 bodyBytes, _ := io.ReadAll(resp.Body)
175 slog.Debug("PDS server error", "status", resp.StatusCode, "response", string(bodyBytes))
176 return nil, fmt.Errorf("%w: server returned %d", ErrPDSUnavailable, resp.StatusCode)
177 }
178
179 if resp.StatusCode != http.StatusOK {
180 bodyBytes, _ := io.ReadAll(resp.Body)
181 slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes))
182 return nil, fmt.Errorf("%w: unexpected status %d: %s", ErrPDSUnavailable, resp.StatusCode, string(bodyBytes))
183 }
184
185 var sessionResp SessionResponse
186 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil {
187 return nil, fmt.Errorf("failed to decode response: %w", err)
188 }
189
190 return &sessionResp, nil
191}