A container registry that uses the AT Protocol for manifest storage and S3 for blob storage. atcr.io
docker container atproto go
72
fork

Configure Feed

Select the types of activity you want to include in your feed.

at refactor 173 lines 5.2 kB view raw
1// Package auth provides authentication and authorization for ATCR, including 2// ATProto session validation, hold authorization (captain/crew membership), 3// scope parsing, and token caching for OAuth and service tokens. 4package auth 5 6import ( 7 "bytes" 8 "context" 9 "crypto/sha256" 10 "encoding/hex" 11 "encoding/json" 12 "fmt" 13 "io" 14 "log/slog" 15 "net/http" 16 "sync" 17 "time" 18 19 "atcr.io/pkg/atproto" 20) 21 22// CachedSession represents a cached session 23type CachedSession struct { 24 DID string 25 Handle string 26 PDS string 27 AccessToken string 28 ExpiresAt time.Time 29} 30 31// SessionValidator validates ATProto credentials 32type SessionValidator struct { 33 httpClient *http.Client 34 cache map[string]*CachedSession 35 cacheMu sync.RWMutex 36} 37 38// NewSessionValidator creates a new ATProto session validator 39func NewSessionValidator() *SessionValidator { 40 return &SessionValidator{ 41 httpClient: &http.Client{}, 42 cache: make(map[string]*CachedSession), 43 } 44} 45 46// getCacheKey generates a cache key from username and password 47func getCacheKey(username, password string) string { 48 h := sha256.New() 49 h.Write([]byte(username + ":" + password)) 50 return hex.EncodeToString(h.Sum(nil)) 51} 52 53// getCachedSession retrieves a cached session if valid 54func (v *SessionValidator) getCachedSession(cacheKey string) (*CachedSession, bool) { 55 v.cacheMu.RLock() 56 defer v.cacheMu.RUnlock() 57 58 session, ok := v.cache[cacheKey] 59 if !ok { 60 return nil, false 61 } 62 63 // Check if expired (with 5 minute buffer) 64 if time.Now().After(session.ExpiresAt.Add(-5 * time.Minute)) { 65 return nil, false 66 } 67 68 return session, true 69} 70 71// setCachedSession stores a session in the cache 72func (v *SessionValidator) setCachedSession(cacheKey string, session *CachedSession) { 73 v.cacheMu.Lock() 74 defer v.cacheMu.Unlock() 75 v.cache[cacheKey] = session 76} 77 78// SessionResponse represents the response from createSession 79type SessionResponse struct { 80 DID string `json:"did"` 81 Handle string `json:"handle"` 82 AccessJWT string `json:"accessJwt"` 83 RefreshJWT string `json:"refreshJwt"` 84 Email string `json:"email,omitempty"` 85 AccessToken string `json:"access_token,omitempty"` // Alternative field name 86} 87 88// CreateSessionAndGetToken creates a session and returns the DID, handle, and access token 89func (v *SessionValidator) CreateSessionAndGetToken(ctx context.Context, identifier, password string) (did, handle, accessToken string, err error) { 90 // Check cache first 91 cacheKey := getCacheKey(identifier, password) 92 if cached, ok := v.getCachedSession(cacheKey); ok { 93 slog.Debug("Using cached session", "identifier", identifier, "did", cached.DID) 94 return cached.DID, cached.Handle, cached.AccessToken, nil 95 } 96 97 slog.Debug("No cached session, creating new session", "identifier", identifier) 98 99 // Resolve identifier to PDS endpoint 100 _, _, pds, err := atproto.ResolveIdentity(ctx, identifier) 101 if err != nil { 102 return "", "", "", err 103 } 104 105 // Create session 106 sessionResp, err := v.createSession(ctx, pds, identifier, password) 107 if err != nil { 108 return "", "", "", fmt.Errorf("authentication failed: %w", err) 109 } 110 111 // Cache the session (ATProto sessions typically last 2 hours) 112 v.setCachedSession(cacheKey, &CachedSession{ 113 DID: sessionResp.DID, 114 Handle: sessionResp.Handle, 115 PDS: pds, 116 AccessToken: sessionResp.AccessJWT, 117 ExpiresAt: time.Now().Add(2 * time.Hour), 118 }) 119 slog.Debug("Cached session (expires in 2 hours)", "identifier", identifier, "did", sessionResp.DID) 120 121 return sessionResp.DID, sessionResp.Handle, sessionResp.AccessJWT, nil 122} 123 124// createSession calls com.atproto.server.createSession 125func (v *SessionValidator) createSession(ctx context.Context, pdsEndpoint, identifier, password string) (*SessionResponse, error) { 126 payload := map[string]string{ 127 "identifier": identifier, 128 "password": password, 129 } 130 131 body, err := json.Marshal(payload) 132 if err != nil { 133 return nil, fmt.Errorf("failed to marshal request: %w", err) 134 } 135 136 url := fmt.Sprintf("%s%s", pdsEndpoint, atproto.ServerCreateSession) 137 slog.Debug("Creating ATProto session", "url", url) 138 139 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) 140 if err != nil { 141 return nil, err 142 } 143 144 req.Header.Set("Content-Type", "application/json") 145 146 resp, err := v.httpClient.Do(req) 147 if err != nil { 148 slog.Debug("Session creation HTTP request failed", "error", err) 149 return nil, fmt.Errorf("failed to create session: %w", err) 150 } 151 defer resp.Body.Close() 152 153 slog.Debug("Received session creation response", "status", resp.StatusCode) 154 155 if resp.StatusCode == http.StatusUnauthorized { 156 bodyBytes, _ := io.ReadAll(resp.Body) 157 slog.Debug("Session creation unauthorized", "response", string(bodyBytes)) 158 return nil, fmt.Errorf("invalid credentials") 159 } 160 161 if resp.StatusCode != http.StatusOK { 162 bodyBytes, _ := io.ReadAll(resp.Body) 163 slog.Debug("Session creation failed", "status", resp.StatusCode, "response", string(bodyBytes)) 164 return nil, fmt.Errorf("create session failed with status %d: %s", resp.StatusCode, string(bodyBytes)) 165 } 166 167 var sessionResp SessionResponse 168 if err := json.NewDecoder(resp.Body).Decode(&sessionResp); err != nil { 169 return nil, fmt.Errorf("failed to decode response: %w", err) 170 } 171 172 return &sessionResp, nil 173}