Go boilerplate library for building atproto apps
atproto
go
1// Package sqlite provides an oauth.ClientAuthStore backed by SQLite.
2package sqlite
3
4import (
5 "context"
6 "database/sql"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "sync"
11
12 "github.com/bluesky-social/indigo/atproto/auth/oauth"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14)
15
16var (
17 ErrSessionNotFound = errors.New("session not found")
18 ErrStateNotFound = errors.New("auth state not found")
19)
20
21// Verify interface compliance at compile time.
22var _ oauth.ClientAuthStore = (*Store)(nil)
23
24// Store implements [oauth.ClientAuthStore] using SQLite.
25type Store struct {
26 db *sql.DB
27 mu sync.RWMutex
28}
29
30// New creates a new SQLite-backed auth store. The caller owns the *sql.DB
31// and is responsible for closing it.
32func New(db *sql.DB) (*Store, error) {
33 s := &Store{db: db}
34 if err := s.migrate(); err != nil {
35 return nil, fmt.Errorf("migrate auth tables: %w", err)
36 }
37 return s, nil
38}
39
40func (s *Store) migrate() error {
41 _, err := s.db.Exec(`
42 CREATE TABLE IF NOT EXISTS oauth_sessions (
43 did TEXT NOT NULL,
44 session_id TEXT NOT NULL,
45 data TEXT NOT NULL,
46 created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
47 updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
48 PRIMARY KEY (did, session_id)
49 );
50
51 CREATE TABLE IF NOT EXISTS oauth_auth_requests (
52 state TEXT PRIMARY KEY,
53 data TEXT NOT NULL,
54 created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
55 );
56
57 CREATE INDEX IF NOT EXISTS idx_oauth_auth_requests_created_at
58 ON oauth_auth_requests(created_at);
59 `)
60 return err
61}
62
63// GetSession retrieves an OAuth session by DID and session ID.
64func (s *Store) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
65 s.mu.RLock()
66 defer s.mu.RUnlock()
67
68 var dataJSON string
69 err := s.db.QueryRowContext(ctx,
70 `SELECT data FROM oauth_sessions WHERE did = ? AND session_id = ?`,
71 did.String(), sessionID,
72 ).Scan(&dataJSON)
73
74 if errors.Is(err, sql.ErrNoRows) {
75 return nil, ErrSessionNotFound
76 }
77 if err != nil {
78 return nil, fmt.Errorf("query session: %w", err)
79 }
80
81 var data oauth.ClientSessionData
82 if err := json.Unmarshal([]byte(dataJSON), &data); err != nil {
83 return nil, fmt.Errorf("unmarshal session: %w", err)
84 }
85 return &data, nil
86}
87
88// SaveSession persists an OAuth session (upsert).
89func (s *Store) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
90 s.mu.Lock()
91 defer s.mu.Unlock()
92
93 dataJSON, err := json.Marshal(sess)
94 if err != nil {
95 return fmt.Errorf("marshal session: %w", err)
96 }
97
98 _, err = s.db.ExecContext(ctx, `
99 INSERT INTO oauth_sessions (did, session_id, data, updated_at)
100 VALUES (?, ?, ?, CURRENT_TIMESTAMP)
101 ON CONFLICT(did, session_id) DO UPDATE SET
102 data = excluded.data,
103 updated_at = CURRENT_TIMESTAMP
104 `, sess.AccountDID.String(), sess.SessionID, string(dataJSON))
105 return err
106}
107
108// DeleteSession removes an OAuth session.
109func (s *Store) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
110 s.mu.Lock()
111 defer s.mu.Unlock()
112
113 _, err := s.db.ExecContext(ctx,
114 `DELETE FROM oauth_sessions WHERE did = ? AND session_id = ?`,
115 did.String(), sessionID,
116 )
117 return err
118}
119
120// GetAuthRequestInfo retrieves pending auth request data by state token.
121func (s *Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
122 s.mu.RLock()
123 defer s.mu.RUnlock()
124
125 var dataJSON string
126 err := s.db.QueryRowContext(ctx,
127 `SELECT data FROM oauth_auth_requests WHERE state = ?`,
128 state,
129 ).Scan(&dataJSON)
130
131 if errors.Is(err, sql.ErrNoRows) {
132 return nil, ErrStateNotFound
133 }
134 if err != nil {
135 return nil, fmt.Errorf("query auth request: %w", err)
136 }
137
138 var data oauth.AuthRequestData
139 if err := json.Unmarshal([]byte(dataJSON), &data); err != nil {
140 return nil, fmt.Errorf("unmarshal auth request: %w", err)
141 }
142 return &data, nil
143}
144
145// SaveAuthRequestInfo stores auth request data keyed by state token.
146func (s *Store) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
147 s.mu.Lock()
148 defer s.mu.Unlock()
149
150 dataJSON, err := json.Marshal(info)
151 if err != nil {
152 return fmt.Errorf("marshal auth request: %w", err)
153 }
154
155 _, err = s.db.ExecContext(ctx,
156 `INSERT INTO oauth_auth_requests (state, data) VALUES (?, ?)`,
157 info.State, string(dataJSON),
158 )
159 return err
160}
161
162// DeleteAuthRequestInfo removes auth request data by state token.
163func (s *Store) DeleteAuthRequestInfo(ctx context.Context, state string) error {
164 s.mu.Lock()
165 defer s.mu.Unlock()
166
167 _, err := s.db.ExecContext(ctx,
168 `DELETE FROM oauth_auth_requests WHERE state = ?`,
169 state,
170 )
171 return err
172}
173
174// CleanupExpiredRequests removes auth requests older than the given duration.
175// Call this periodically (e.g. every 10 minutes) to prevent unbounded growth.
176func (s *Store) CleanupExpiredRequests(ctx context.Context, olderThanMinutes int) error {
177 s.mu.Lock()
178 defer s.mu.Unlock()
179
180 _, err := s.db.ExecContext(ctx,
181 `DELETE FROM oauth_auth_requests WHERE created_at < datetime('now', ?)`,
182 fmt.Sprintf("-%d minutes", olderThanMinutes),
183 )
184 return err
185}