Monorepo for Tangled
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at master 406 lines 11 kB view raw
1package oauth 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "net/http" 9 "net/url" 10 "sync" 11 "time" 12 13 comatproto "github.com/bluesky-social/indigo/api/atproto" 14 "github.com/bluesky-social/indigo/atproto/atclient" 15 "github.com/bluesky-social/indigo/atproto/atcrypto" 16 "github.com/bluesky-social/indigo/atproto/auth/oauth" 17 "github.com/bluesky-social/indigo/atproto/syntax" 18 xrpc "github.com/bluesky-social/indigo/xrpc" 19 "github.com/gorilla/sessions" 20 "github.com/posthog/posthog-go" 21 "tangled.org/core/appview/config" 22 "tangled.org/core/appview/db" 23 "tangled.org/core/idresolver" 24 "tangled.org/core/rbac" 25) 26 27type OAuth struct { 28 ClientApp *oauth.ClientApp 29 SessStore *sessions.CookieStore 30 Config *config.Config 31 JwksUri string 32 ClientName string 33 ClientUri string 34 Posthog posthog.Client 35 Db *db.DB 36 Enforcer *rbac.Enforcer 37 IdResolver *idresolver.Resolver 38 Logger *slog.Logger 39 40 appPasswordSession *AppPasswordSession 41 appPasswordSessionMu sync.Mutex 42} 43 44func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) { 45 var oauthConfig oauth.ClientConfig 46 var clientUri string 47 if config.Core.Dev { 48 clientUri = "http://127.0.0.1:3000" 49 callbackUri := clientUri + "/oauth/callback" 50 oauthConfig = oauth.NewLocalhostConfig(callbackUri, TangledScopes) 51 } else { 52 clientUri = "https://" + config.Core.AppviewHost 53 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri) 54 callbackUri := clientUri + "/oauth/callback" 55 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, TangledScopes) 56 } 57 58 // configure client secret 59 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret) 60 if err != nil { 61 return nil, err 62 } 63 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil { 64 return nil, err 65 } 66 67 jwksUri := clientUri + "/oauth/jwks.json" 68 69 authStore, err := NewRedisStore(&RedisStoreConfig{ 70 RedisURL: config.Redis.ToURL(), 71 SessionExpiryDuration: time.Hour * 24 * 90, 72 SessionInactivityDuration: time.Hour * 24 * 14, 73 AuthRequestExpiryDuration: time.Minute * 30, 74 }) 75 if err != nil { 76 return nil, err 77 } 78 79 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret)) 80 81 clientApp := oauth.NewClientApp(&oauthConfig, authStore) 82 clientApp.Dir = res.Directory() 83 // allow non-public transports in dev mode 84 if config.Core.Dev { 85 clientApp.Resolver.Client.Transport = http.DefaultTransport 86 } 87 88 clientName := config.Core.AppviewName 89 90 logger.Info("oauth setup successfully", "IsConfidential", clientApp.Config.IsConfidential()) 91 return &OAuth{ 92 ClientApp: clientApp, 93 Config: config, 94 SessStore: sessStore, 95 JwksUri: jwksUri, 96 ClientName: clientName, 97 ClientUri: clientUri, 98 Posthog: ph, 99 Db: db, 100 Enforcer: enforcer, 101 IdResolver: res, 102 Logger: logger, 103 }, nil 104} 105 106func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error { 107 userSession, err := o.SessStore.Get(r, SessionName) 108 if err != nil { 109 o.Logger.Warn("failed to decode existing session cookie, will create new", "err", err) 110 } 111 112 userSession.Values[SessionDid] = sessData.AccountDID.String() 113 userSession.Values[SessionPds] = sessData.HostURL 114 userSession.Values[SessionId] = sessData.SessionID 115 userSession.Values[SessionAuthenticated] = true 116 117 if err := userSession.Save(r, w); err != nil { 118 return err 119 } 120 121 handle := "" 122 resolved, err := o.IdResolver.ResolveIdent(r.Context(), sessData.AccountDID.String()) 123 if err == nil && resolved.Handle.String() != "" { 124 handle = resolved.Handle.String() 125 } 126 127 registry := o.GetAccounts(r) 128 if err := registry.AddAccount(sessData.AccountDID.String(), handle, sessData.SessionID); err != nil { 129 return err 130 } 131 return o.saveAccounts(w, r, registry) 132} 133 134func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) { 135 userSession, err := o.SessStore.Get(r, SessionName) 136 if err != nil { 137 return nil, fmt.Errorf("error getting user session: %w", err) 138 } 139 if userSession.IsNew { 140 return nil, fmt.Errorf("no session available for user") 141 } 142 143 d := userSession.Values[SessionDid].(string) 144 sessDid, err := syntax.ParseDID(d) 145 if err != nil { 146 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 147 } 148 149 sessId := userSession.Values[SessionId].(string) 150 151 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId) 152 if err != nil { 153 return nil, fmt.Errorf("failed to resume session: %w", err) 154 } 155 156 return clientSess, nil 157} 158 159func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error { 160 userSession, err := o.SessStore.Get(r, SessionName) 161 if err != nil { 162 return fmt.Errorf("error getting user session: %w", err) 163 } 164 if userSession.IsNew { 165 return fmt.Errorf("no session available for user") 166 } 167 168 d := userSession.Values[SessionDid].(string) 169 sessDid, err := syntax.ParseDID(d) 170 if err != nil { 171 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err) 172 } 173 174 sessId := userSession.Values[SessionId].(string) 175 176 // delete the session 177 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId) 178 if err1 != nil { 179 err1 = fmt.Errorf("failed to logout: %w", err1) 180 } 181 182 // remove the cookie 183 userSession.Options.MaxAge = -1 184 err2 := o.SessStore.Save(r, w, userSession) 185 if err2 != nil { 186 err2 = fmt.Errorf("failed to save into session store: %w", err2) 187 } 188 189 return errors.Join(err1, err2) 190} 191 192func (o *OAuth) SwitchAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 193 registry := o.GetAccounts(r) 194 account := registry.FindAccount(targetDid) 195 if account == nil { 196 return fmt.Errorf("account not found in registry: %s", targetDid) 197 } 198 199 did, err := syntax.ParseDID(targetDid) 200 if err != nil { 201 return fmt.Errorf("invalid DID: %w", err) 202 } 203 204 sess, err := o.ClientApp.ResumeSession(r.Context(), did, account.SessionId) 205 if err != nil { 206 registry.RemoveAccount(targetDid) 207 _ = o.saveAccounts(w, r, registry) 208 return fmt.Errorf("session expired for account: %w", err) 209 } 210 211 userSession, err := o.SessStore.Get(r, SessionName) 212 if err != nil { 213 return err 214 } 215 216 userSession.Values[SessionDid] = sess.Data.AccountDID.String() 217 userSession.Values[SessionPds] = sess.Data.HostURL 218 userSession.Values[SessionId] = sess.Data.SessionID 219 userSession.Values[SessionAuthenticated] = true 220 221 return userSession.Save(r, w) 222} 223 224func (o *OAuth) RemoveAccount(w http.ResponseWriter, r *http.Request, targetDid string) error { 225 registry := o.GetAccounts(r) 226 account := registry.FindAccount(targetDid) 227 if account == nil { 228 return nil 229 } 230 231 did, err := syntax.ParseDID(targetDid) 232 if err == nil { 233 _ = o.ClientApp.Logout(r.Context(), did, account.SessionId) 234 } 235 236 registry.RemoveAccount(targetDid) 237 return o.saveAccounts(w, r, registry) 238} 239 240func (o *OAuth) GetDid(r *http.Request) string { 241 if u := o.GetMultiAccountUser(r); u != nil { 242 return u.Did 243 } 244 245 return "" 246} 247 248func (o *OAuth) AuthorizedClient(r *http.Request) (*atclient.APIClient, error) { 249 session, err := o.ResumeSession(r) 250 if err != nil { 251 return nil, fmt.Errorf("error getting session: %w", err) 252 } 253 return session.APIClient(), nil 254} 255 256// this is a higher level abstraction on ServerGetServiceAuth 257type ServiceClientOpts struct { 258 service string 259 exp int64 260 lxm string 261 dev bool 262 timeout time.Duration 263} 264 265type ServiceClientOpt func(*ServiceClientOpts) 266 267func DefaultServiceClientOpts() ServiceClientOpts { 268 return ServiceClientOpts{ 269 timeout: time.Second * 5, 270 } 271} 272 273func WithService(service string) ServiceClientOpt { 274 return func(s *ServiceClientOpts) { 275 s.service = service 276 } 277} 278 279// Specify the Duration in seconds for the expiry of this token 280// 281// The time of expiry is calculated as time.Now().Unix() + exp 282func WithExp(exp int64) ServiceClientOpt { 283 return func(s *ServiceClientOpts) { 284 s.exp = time.Now().Unix() + exp 285 } 286} 287 288func WithLxm(lxm string) ServiceClientOpt { 289 return func(s *ServiceClientOpts) { 290 s.lxm = lxm 291 } 292} 293 294func WithDev(dev bool) ServiceClientOpt { 295 return func(s *ServiceClientOpts) { 296 s.dev = dev 297 } 298} 299 300func WithTimeout(timeout time.Duration) ServiceClientOpt { 301 return func(s *ServiceClientOpts) { 302 s.timeout = timeout 303 } 304} 305 306func (s *ServiceClientOpts) Audience() string { 307 return fmt.Sprintf("did:web:%s", s.service) 308} 309 310func (s *ServiceClientOpts) Host() string { 311 scheme := "https://" 312 if s.dev { 313 scheme = "http://" 314 } 315 316 return scheme + s.service 317} 318 319func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) { 320 opts := DefaultServiceClientOpts() 321 for _, o := range os { 322 o(&opts) 323 } 324 325 client, err := o.AuthorizedClient(r) 326 if err != nil { 327 return nil, err 328 } 329 330 // force expiry to atleast 60 seconds in the future 331 sixty := time.Now().Unix() + 60 332 if opts.exp < sixty { 333 opts.exp = sixty 334 } 335 336 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm) 337 if err != nil { 338 return nil, err 339 } 340 341 return &xrpc.Client{ 342 Auth: &xrpc.AuthInfo{ 343 AccessJwt: resp.Token, 344 }, 345 Host: opts.Host(), 346 Client: &http.Client{ 347 Timeout: opts.timeout, 348 }, 349 }, nil 350} 351 352func (o *OAuth) StartElevatedAuthFlow(ctx context.Context, w http.ResponseWriter, r *http.Request, did string, extraScopes []string, returnURL string) (string, error) { 353 parsedDid, err := syntax.ParseDID(did) 354 if err != nil { 355 return "", fmt.Errorf("invalid DID: %w", err) 356 } 357 358 ident, err := o.ClientApp.Dir.Lookup(ctx, parsedDid.AtIdentifier()) 359 if err != nil { 360 return "", fmt.Errorf("failed to resolve DID (%s): %w", did, err) 361 } 362 363 host := ident.PDSEndpoint() 364 if host == "" { 365 return "", fmt.Errorf("identity does not link to an atproto host (PDS)") 366 } 367 368 authserverURL, err := o.ClientApp.Resolver.ResolveAuthServerURL(ctx, host) 369 if err != nil { 370 return "", fmt.Errorf("resolving auth server: %w", err) 371 } 372 373 authserverMeta, err := o.ClientApp.Resolver.ResolveAuthServerMetadata(ctx, authserverURL) 374 if err != nil { 375 return "", fmt.Errorf("fetching auth server metadata: %w", err) 376 } 377 378 scopes := make([]string, 0, len(TangledScopes)+len(extraScopes)) 379 scopes = append(scopes, TangledScopes...) 380 scopes = append(scopes, extraScopes...) 381 382 loginHint := did 383 if ident.Handle != "" && !ident.Handle.IsInvalidHandle() { 384 loginHint = ident.Handle.String() 385 } 386 387 info, err := o.ClientApp.SendAuthRequest(ctx, authserverMeta, scopes, loginHint) 388 if err != nil { 389 return "", fmt.Errorf("auth request failed: %w", err) 390 } 391 392 info.AccountDID = &parsedDid 393 o.ClientApp.Store.SaveAuthRequestInfo(ctx, *info) 394 395 if err := o.SetAuthReturn(w, r, returnURL); err != nil { 396 return "", fmt.Errorf("failed to set auth return: %w", err) 397 } 398 399 redirectURL := fmt.Sprintf("%s?client_id=%s&request_uri=%s", 400 authserverMeta.AuthorizationEndpoint, 401 url.QueryEscape(o.ClientApp.Config.ClientID), 402 url.QueryEscape(info.RequestURI), 403 ) 404 405 return redirectURL, nil 406}