forked from
tangled.org/core
Monorepo for Tangled
1package oauth
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "net/http"
9 "net/url"
10 "sync"
11 "time"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/atclient"
15 "github.com/bluesky-social/indigo/atproto/atcrypto"
16 "github.com/bluesky-social/indigo/atproto/auth/oauth"
17 "github.com/bluesky-social/indigo/atproto/syntax"
18 xrpc "github.com/bluesky-social/indigo/xrpc"
19 "github.com/gorilla/sessions"
20 "github.com/posthog/posthog-go"
21 "tangled.org/core/appview/config"
22 "tangled.org/core/appview/db"
23 "tangled.org/core/idresolver"
24 "tangled.org/core/rbac"
25)
26
27type OAuth struct {
28 ClientApp *oauth.ClientApp
29 SessStore *sessions.CookieStore
30 Config *config.Config
31 JwksUri string
32 ClientName string
33 ClientUri string
34 Posthog posthog.Client
35 Db *db.DB
36 Enforcer *rbac.Enforcer
37 IdResolver *idresolver.Resolver
38 Logger *slog.Logger
39
40 appPasswordSession *AppPasswordSession
41 appPasswordSessionMu sync.Mutex
42}
43
44func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
45 var oauthConfig oauth.ClientConfig
46 var clientUri string
47 if config.Core.Dev {
48 clientUri = "http://127.0.0.1:3000"
49 callbackUri := clientUri + "/oauth/callback"
50 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes)
51 } else {
52 clientUri = "https://" + config.Core.AppviewHost
53 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
54 callbackUri := clientUri + "/oauth/callback"
55 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes)
56 }
57
58 // configure client secret
59 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
60 if err != nil {
61 return nil, err
62 }
63 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
64 return nil, err
65 }
66
67 jwksUri := clientUri + "/oauth/jwks.json"
68
69 authStore, err := NewRedisStore(&RedisStoreConfig{
70 RedisURL: config.Redis.ToURL(),
71 SessionExpiryDuration: time.Hour * 24 * 90,
72 SessionInactivityDuration: time.Hour * 24 * 14,
73 AuthRequestExpiryDuration: time.Minute * 30,
74 })
75 if err != nil {
76 return nil, err
77 }
78
79 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
80
81 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
82 clientApp.Dir = res.Directory()
83 // allow non-public transports in dev mode
84 if config.Core.Dev {
85 clientApp.Resolver.Client.Transport = http.DefaultTransport
86 }
87
88 clientName := config.Core.AppviewName
89
90 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential())
91 return &OAuth{
92 ClientApp: clientApp,
93 Config: config,
94 SessStore: sessStore,
95 JwksUri: jwksUri,
96 ClientName: clientName,
97 ClientUri: clientUri,
98 Posthog: ph,
99 Db: db,
100 Enforcer: enforcer,
101 IdResolver: res,
102 Logger: logger,
103 }, nil
104}
105
106func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
107 userSession, err := o.SessStore.Get(r, SessionName)
108 if err != nil {
109 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err)
110 }
111
112 userSession.Values[SessionDid] = sessData.AccountDID.String()
113 userSession.Values[SessionPds] = sessData.HostURL
114 userSession.Values[SessionId] = sessData.SessionID
115 userSession.Values[SessionAuthenticated] = true
116
117 if err := userSession.Save(r, w); err != nil {
118 return err
119 }
120
121 handle := ""
122 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String())
123 if err == nil && resolved.Handle.String() != "" {
124 handle = resolved.Handle.String()
125 }
126
127 registry := o.GetAccounts(r)
128 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil {
129 return err
130 }
131 return o.saveAccounts(w, r, registry)
132}
133
134func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
135 userSession, err := o.SessStore.Get(r, SessionName)
136 if err != nil {
137 return nil, fmt.Errorf("error getting user session: %w", err)
138 }
139 if userSession.IsNew {
140 return nil, fmt.Errorf("no session available for user")
141 }
142
143 d := userSession.Values[SessionDid].(string)
144 sessDid, err := syntax.ParseDID(d)
145 if err != nil {
146 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
147 }
148
149 sessId := userSession.Values[SessionId].(string)
150
151 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
152 if err != nil {
153 return nil, fmt.Errorf("failed to resume session: %w", err)
154 }
155
156 return clientSess, nil
157}
158
159func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
160 userSession, err := o.SessStore.Get(r, SessionName)
161 if err != nil {
162 return fmt.Errorf("error getting user session: %w", err)
163 }
164 if userSession.IsNew {
165 return fmt.Errorf("no session available for user")
166 }
167
168 d := userSession.Values[SessionDid].(string)
169 sessDid, err := syntax.ParseDID(d)
170 if err != nil {
171 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
172 }
173
174 sessId := userSession.Values[SessionId].(string)
175
176 // delete the session
177 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
178 if err1 != nil {
179 err1 = fmt.Errorf("failed to logout: %w", err1)
180 }
181
182 // remove the cookie
183 userSession.Options.MaxAge = -1
184 err2 := o.SessStore.Save(r, w, userSession)
185 if err2 != nil {
186 err2 = fmt.Errorf("failed to save into session store: %w", err2)
187 }
188
189 return errors.Join(err1, err2)
190}
191
192func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
193 registry := o.GetAccounts(r)
194 account := registry.FindAccount(targetDid)
195 if account == nil {
196 return fmt.Errorf("account not found in registry: %s", targetDid)
197 }
198
199 did, err := syntax.ParseDID(targetDid)
200 if err != nil {
201 return fmt.Errorf("invalid DID: %w", err)
202 }
203
204 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId)
205 if err != nil {
206 registry.RemoveAccount(targetDid)
207 _ = o.saveAccounts(w, r, registry)
208 return fmt.Errorf("session expired for account: %w", err)
209 }
210
211 userSession, err := o.SessStore.Get(r, SessionName)
212 if err != nil {
213 return err
214 }
215
216 userSession.Values[SessionDid] = sess.Data.AccountDID.String()
217 userSession.Values[SessionPds] = sess.Data.HostURL
218 userSession.Values[SessionId] = sess.Data.SessionID
219 userSession.Values[SessionAuthenticated] = true
220
221 return userSession.Save(r, w)
222}
223
224func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error {
225 registry := o.GetAccounts(r)
226 account := registry.FindAccount(targetDid)
227 if account == nil {
228 return nil
229 }
230
231 did, err := syntax.ParseDID(targetDid)
232 if err == nil {
233 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId)
234 }
235
236 registry.RemoveAccount(targetDid)
237 return o.saveAccounts(w, r, registry)
238}
239
240func (o *OAuth) GetDid(r *http.Request) string {
241 if u := o.GetMultiAccountUser(r); u != nil {
242 return u.Did
243 }
244
245 return ""
246}
247
248func (o *OAuth) AuthorizedClient(r *http.Request) (*atclient.APIClient, error) {
249 session, err := o.ResumeSession(r)
250 if err != nil {
251 return nil, fmt.Errorf("error getting session: %w", err)
252 }
253 return session.APIClient(), nil
254}
255
256// this is a higher level abstraction on ServerGetServiceAuth
257type ServiceClientOpts struct {
258 service string
259 exp int64
260 lxm string
261 dev bool
262 timeout time.Duration
263}
264
265type ServiceClientOpt func(*ServiceClientOpts)
266
267func DefaultServiceClientOpts() ServiceClientOpts {
268 return ServiceClientOpts{
269 timeout: time.Second * 5,
270 }
271}
272
273func WithService(service string) ServiceClientOpt {
274 return func(s *ServiceClientOpts) {
275 s.service = service
276 }
277}
278
279// Specify the Duration in seconds for the expiry of this token
280//
281// The time of expiry is calculated as time.Now().Unix() + exp
282func WithExp(exp int64) ServiceClientOpt {
283 return func(s *ServiceClientOpts) {
284 s.exp = time.Now().Unix() + exp
285 }
286}
287
288func WithLxm(lxm string) ServiceClientOpt {
289 return func(s *ServiceClientOpts) {
290 s.lxm = lxm
291 }
292}
293
294func WithDev(dev bool) ServiceClientOpt {
295 return func(s *ServiceClientOpts) {
296 s.dev = dev
297 }
298}
299
300func WithTimeout(timeout time.Duration) ServiceClientOpt {
301 return func(s *ServiceClientOpts) {
302 s.timeout = timeout
303 }
304}
305
306func (s *ServiceClientOpts) Audience() string {
307 return fmt.Sprintf("did:web:%s", s.service)
308}
309
310func (s *ServiceClientOpts) Host() string {
311 scheme := "https://"
312 if s.dev {
313 scheme = "http://"
314 }
315
316 return scheme + s.service
317}
318
319func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
320 opts := DefaultServiceClientOpts()
321 for _, o := range os {
322 o(&opts)
323 }
324
325 client, err := o.AuthorizedClient(r)
326 if err != nil {
327 return nil, err
328 }
329
330 // force expiry to atleast 60 seconds in the future
331 sixty := time.Now().Unix() + 60
332 if opts.exp < sixty {
333 opts.exp = sixty
334 }
335
336 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
337 if err != nil {
338 return nil, err
339 }
340
341 return &xrpc.Client{
342 Auth: &xrpc.AuthInfo{
343 AccessJwt: resp.Token,
344 },
345 Host: opts.Host(),
346 Client: &http.Client{
347 Timeout: opts.timeout,
348 },
349 }, nil
350}
351
352func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) {
353 parsedDid, err := syntax.ParseDID(did)
354 if err != nil {
355 return "", fmt.Errorf("invalid DID: %w", err)
356 }
357
358 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier())
359 if err != nil {
360 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err)
361 }
362
363 host := ident.PDSEndpoint()
364 if host == "" {
365 return "", fmt.Errorf("identity does not link to an atproto host (PDS)")
366 }
367
368 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host)
369 if err != nil {
370 return "", fmt.Errorf("resolving auth server: %w", err)
371 }
372
373 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL)
374 if err != nil {
375 return "", fmt.Errorf("fetching auth server metadata: %w", err)
376 }
377
378 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes))
379 scopes = append(scopes, TangledScopes...)
380 scopes = append(scopes, extraScopes...)
381
382 loginHint := did
383 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() {
384 loginHint = ident.Handle.String()
385 }
386
387 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint)
388 if err != nil {
389 return "", fmt.Errorf("auth request failed: %w", err)
390 }
391
392 info.AccountDID = &parsedDid
393 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info)
394
395 if err := o.SetAuthReturn(w, r, returnURL); err != nil {
396 return "", fmt.Errorf("failed to set auth return: %w", err)
397 }
398
399 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s",
400 authserverMeta.AuthorizationEndpoint,
401 url.QueryEscape(o.ClientApp.Config.ClientID),
402 url.QueryEscape(info.RequestURI),
403 )
404
405 return redirectURL, nil
406}