forked from
tangled.org/core
this repo has no description
1package oauth
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "net/http"
11 "slices"
12 "strings"
13 "time"
14
15 comatproto "github.com/bluesky-social/indigo/api/atproto"
16 atpclient "github.com/bluesky-social/indigo/atproto/atclient"
17 "github.com/bluesky-social/indigo/atproto/auth/oauth"
18 lexutil "github.com/bluesky-social/indigo/lex/util"
19 xrpc "github.com/bluesky-social/indigo/xrpc"
20 "github.com/go-chi/chi/v5"
21 "github.com/posthog/posthog-go"
22 "tangled.org/core/api/tangled"
23 "tangled.org/core/appview/db"
24 "tangled.org/core/appview/models"
25 "tangled.org/core/consts"
26 "tangled.org/core/idresolver"
27 "tangled.org/core/orm"
28 "tangled.org/core/tid"
29)
30
31func (o *OAuth) Router() http.Handler {
32 r := chi.NewRouter()
33
34 r.Get("/oauth/client-metadata.json", o.clientMetadata)
35 r.Get("/oauth/jwks.json", o.jwks)
36 r.Get("/oauth/callback", o.callback)
37 return r
38}
39
40func (o *OAuth) clientMetadata(w http.ResponseWriter, r *http.Request) {
41 doc := o.ClientApp.Config.ClientMetadata()
42 doc.JWKSURI = &o.JwksUri
43 doc.ClientName = &o.ClientName
44 doc.ClientURI = &o.ClientUri
45 doc.Scope = doc.Scope + " identity:handle"
46
47 w.Header().Set("Content-Type", "application/json")
48 if err := json.NewEncoder(w).Encode(doc); err != nil {
49 http.Error(w, err.Error(), http.StatusInternalServerError)
50 return
51 }
52}
53
54func (o *OAuth) jwks(w http.ResponseWriter, r *http.Request) {
55 w.Header().Set("Content-Type", "application/json")
56 body := o.ClientApp.Config.PublicJWKS()
57 if err := json.NewEncoder(w).Encode(body); err != nil {
58 http.Error(w, err.Error(), http.StatusInternalServerError)
59 return
60 }
61}
62
63func (o *OAuth) callback(w http.ResponseWriter, r *http.Request) {
64 ctx := r.Context()
65 l := o.Logger.With("query", r.URL.Query())
66
67 redirectURL := o.GetAuthReturn(r)
68 _ = o.ClearAuthReturn(w, r)
69
70 sessData, err := o.ClientApp.ProcessCallback(ctx, r.URL.Query())
71 if err != nil {
72 var callbackErr *oauth.AuthRequestCallbackError
73 if errors.As(err, &callbackErr) {
74 l.Debug("callback error", "err", callbackErr)
75 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", callbackErr.ErrorCode), http.StatusFound)
76 return
77 }
78 l.Error("failed to process callback", "err", err)
79 http.Redirect(w, r, "/login?error=oauth", http.StatusFound)
80 return
81 }
82
83 if err := o.SaveSession(w, r, sessData); err != nil {
84 l.Error("failed to save session", "data", sessData, "err", err)
85 errorCode := "session"
86 if errors.Is(err, ErrMaxAccountsReached) {
87 errorCode = "max_accounts"
88 }
89 http.Redirect(w, r, fmt.Sprintf("/login?error=%s", errorCode), http.StatusFound)
90 return
91 }
92
93 o.Logger.Debug("session saved successfully")
94
95 go o.addToDefaultKnot(sessData.AccountDID.String())
96 go o.addToDefaultSpindle(sessData.AccountDID.String())
97 go o.ensureTangledProfile(sessData)
98 go o.autoClaimTnglShDomain(sessData.AccountDID.String())
99 go o.drainPdsRewrites(sessData)
100
101 if !o.Config.Core.Dev {
102 err = o.Posthog.Enqueue(posthog.Capture{
103 DistinctId: sessData.AccountDID.String(),
104 Event: "signin",
105 })
106 if err != nil {
107 o.Logger.Error("failed to enqueue posthog event", "err", err)
108 }
109 }
110
111 if redirectURL == "" {
112 redirectURL = "/"
113 }
114
115 if o.isAccountDeactivated(sessData) {
116 redirectURL = "/settings/profile"
117 }
118
119 http.Redirect(w, r, redirectURL, http.StatusFound)
120}
121
122func (o *OAuth) isAccountDeactivated(sessData *oauth.ClientSessionData) bool {
123 pdsClient := &xrpc.Client{
124 Host: sessData.HostURL,
125 Client: &http.Client{Timeout: 5 * time.Second},
126 }
127
128 _, err := comatproto.RepoDescribeRepo(
129 context.Background(),
130 pdsClient,
131 sessData.AccountDID.String(),
132 )
133 if err == nil {
134 return false
135 }
136
137 var xrpcErr *xrpc.Error
138 var xrpcBody *xrpc.XRPCError
139 return errors.As(err, &xrpcErr) &&
140 errors.As(xrpcErr.Wrapped, &xrpcBody) &&
141 xrpcBody.ErrStr == "RepoDeactivated"
142}
143
144func (o *OAuth) addToDefaultSpindle(did string) {
145 l := o.Logger.With("subject", did)
146
147 // use the tangled.sh app password to get an accessJwt
148 // and create an sh.tangled.spindle.member record with that
149 spindleMembers, err := db.GetSpindleMembers(
150 o.Db,
151 orm.FilterEq("instance", "spindle.tangled.sh"),
152 orm.FilterEq("subject", did),
153 )
154 if err != nil {
155 l.Error("failed to get spindle members", "err", err)
156 return
157 }
158
159 if len(spindleMembers) != 0 {
160 l.Warn("already a member of the default spindle")
161 return
162 }
163
164 l.Debug("adding to default spindle")
165 session, err := o.getAppPasswordSession()
166 if err != nil {
167 l.Error("failed to create session", "err", err)
168 return
169 }
170
171 record := tangled.SpindleMember{
172 LexiconTypeID: tangled.SpindleMemberNSID,
173 Subject: did,
174 Instance: consts.DefaultSpindle,
175 CreatedAt: time.Now().Format(time.RFC3339),
176 }
177
178 if err := session.putRecord(record, tangled.SpindleMemberNSID); err != nil {
179 l.Error("failed to add to default spindle", "err", err)
180 return
181 }
182
183 l.Debug("successfully added to default spindle", "did", did)
184}
185
186func (o *OAuth) addToDefaultKnot(did string) {
187 l := o.Logger.With("subject", did)
188
189 // use the tangled.sh app password to get an accessJwt
190 // and create an sh.tangled.spindle.member record with that
191
192 allKnots, err := o.Enforcer.GetKnotsForUser(did)
193 if err != nil {
194 l.Error("failed to get knot members for did", "err", err)
195 return
196 }
197
198 if slices.Contains(allKnots, consts.DefaultKnot) {
199 l.Warn("already a member of the default knot")
200 return
201 }
202
203 l.Debug("adding to default knot")
204 session, err := o.getAppPasswordSession()
205 if err != nil {
206 l.Error("failed to create session", "err", err)
207 return
208 }
209
210 record := tangled.KnotMember{
211 LexiconTypeID: tangled.KnotMemberNSID,
212 Subject: did,
213 Domain: consts.DefaultKnot,
214 CreatedAt: time.Now().Format(time.RFC3339),
215 }
216
217 if err := session.putRecord(record, tangled.KnotMemberNSID); err != nil {
218 l.Error("failed to add to default knot", "err", err)
219 return
220 }
221
222 if err := o.Enforcer.AddKnotMember(consts.DefaultKnot, did); err != nil {
223 l.Error("failed to set up enforcer rules", "err", err)
224 return
225 }
226
227 l.Debug("successfully added to default knot")
228}
229
230func (o *OAuth) ensureTangledProfile(sessData *oauth.ClientSessionData) {
231 ctx := context.Background()
232 did := sessData.AccountDID.String()
233 l := o.Logger.With("did", did)
234
235 profile, _ := db.GetProfile(o.Db, did)
236 if profile != nil {
237 l.Debug("profile already exists in DB")
238 return
239 }
240
241 l.Debug("creating empty Tangled profile")
242
243 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID)
244 if err != nil {
245 l.Error("failed to resume session for profile creation", "err", err)
246 return
247 }
248 client := sess.APIClient()
249
250 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{
251 Collection: tangled.ActorProfileNSID,
252 Repo: did,
253 Rkey: "self",
254 Record: &lexutil.LexiconTypeDecoder{Val: &tangled.ActorProfile{}},
255 })
256
257 if err != nil {
258 l.Error("failed to create empty profile on PDS", "err", err)
259 return
260 }
261
262 tx, err := o.Db.BeginTx(ctx, nil)
263 if err != nil {
264 l.Error("failed to start transaction", "err", err)
265 return
266 }
267
268 emptyProfile := &models.Profile{Did: did}
269 if err := db.UpsertProfile(tx, emptyProfile); err != nil {
270 l.Error("failed to create empty profile in DB", "err", err)
271 return
272 }
273
274 l.Debug("successfully created empty Tangled profile on PDS and DB")
275}
276
277func (o *OAuth) drainPdsRewrites(sessData *oauth.ClientSessionData) {
278 ctx := context.Background()
279 did := sessData.AccountDID.String()
280 l := o.Logger.With("did", did, "handler", "drainPdsRewrites")
281
282 rewrites, err := db.GetPendingPdsRewrites(o.Db, did)
283 if err != nil {
284 l.Error("failed to get pending rewrites", "err", err)
285 return
286 }
287 if len(rewrites) == 0 {
288 return
289 }
290
291 l.Info("draining pending PDS rewrites", "count", len(rewrites))
292
293 sess, err := o.ClientApp.ResumeSession(ctx, sessData.AccountDID, sessData.SessionID)
294 if err != nil {
295 l.Error("failed to resume session for PDS rewrites", "err", err)
296 return
297 }
298 client := sess.APIClient()
299
300 for _, rw := range rewrites {
301 if err := o.rewritePdsRecord(ctx, client, did, rw); err != nil {
302 l.Error("failed to rewrite PDS record",
303 "nsid", rw.RecordNsid,
304 "rkey", rw.RecordRkey,
305 "repo_did", rw.RepoDid,
306 "err", err)
307 continue
308 }
309
310 if err := db.CompletePdsRewrite(o.Db, rw.Id); err != nil {
311 l.Error("failed to mark rewrite complete", "id", rw.Id, "err", err)
312 }
313 }
314}
315
316func (o *OAuth) rewritePdsRecord(ctx context.Context, client *atpclient.APIClient, userDid string, rw db.PdsRewrite) error {
317 ex, err := comatproto.RepoGetRecord(ctx, client, "", rw.RecordNsid, userDid, rw.RecordRkey)
318 if err != nil {
319 return fmt.Errorf("get record: %w", err)
320 }
321
322 val := ex.Value.Val
323 repoDid := rw.RepoDid
324
325 switch rw.RecordNsid {
326 case tangled.RepoNSID:
327 rec, ok := val.(*tangled.Repo)
328 if !ok {
329 return fmt.Errorf("unexpected type for repo record")
330 }
331 rec.RepoDid = &repoDid
332
333 case tangled.RepoIssueNSID:
334 rec, ok := val.(*tangled.RepoIssue)
335 if !ok {
336 return fmt.Errorf("unexpected type for issue record")
337 }
338 rec.RepoDid = &repoDid
339
340 case tangled.RepoPullNSID:
341 rec, ok := val.(*tangled.RepoPull)
342 if !ok {
343 return fmt.Errorf("unexpected type for pull record")
344 }
345 if rec.Target != nil {
346 rec.Target.RepoDid = &repoDid
347 }
348 if rec.Source != nil && rec.Source.Repo != nil && *rec.Source.Repo == rw.OldRepoAt {
349 rec.Source.RepoDid = &repoDid
350 }
351
352 case tangled.RepoCollaboratorNSID:
353 rec, ok := val.(*tangled.RepoCollaborator)
354 if !ok {
355 return fmt.Errorf("unexpected type for collaborator record")
356 }
357 rec.RepoDid = &repoDid
358
359 case tangled.RepoArtifactNSID:
360 rec, ok := val.(*tangled.RepoArtifact)
361 if !ok {
362 return fmt.Errorf("unexpected type for artifact record")
363 }
364 rec.RepoDid = &repoDid
365
366 case tangled.FeedStarNSID:
367 rec, ok := val.(*tangled.FeedStar)
368 if !ok {
369 return fmt.Errorf("unexpected type for star record")
370 }
371 rec.SubjectDid = &repoDid
372
373 case tangled.ActorProfileNSID:
374 rec, ok := val.(*tangled.ActorProfile)
375 if !ok {
376 return fmt.Errorf("unexpected type for profile record")
377 }
378 rewritten := make([]string, 0, len(rec.PinnedRepositories))
379 for _, pin := range rec.PinnedRepositories {
380 if strings.HasPrefix(pin, "did:") {
381 rewritten = append(rewritten, pin)
382 continue
383 }
384 repo, repoErr := db.GetRepoByAtUri(o.Db, pin)
385 if repoErr != nil || repo.RepoDid == "" {
386 rewritten = append(rewritten, pin)
387 continue
388 }
389 rewritten = append(rewritten, repo.RepoDid)
390 }
391 rec.PinnedRepositories = rewritten
392
393 default:
394 return fmt.Errorf("unsupported NSID for PDS rewrite: %s", rw.RecordNsid)
395 }
396
397 _, err = comatproto.RepoPutRecord(ctx, client, &comatproto.RepoPutRecord_Input{
398 Collection: rw.RecordNsid,
399 Repo: userDid,
400 Rkey: rw.RecordRkey,
401 SwapRecord: ex.Cid,
402 Record: &lexutil.LexiconTypeDecoder{Val: val},
403 })
404 if err != nil {
405 return fmt.Errorf("put record: %w", err)
406 }
407
408 return nil
409}
410
411// create a AppPasswordSession using apppasswords
412type AppPasswordSession struct {
413 AccessJwt string `json:"accessJwt"`
414 RefreshJwt string `json:"refreshJwt"`
415 PdsEndpoint string
416 Did string
417 Logger *slog.Logger
418 ExpiresAt time.Time
419}
420
421func CreateAppPasswordSession(res *idresolver.Resolver, appPassword, did string, logger *slog.Logger) (*AppPasswordSession, error) {
422 if appPassword == "" {
423 return nil, fmt.Errorf("no app password configured")
424 }
425
426 resolved, err := res.ResolveIdent(context.Background(), did)
427 if err != nil {
428 return nil, fmt.Errorf("failed to resolve tangled.sh DID %s: %v", did, err)
429 }
430
431 pdsEndpoint := resolved.PDSEndpoint()
432 if pdsEndpoint == "" {
433 return nil, fmt.Errorf("no PDS endpoint found for tangled.sh DID %s", did)
434 }
435
436 sessionPayload := map[string]string{
437 "identifier": did,
438 "password": appPassword,
439 }
440 sessionBytes, err := json.Marshal(sessionPayload)
441 if err != nil {
442 return nil, fmt.Errorf("failed to marshal session payload: %v", err)
443 }
444
445 sessionURL := pdsEndpoint + "/xrpc/com.atproto.server.createSession"
446 sessionReq, err := http.NewRequestWithContext(context.Background(), "POST", sessionURL, bytes.NewBuffer(sessionBytes))
447 if err != nil {
448 return nil, fmt.Errorf("failed to create session request: %v", err)
449 }
450 sessionReq.Header.Set("Content-Type", "application/json")
451
452 logger.Debug("creating app password session", "url", sessionURL, "headers", sessionReq.Header)
453
454 client := &http.Client{Timeout: 30 * time.Second}
455 sessionResp, err := client.Do(sessionReq)
456 if err != nil {
457 return nil, fmt.Errorf("failed to create session: %v", err)
458 }
459 defer sessionResp.Body.Close()
460
461 if sessionResp.StatusCode != http.StatusOK {
462 return nil, fmt.Errorf("failed to create session: HTTP %d", sessionResp.StatusCode)
463 }
464
465 var session AppPasswordSession
466 if err := json.NewDecoder(sessionResp.Body).Decode(&session); err != nil {
467 return nil, fmt.Errorf("failed to decode session response: %v", err)
468 }
469
470 session.PdsEndpoint = pdsEndpoint
471 session.Did = did
472 session.Logger = logger
473 session.ExpiresAt = time.Now().Add(115 * time.Minute)
474
475 return &session, nil
476}
477
478func (s *AppPasswordSession) RefreshSession() error {
479 refreshURL := s.PdsEndpoint + "/xrpc/com.atproto.server.refreshSession"
480 req, err := http.NewRequestWithContext(context.Background(), "POST", refreshURL, nil)
481 if err != nil {
482 return fmt.Errorf("failed to create refresh request: %w", err)
483 }
484
485 req.Header.Set("Authorization", "Bearer "+s.RefreshJwt)
486
487 s.Logger.Debug("refreshing app password session", "url", refreshURL)
488
489 client := &http.Client{Timeout: 30 * time.Second}
490 resp, err := client.Do(req)
491 if err != nil {
492 return fmt.Errorf("failed to refresh session: %w", err)
493 }
494 defer resp.Body.Close()
495
496 if resp.StatusCode != http.StatusOK {
497 var errorResponse map[string]any
498 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil {
499 return fmt.Errorf("failed to refresh session: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err)
500 }
501 errorBytes, _ := json.Marshal(errorResponse)
502 return fmt.Errorf("failed to refresh session: HTTP %d, response: %s", resp.StatusCode, string(errorBytes))
503 }
504
505 var refreshResponse struct {
506 AccessJwt string `json:"accessJwt"`
507 RefreshJwt string `json:"refreshJwt"`
508 }
509 if err := json.NewDecoder(resp.Body).Decode(&refreshResponse); err != nil {
510 return fmt.Errorf("failed to decode refresh response: %w", err)
511 }
512
513 s.AccessJwt = refreshResponse.AccessJwt
514 s.RefreshJwt = refreshResponse.RefreshJwt
515 // Set new expiry time with 5 minute buffer
516 s.ExpiresAt = time.Now().Add(115 * time.Minute)
517
518 s.Logger.Debug("successfully refreshed app password session")
519 return nil
520}
521
522func (s *AppPasswordSession) IsValid() bool {
523 return time.Now().Before(s.ExpiresAt)
524}
525
526func (s *AppPasswordSession) putRecord(record any, collection string) error {
527 if !s.IsValid() {
528 s.Logger.Debug("access token expired, refreshing session")
529 if err := s.RefreshSession(); err != nil {
530 return fmt.Errorf("failed to refresh session: %w", err)
531 }
532 s.Logger.Debug("session refreshed")
533 }
534
535 recordBytes, err := json.Marshal(record)
536 if err != nil {
537 return fmt.Errorf("failed to marshal knot member record: %w", err)
538 }
539
540 payload := map[string]any{
541 "repo": s.Did,
542 "collection": collection,
543 "rkey": tid.TID(),
544 "record": json.RawMessage(recordBytes),
545 }
546
547 payloadBytes, err := json.Marshal(payload)
548 if err != nil {
549 return fmt.Errorf("failed to marshal request payload: %w", err)
550 }
551
552 url := s.PdsEndpoint + "/xrpc/com.atproto.repo.putRecord"
553 req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(payloadBytes))
554 if err != nil {
555 return fmt.Errorf("failed to create HTTP request: %w", err)
556 }
557
558 req.Header.Set("Content-Type", "application/json")
559 req.Header.Set("Authorization", "Bearer "+s.AccessJwt)
560
561 s.Logger.Debug("putting record", "url", url, "collection", collection)
562
563 client := &http.Client{Timeout: 30 * time.Second}
564 resp, err := client.Do(req)
565 if err != nil {
566 return fmt.Errorf("failed to add user to default service: %w", err)
567 }
568 defer resp.Body.Close()
569
570 if resp.StatusCode != http.StatusOK {
571 var errorResponse map[string]any
572 if err := json.NewDecoder(resp.Body).Decode(&errorResponse); err != nil {
573 return fmt.Errorf("failed to add user to default service: HTTP %d (failed to decode error response: %w)", resp.StatusCode, err)
574 }
575 return fmt.Errorf("failed to add user to default service: HTTP %d, response: %v", resp.StatusCode, errorResponse)
576 }
577
578 return nil
579}
580
581// autoClaimTnglShDomain checks if the user has a .tngl.sh handle and, if so,
582// ensures their corresponding sites domain is claimed. This is idempotent —
583// ClaimDomain is a no-op if the claim already exists.
584func (o *OAuth) autoClaimTnglShDomain(did string) {
585 l := o.Logger.With("did", did)
586
587 pdsDomain := strings.TrimPrefix(o.Config.Pds.Host, "https://")
588 pdsDomain = strings.TrimPrefix(pdsDomain, "http://")
589
590 resolved, err := o.IdResolver.ResolveIdent(context.Background(), did)
591 if err != nil {
592 l.Error("autoClaimTnglShDomain: failed to resolve ident", "err", err)
593 return
594 }
595
596 handle := resolved.Handle.String()
597 if !strings.HasSuffix(handle, "."+pdsDomain) {
598 return
599 }
600
601 if err := db.ClaimDomain(o.Db, did, handle); err != nil {
602 l.Warn("autoClaimTnglShDomain: failed to claim domain", "domain", handle, "err", err)
603 } else {
604 l.Info("autoClaimTnglShDomain: claimed domain", "domain", handle)
605 }
606}
607
608// getAppPasswordSession returns a cached AppPasswordSession, creating one if needed.
609func (o *OAuth) getAppPasswordSession() (*AppPasswordSession, error) {
610 o.appPasswordSessionMu.Lock()
611 defer o.appPasswordSessionMu.Unlock()
612
613 if o.appPasswordSession != nil {
614 return o.appPasswordSession, nil
615 }
616
617 session, err := CreateAppPasswordSession(o.IdResolver, o.Config.Core.AppPassword, consts.TangledDid, o.Logger)
618 if err != nil {
619 return nil, err
620 }
621
622 o.appPasswordSession = session
623 return session, nil
624}