// Package sqlite provides an oauth.ClientAuthStore backed by SQLite. package sqlite import ( "context" "database/sql" "encoding/json" "errors" "fmt" "sync" "github.com/bluesky-social/indigo/atproto/auth/oauth" "github.com/bluesky-social/indigo/atproto/syntax" ) var ( ErrSessionNotFound = errors.New("session not found") ErrStateNotFound = errors.New("auth state not found") ) // Verify interface compliance at compile time. var _ oauth.ClientAuthStore = (*Store)(nil) // Store implements [oauth.ClientAuthStore] using SQLite. type Store struct { db *sql.DB mu sync.RWMutex } // New creates a new SQLite-backed auth store. The caller owns the *sql.DB // and is responsible for closing it. func New(db *sql.DB) (*Store, error) { s := &Store{db: db} if err := s.migrate(); err != nil { return nil, fmt.Errorf("migrate auth tables: %w", err) } return s, nil } func (s *Store) migrate() error { _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS oauth_sessions ( did TEXT NOT NULL, session_id TEXT NOT NULL, data TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY (did, session_id) ); CREATE TABLE IF NOT EXISTS oauth_auth_requests ( state TEXT PRIMARY KEY, data TEXT NOT NULL, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_oauth_auth_requests_created_at ON oauth_auth_requests(created_at); `) return err } // GetSession retrieves an OAuth session by DID and session ID. func (s *Store) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) { s.mu.RLock() defer s.mu.RUnlock() var dataJSON string err := s.db.QueryRowContext(ctx, `SELECT data FROM oauth_sessions WHERE did = ? AND session_id = ?`, did.String(), sessionID, ).Scan(&dataJSON) if errors.Is(err, sql.ErrNoRows) { return nil, ErrSessionNotFound } if err != nil { return nil, fmt.Errorf("query session: %w", err) } var data oauth.ClientSessionData if err := json.Unmarshal([]byte(dataJSON), &data); err != nil { return nil, fmt.Errorf("unmarshal session: %w", err) } return &data, nil } // SaveSession persists an OAuth session (upsert). func (s *Store) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error { s.mu.Lock() defer s.mu.Unlock() dataJSON, err := json.Marshal(sess) if err != nil { return fmt.Errorf("marshal session: %w", err) } _, err = s.db.ExecContext(ctx, ` INSERT INTO oauth_sessions (did, session_id, data, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP) ON CONFLICT(did, session_id) DO UPDATE SET data = excluded.data, updated_at = CURRENT_TIMESTAMP `, sess.AccountDID.String(), sess.SessionID, string(dataJSON)) return err } // DeleteSession removes an OAuth session. func (s *Store) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_sessions WHERE did = ? AND session_id = ?`, did.String(), sessionID, ) return err } // GetAuthRequestInfo retrieves pending auth request data by state token. func (s *Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) { s.mu.RLock() defer s.mu.RUnlock() var dataJSON string err := s.db.QueryRowContext(ctx, `SELECT data FROM oauth_auth_requests WHERE state = ?`, state, ).Scan(&dataJSON) if errors.Is(err, sql.ErrNoRows) { return nil, ErrStateNotFound } if err != nil { return nil, fmt.Errorf("query auth request: %w", err) } var data oauth.AuthRequestData if err := json.Unmarshal([]byte(dataJSON), &data); err != nil { return nil, fmt.Errorf("unmarshal auth request: %w", err) } return &data, nil } // SaveAuthRequestInfo stores auth request data keyed by state token. func (s *Store) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error { s.mu.Lock() defer s.mu.Unlock() dataJSON, err := json.Marshal(info) if err != nil { return fmt.Errorf("marshal auth request: %w", err) } _, err = s.db.ExecContext(ctx, `INSERT INTO oauth_auth_requests (state, data) VALUES (?, ?)`, info.State, string(dataJSON), ) return err } // DeleteAuthRequestInfo removes auth request data by state token. func (s *Store) DeleteAuthRequestInfo(ctx context.Context, state string) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_auth_requests WHERE state = ?`, state, ) return err } // CleanupExpiredRequests removes auth requests older than the given duration. // Call this periodically (e.g. every 10 minutes) to prevent unbounded growth. func (s *Store) CleanupExpiredRequests(ctx context.Context, olderThanMinutes int) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.ExecContext(ctx, `DELETE FROM oauth_auth_requests WHERE created_at < datetime('now', ?)`, fmt.Sprintf("-%d minutes", olderThanMinutes), ) return err }