audio streaming app plyr.fm
37
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: split backend monoliths into focused packages (#886)

- auth.py (1,400 lines) → auth/ package (8 modules)
- background_tasks.py (803 lines) → tasks/ package (5 domain modules)
- 5 *_client.py files → clients/ package
- extract upload pipeline into 7 named phase functions
- extract shared tag operations to utilities/tags.py
- remove 2 loq.toml line-count exemptions

424 tests pass, lint clean. all public APIs preserved via __init__.py re-exports.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

authored by

nate nowack
Claude Opus 4.6
and committed by
GitHub
4bc83742 42ae7300

+3134 -2684
-1400
backend/src/backend/_internal/auth.py
··· 1 - """OAuth 2.1 authentication and session management.""" 2 - 3 - import json 4 - import logging 5 - import secrets 6 - import time 7 - from dataclasses import dataclass 8 - from datetime import UTC, datetime, timedelta 9 - from typing import Annotated, Any 10 - 11 - from atproto_oauth import OAuthClient, OAuthState, PromptType 12 - from atproto_oauth.client import ( 13 - discover_authserver_from_pds_async, 14 - fetch_authserver_metadata_async, 15 - ) 16 - from atproto_oauth.dpop import DPoPManager 17 - from atproto_oauth.pkce import PKCEManager 18 - from atproto_oauth.stores.memory import MemorySessionStore 19 - from cryptography.fernet import Fernet 20 - from cryptography.hazmat.primitives.asymmetric import ec 21 - from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey 22 - from cryptography.hazmat.primitives.serialization import load_pem_private_key 23 - from fastapi import Cookie, Header, HTTPException 24 - from jose import jwk 25 - from sqlalchemy import select 26 - from sqlalchemy.ext.asyncio import AsyncSession 27 - 28 - from backend._internal.oauth_stores import PostgresStateStore 29 - from backend.config import settings 30 - from backend.models import ExchangeToken, PendingDevToken, UserPreferences, UserSession 31 - from backend.utilities.database import db_session 32 - 33 - logger = logging.getLogger(__name__) 34 - 35 - PUBLIC_REFRESH_TOKEN_DAYS = 14 36 - CONFIDENTIAL_REFRESH_TOKEN_DAYS = 180 37 - 38 - 39 - def _parse_scopes(scope_string: str) -> set[str]: 40 - """parse an OAuth scope string into a set of individual scopes. 41 - 42 - handles format like: "atproto repo:fm.plyr.track repo:fm.plyr.like" 43 - returns: {"repo:fm.plyr.track", "repo:fm.plyr.like"} 44 - """ 45 - parts = scope_string.split() 46 - # filter out the "atproto" prefix and keep just the repo: scopes 47 - return {p for p in parts if p.startswith("repo:")} 48 - 49 - 50 - def _check_scope_coverage(granted_scope: str, required_scope: str) -> bool: 51 - """check if granted scope covers all required scopes. 52 - 53 - returns True if the session has all required permissions. 54 - """ 55 - granted = _parse_scopes(granted_scope) 56 - required = _parse_scopes(required_scope) 57 - return required.issubset(granted) 58 - 59 - 60 - def _get_missing_scopes(granted_scope: str, required_scope: str) -> set[str]: 61 - """get the scopes that are required but not granted.""" 62 - granted = _parse_scopes(granted_scope) 63 - required = _parse_scopes(required_scope) 64 - return required - granted 65 - 66 - 67 - @dataclass 68 - class Session: 69 - """authenticated user session.""" 70 - 71 - session_id: str 72 - did: str 73 - handle: str 74 - oauth_session: dict # store OAuth session data 75 - 76 - def get_oauth_session_id(self) -> str: 77 - """extract OAuth session ID for retrieving from session store.""" 78 - return self.oauth_session.get("session_id", self.did) 79 - 80 - 81 - # OAuth stores 82 - # state store: postgres-backed for multi-instance resilience 83 - # session store: in-memory (not used, we use UserSession table instead) 84 - _state_store = PostgresStateStore() 85 - _session_store = MemorySessionStore() 86 - 87 - # confidential client key (loaded lazily) 88 - _client_secret_key: EllipticCurvePrivateKey | None = None 89 - _client_secret_kid: str | None = None 90 - _client_secret_key_loaded = False 91 - 92 - 93 - def _load_client_secret() -> tuple[EllipticCurvePrivateKey | None, str | None]: 94 - """load EC private key and kid from OAUTH_JWK setting for confidential client. 95 - 96 - the key is expected to be a JSON-serialized JWK with ES256 (P-256) key. 97 - returns (None, None) if OAUTH_JWK is not configured (public client mode). 98 - """ 99 - global _client_secret_key, _client_secret_kid, _client_secret_key_loaded 100 - 101 - if _client_secret_key_loaded: 102 - return _client_secret_key, _client_secret_kid 103 - 104 - _client_secret_key_loaded = True 105 - 106 - if not settings.atproto.oauth_jwk: 107 - logger.info("OAUTH_JWK not configured, using public OAuth client") 108 - return None, None 109 - 110 - try: 111 - # parse JWK JSON 112 - jwk_data = json.loads(settings.atproto.oauth_jwk) 113 - 114 - # extract kid (required for client assertions) 115 - _client_secret_kid = jwk_data.get("kid") 116 - if not _client_secret_kid: 117 - raise ValueError("OAUTH_JWK must include 'kid' field") 118 - 119 - # convert JWK to PEM format using python-jose 120 - key_obj = jwk.construct(jwk_data, algorithm="ES256") 121 - pem_bytes = key_obj.to_pem() 122 - 123 - # load as cryptography key 124 - loaded_key = load_pem_private_key(pem_bytes, password=None) 125 - 126 - if not isinstance(loaded_key, ec.EllipticCurvePrivateKey): 127 - raise ValueError("OAUTH_JWK must be an EC key (ES256)") 128 - 129 - _client_secret_key = loaded_key 130 - logger.info(f"loaded confidential OAuth client key (kid={_client_secret_kid})") 131 - return _client_secret_key, _client_secret_kid 132 - 133 - except Exception as e: 134 - logger.error(f"failed to load OAUTH_JWK: {e}") 135 - raise RuntimeError(f"invalid OAUTH_JWK configuration: {e}") from e 136 - 137 - 138 - def get_public_jwks() -> dict | None: 139 - """get public JWKS for the /.well-known/jwks.json endpoint. 140 - 141 - returns None if confidential client is not configured. 142 - """ 143 - if not settings.atproto.oauth_jwk: 144 - return None 145 - 146 - try: 147 - # parse private JWK 148 - jwk_data = json.loads(settings.atproto.oauth_jwk) 149 - 150 - # construct key and extract public components 151 - key_obj = jwk.construct(jwk_data, algorithm="ES256") 152 - public_jwk = key_obj.to_dict() 153 - 154 - # remove private key components, keep only public 155 - public_jwk.pop("d", None) # private key scalar 156 - 157 - # ensure required fields for public key 158 - public_jwk["use"] = "sig" 159 - public_jwk["alg"] = "ES256" 160 - 161 - # preserve kid from original JWK (python-jose's to_dict() doesn't include it) 162 - if "kid" in jwk_data: 163 - public_jwk["kid"] = jwk_data["kid"] 164 - 165 - return {"keys": [public_jwk]} 166 - 167 - except Exception as e: 168 - logger.error(f"failed to generate public JWKS: {e}") 169 - return None 170 - 171 - 172 - def is_confidential_client() -> bool: 173 - """check if confidential OAuth client is configured.""" 174 - return bool(settings.atproto.oauth_jwk) 175 - 176 - 177 - def get_client_auth_method(oauth_session_data: dict[str, Any] | None = None) -> str: 178 - """resolve client auth method for a session.""" 179 - if oauth_session_data: 180 - method = oauth_session_data.get("client_auth_method") 181 - if method in {"public", "confidential"}: 182 - return method 183 - return "confidential" if is_confidential_client() else "public" 184 - 185 - 186 - def get_refresh_token_lifetime_days(client_auth_method: str | None) -> int: 187 - """get expected refresh token lifetime in days.""" 188 - method = client_auth_method or get_client_auth_method() 189 - return ( 190 - CONFIDENTIAL_REFRESH_TOKEN_DAYS 191 - if method == "confidential" 192 - else PUBLIC_REFRESH_TOKEN_DAYS 193 - ) 194 - 195 - 196 - def _compute_refresh_token_expires_at( 197 - now: datetime, client_auth_method: str | None 198 - ) -> datetime: 199 - """compute refresh token expiration time.""" 200 - return now + timedelta(days=get_refresh_token_lifetime_days(client_auth_method)) 201 - 202 - 203 - def _parse_datetime(value: str | None) -> datetime | None: 204 - """parse ISO datetime string safely.""" 205 - if not value: 206 - return None 207 - try: 208 - return datetime.fromisoformat(value) 209 - except ValueError: 210 - return None 211 - 212 - 213 - def _get_refresh_token_expires_at( 214 - user_session: UserSession, 215 - oauth_session_data: dict[str, Any], 216 - ) -> datetime | None: 217 - """determine refresh token expiry for a session.""" 218 - parsed = _parse_datetime(oauth_session_data.get("refresh_token_expires_at")) 219 - if parsed: 220 - return parsed 221 - 222 - client_auth_method = oauth_session_data.get("client_auth_method") 223 - if client_auth_method: 224 - return user_session.created_at + timedelta( 225 - days=get_refresh_token_lifetime_days(client_auth_method) 226 - ) 227 - 228 - if user_session.is_developer_token: 229 - return user_session.created_at + timedelta(days=PUBLIC_REFRESH_TOKEN_DAYS) 230 - 231 - return None 232 - 233 - 234 - def get_oauth_client(include_teal: bool = False) -> OAuthClient: 235 - """create an OAuth client with the appropriate scopes. 236 - 237 - at ~17 OAuth flows/day, instantiation cost is negligible. 238 - this eliminates the need for pre-instantiated bifurcated clients. 239 - 240 - if OAUTH_JWK is configured, creates a confidential client with 241 - private_key_jwt authentication (180-day refresh tokens). 242 - otherwise creates a public client (2-week refresh tokens). 243 - """ 244 - scope = ( 245 - settings.atproto.resolved_scope_with_teal( 246 - settings.teal.play_collection, settings.teal.status_collection 247 - ) 248 - if include_teal 249 - else settings.atproto.resolved_scope 250 - ) 251 - 252 - # load confidential client key if configured 253 - client_secret_key, client_secret_kid = _load_client_secret() 254 - 255 - return OAuthClient( 256 - client_id=settings.atproto.client_id, 257 - redirect_uri=settings.atproto.redirect_uri, 258 - scope=scope, 259 - state_store=_state_store, 260 - session_store=_session_store, 261 - client_secret_key=client_secret_key, 262 - client_secret_kid=client_secret_kid, 263 - ) 264 - 265 - 266 - def get_oauth_client_for_scope(scope: str) -> OAuthClient: 267 - """get an OAuth client matching a given scope string. 268 - 269 - used during callback to match the scope that was used during authorization. 270 - """ 271 - include_teal = settings.teal.play_collection in scope 272 - return get_oauth_client(include_teal=include_teal) 273 - 274 - 275 - # encryption for sensitive OAuth data at rest 276 - # CRITICAL: encryption key must be configured and stable across restarts 277 - # otherwise all sessions become undecipherable after restart 278 - if not settings.atproto.oauth_encryption_key: 279 - raise RuntimeError( 280 - "oauth_encryption_key must be configured in settings. " 281 - "generate one with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'" 282 - ) 283 - 284 - _encryption_key = settings.atproto.oauth_encryption_key.encode() 285 - _fernet = Fernet(_encryption_key) 286 - 287 - 288 - def _encrypt_data(data: str) -> str: 289 - """encrypt sensitive data for storage.""" 290 - return _fernet.encrypt(data.encode()).decode() 291 - 292 - 293 - def _decrypt_data(encrypted: str) -> str | None: 294 - """decrypt sensitive data from storage. 295 - 296 - returns None if decryption fails (e.g., key changed, data corrupted). 297 - """ 298 - try: 299 - return _fernet.decrypt(encrypted.encode()).decode() 300 - except Exception: 301 - # decryption failed - likely key mismatch or corrupted data 302 - return None 303 - 304 - 305 - async def create_session( 306 - did: str, 307 - handle: str, 308 - oauth_session: dict[str, Any], 309 - expires_in_days: int = 14, 310 - is_developer_token: bool = False, 311 - token_name: str | None = None, 312 - group_id: str | None = None, 313 - ) -> str: 314 - """create a new session for authenticated user with encrypted OAuth data. 315 - 316 - args: 317 - did: user's decentralized identifier 318 - handle: user's ATProto handle 319 - oauth_session: OAuth session data to encrypt and store 320 - expires_in_days: session expiration in days (default 14, capped by refresh lifetime) 321 - is_developer_token: whether this is a developer token (for listing/revocation) 322 - token_name: optional name for the token (only for developer tokens) 323 - group_id: optional session group ID for multi-account support 324 - """ 325 - session_id = secrets.token_urlsafe(32) 326 - now = datetime.now(UTC) 327 - 328 - client_auth_method = get_client_auth_method(oauth_session) 329 - refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 330 - refresh_expires_at = _compute_refresh_token_expires_at(now, client_auth_method) 331 - 332 - oauth_session = dict(oauth_session) 333 - oauth_session.setdefault("client_auth_method", client_auth_method) 334 - oauth_session.setdefault("refresh_token_lifetime_days", refresh_lifetime_days) 335 - oauth_session.setdefault("refresh_token_expires_at", refresh_expires_at.isoformat()) 336 - 337 - effective_days = ( 338 - refresh_lifetime_days 339 - if expires_in_days <= 0 340 - else min(expires_in_days, refresh_lifetime_days) 341 - ) 342 - expires_at = now + timedelta(days=effective_days) 343 - 344 - encrypted_data = _encrypt_data(json.dumps(oauth_session)) 345 - 346 - async with db_session() as db: 347 - user_session = UserSession( 348 - session_id=session_id, 349 - did=did, 350 - handle=handle, 351 - oauth_session_data=encrypted_data, 352 - expires_at=expires_at, 353 - is_developer_token=is_developer_token, 354 - token_name=token_name, 355 - group_id=group_id, 356 - ) 357 - db.add(user_session) 358 - await db.commit() 359 - 360 - return session_id 361 - 362 - 363 - async def get_session(session_id: str) -> Session | None: 364 - """retrieve session by id, decrypt OAuth data, and validate expiration.""" 365 - async with db_session() as db: 366 - result = await db.execute( 367 - select(UserSession).where(UserSession.session_id == session_id) 368 - ) 369 - if not (user_session := result.scalar_one_or_none()): 370 - return None 371 - 372 - # check if session is expired 373 - if user_session.expires_at and datetime.now(UTC) > user_session.expires_at: 374 - # session expired - delete it and return None 375 - await delete_session(session_id) 376 - return None 377 - 378 - # decrypt OAuth session data 379 - decrypted_data = _decrypt_data(user_session.oauth_session_data) 380 - if decrypted_data is None: 381 - # decryption failed - session is invalid (key changed or data corrupted) 382 - # delete the corrupted session 383 - await delete_session(session_id) 384 - return None 385 - 386 - oauth_session_data = json.loads(decrypted_data) 387 - 388 - refresh_expires_at = _get_refresh_token_expires_at( 389 - user_session, oauth_session_data 390 - ) 391 - if refresh_expires_at and datetime.now(UTC) > refresh_expires_at: 392 - await delete_session(session_id) 393 - return None 394 - 395 - return Session( 396 - session_id=user_session.session_id, 397 - did=user_session.did, 398 - handle=user_session.handle, 399 - oauth_session=oauth_session_data, 400 - ) 401 - 402 - 403 - async def update_session_tokens( 404 - session_id: str, oauth_session_data: dict[str, Any] 405 - ) -> None: 406 - """update OAuth session data for a session (e.g., after token refresh).""" 407 - async with db_session() as db: 408 - result = await db.execute( 409 - select(UserSession).where(UserSession.session_id == session_id) 410 - ) 411 - if user_session := result.scalar_one_or_none(): 412 - # encrypt updated OAuth session data 413 - encrypted_data = _encrypt_data(json.dumps(oauth_session_data)) 414 - user_session.oauth_session_data = encrypted_data 415 - await db.commit() 416 - 417 - 418 - async def delete_session(session_id: str) -> None: 419 - """delete a session.""" 420 - async with db_session() as db: 421 - result = await db.execute( 422 - select(UserSession).where(UserSession.session_id == session_id) 423 - ) 424 - if user_session := result.scalar_one_or_none(): 425 - await db.delete(user_session) 426 - await db.commit() 427 - 428 - 429 - async def _check_teal_preference(did: str) -> bool: 430 - """check if user has enabled teal.fm scrobbling.""" 431 - async with db_session() as db: 432 - result = await db.execute( 433 - select(UserPreferences.enable_teal_scrobbling).where( 434 - UserPreferences.did == did 435 - ) 436 - ) 437 - pref = result.scalar_one_or_none() 438 - return pref is True 439 - 440 - 441 - async def start_oauth_flow( 442 - handle: str, prompt: PromptType | None = None 443 - ) -> tuple[str, str]: 444 - """start OAuth flow and return (auth_url, state). 445 - 446 - uses extended scope if user has enabled teal.fm scrobbling. 447 - 448 - args: 449 - handle: user's ATProto handle 450 - prompt: optional OAuth prompt parameter (login, select_account, consent, none) 451 - """ 452 - from backend._internal.atproto.handles import resolve_handle 453 - 454 - try: 455 - # resolve handle to DID to check preferences 456 - resolved = await resolve_handle(handle) 457 - if resolved: 458 - did = resolved["did"] 459 - wants_teal = await _check_teal_preference(did) 460 - client = get_oauth_client(include_teal=wants_teal) 461 - logger.info(f"starting OAuth for {handle} (did={did}, teal={wants_teal})") 462 - else: 463 - # fallback to base client if resolution fails 464 - # (OAuth flow will resolve handle again internally) 465 - client = get_oauth_client(include_teal=False) 466 - logger.info(f"starting OAuth for {handle} (resolution failed, using base)") 467 - 468 - auth_url, state = await client.start_authorization(handle, prompt=prompt) 469 - return auth_url, state 470 - except Exception as e: 471 - raise HTTPException( 472 - status_code=400, 473 - detail=f"failed to start OAuth flow: {e}", 474 - ) from e 475 - 476 - 477 - async def start_oauth_flow_with_scopes( 478 - handle: str, include_teal: bool, prompt: PromptType | None = None 479 - ) -> tuple[str, str]: 480 - """start OAuth flow with explicit scope selection. 481 - 482 - unlike start_oauth_flow which checks user preferences, this explicitly 483 - requests the specified scopes. used for scope upgrade flows. 484 - 485 - args: 486 - handle: user's ATProto handle 487 - include_teal: whether to include teal.fm scopes 488 - prompt: optional OAuth prompt parameter (login, select_account, consent, none) 489 - """ 490 - try: 491 - client = get_oauth_client(include_teal=include_teal) 492 - logger.info(f"starting scope upgrade OAuth for {handle} (teal={include_teal})") 493 - auth_url, state = await client.start_authorization(handle, prompt=prompt) 494 - return auth_url, state 495 - except Exception as e: 496 - raise HTTPException( 497 - status_code=400, 498 - detail=f"failed to start OAuth flow: {e}", 499 - ) from e 500 - 501 - 502 - async def start_oauth_flow_for_pds(pds_url: str) -> tuple[str, str]: 503 - """start OAuth flow for account creation on a PDS. 504 - 505 - unlike start_oauth_flow which resolves a handle to DID, this discovers 506 - the auth server directly from the PDS URL and sends PAR with prompt=create 507 - to trigger the account creation UI. 508 - 509 - args: 510 - pds_url: URL of the PDS to create account on (e.g., 'https://bsky.social') 511 - 512 - returns: 513 - tuple of (authorization_url, state) for redirecting user. 514 - """ 515 - from urllib.parse import urlencode 516 - 517 - import httpx 518 - 519 - try: 520 - pds_url = pds_url.rstrip("/") 521 - 522 - # discover auth server from PDS 523 - authserver_url = await discover_authserver_from_pds_async(pds_url) 524 - authserver_url = authserver_url.rstrip("/") 525 - 526 - # fetch auth server metadata 527 - authserver_meta = await fetch_authserver_metadata_async(authserver_url) 528 - 529 - # get OAuth client for scope/keys 530 - client = get_oauth_client(include_teal=False) 531 - 532 - # generate PKCE and DPoP 533 - pkce = PKCEManager() 534 - pkce_verifier, pkce_challenge = pkce.generate_pair() 535 - dpop = DPoPManager() 536 - dpop_key = dpop.generate_keypair() 537 - state_token = secrets.token_urlsafe(32) 538 - 539 - # build PAR request with prompt=create and no login_hint 540 - par_url = authserver_meta.pushed_authorization_request_endpoint 541 - params: dict[str, str] = { 542 - "response_type": "code", 543 - "code_challenge": pkce_challenge, 544 - "code_challenge_method": "S256", 545 - "state": state_token, 546 - "redirect_uri": client.redirect_uri, 547 - "scope": client.scope, 548 - "client_id": client.client_id, 549 - "prompt": "create", 550 - } 551 - 552 - # add client authentication if confidential client 553 - client_secret_key, client_secret_kid = _load_client_secret() 554 - if client_secret_key and client_secret_kid: 555 - client_assertion = _create_client_assertion( 556 - client.client_id, 557 - authserver_meta.issuer, 558 - client_secret_key, 559 - client_secret_kid, 560 - ) 561 - params["client_assertion_type"] = ( 562 - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" 563 - ) 564 - params["client_assertion"] = client_assertion 565 - 566 - # make PAR request with DPoP nonce retry 567 - dpop_nonce = "" 568 - for attempt in range(2): 569 - dpop_proof = dpop.create_proof( 570 - method="POST", 571 - url=par_url, 572 - private_key=dpop_key, 573 - nonce=dpop_nonce if dpop_nonce else None, 574 - ) 575 - 576 - async with httpx.AsyncClient() as http: 577 - response = await http.post( 578 - par_url, data=params, headers={"DPoP": dpop_proof} 579 - ) 580 - 581 - if dpop.is_dpop_nonce_error(response): 582 - new_nonce = dpop.extract_nonce_from_response(response) 583 - if new_nonce and attempt == 0: 584 - dpop_nonce = new_nonce 585 - continue 586 - 587 - dpop_nonce = dpop.extract_nonce_from_response(response) or dpop_nonce 588 - break 589 - 590 - if response.status_code not in (200, 201): 591 - raise HTTPException( 592 - status_code=400, 593 - detail=f"PAR request failed: {response.status_code} {response.text}", 594 - ) 595 - 596 - par_response = response.json() 597 - request_uri = par_response["request_uri"] 598 - 599 - # store state with did=None (unknown until account created) 600 - oauth_state = OAuthState( 601 - state=state_token, 602 - pkce_verifier=pkce_verifier, 603 - redirect_uri=client.redirect_uri, 604 - scope=client.scope, 605 - authserver_iss=authserver_meta.issuer, 606 - dpop_private_key=dpop_key, 607 - dpop_authserver_nonce=dpop_nonce, 608 - did=None, 609 - handle=None, 610 - pds_url=pds_url, 611 - ) 612 - await _state_store.save_state(oauth_state) 613 - 614 - # build authorization URL 615 - auth_params = {"client_id": client.client_id, "request_uri": request_uri} 616 - auth_url = f"{authserver_meta.authorization_endpoint}?{urlencode(auth_params)}" 617 - 618 - logger.info(f"starting account creation OAuth for PDS {pds_url}") 619 - return auth_url, state_token 620 - 621 - except Exception as e: 622 - if isinstance(e, HTTPException): 623 - raise 624 - raise HTTPException( 625 - status_code=400, 626 - detail=f"failed to start account creation OAuth: {e}", 627 - ) from e 628 - 629 - 630 - def _create_client_assertion( 631 - client_id: str, 632 - audience: str, 633 - private_key: EllipticCurvePrivateKey, 634 - kid: str, 635 - ) -> str: 636 - """create client assertion JWT for confidential client.""" 637 - header = {"alg": "ES256", "typ": "JWT", "kid": kid} 638 - now = int(time.time()) 639 - payload = { 640 - "iss": client_id, 641 - "sub": client_id, 642 - "aud": audience, 643 - "jti": secrets.token_urlsafe(16), 644 - "iat": now, 645 - "exp": now + 60, 646 - } 647 - 648 - dpop = DPoPManager() 649 - return dpop._sign_jwt(header, payload, private_key) 650 - 651 - 652 - async def _resolve_handle_from_pds(pds_url: str, did: str) -> str | None: 653 - """resolve handle from PDS when OAuth doesn't return it. 654 - 655 - this happens for newly created accounts on third-party PDSes where 656 - the handle isn't yet indexed by the Bluesky AppView. 657 - """ 658 - import httpx 659 - 660 - try: 661 - async with httpx.AsyncClient() as client: 662 - resp = await client.get( 663 - f"{pds_url}/xrpc/com.atproto.repo.describeRepo", 664 - params={"repo": did}, 665 - timeout=10.0, 666 - ) 667 - if resp.status_code == 200: 668 - data = resp.json() 669 - handle = data.get("handle") 670 - if handle: 671 - logger.info(f"resolved handle from PDS: {handle}") 672 - return handle 673 - except Exception as e: 674 - logger.warning(f"failed to resolve handle from PDS: {e}") 675 - return None 676 - 677 - 678 - async def handle_oauth_callback( 679 - code: str, state: str, iss: str 680 - ) -> tuple[str, str, dict]: 681 - """handle OAuth callback and return (did, handle, oauth_session). 682 - 683 - uses the appropriate OAuth client based on stored state's scope. 684 - """ 685 - try: 686 - # look up stored state to determine which scope was used 687 - if stored_state := await _state_store.get_state(state): 688 - client = get_oauth_client_for_scope(stored_state.scope) 689 - logger.info( 690 - f"callback using client for scope: {stored_state.scope[:50]}..." 691 - ) 692 - else: 693 - # fallback to base client (state might have been cleaned up) 694 - client = get_oauth_client(include_teal=False) 695 - logger.warning(f"state {state[:8]}... not found, using base client") 696 - 697 - oauth_session = await client.handle_callback( 698 - code=code, 699 - state=state, 700 - iss=iss, 701 - ) 702 - 703 - # resolve handle from PDS if not provided by OAuth 704 - # (happens for newly created accounts on third-party PDSes) 705 - handle = oauth_session.handle 706 - if not handle: 707 - handle = ( 708 - await _resolve_handle_from_pds(oauth_session.pds_url, oauth_session.did) 709 - or "" 710 - ) 711 - 712 - # serialize DPoP private key for storage 713 - from cryptography.hazmat.primitives import serialization 714 - 715 - dpop_key_pem = oauth_session.dpop_private_key.private_bytes( 716 - encoding=serialization.Encoding.PEM, 717 - format=serialization.PrivateFormat.PKCS8, 718 - encryption_algorithm=serialization.NoEncryption(), 719 - ).decode("utf-8") 720 - 721 - client_auth_method = get_client_auth_method() 722 - refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 723 - refresh_expires_at = _compute_refresh_token_expires_at( 724 - datetime.now(UTC), client_auth_method 725 - ) 726 - 727 - # store full OAuth session with tokens in database 728 - session_data = { 729 - "did": oauth_session.did, 730 - "handle": handle, 731 - "pds_url": oauth_session.pds_url, 732 - "authserver_iss": oauth_session.authserver_iss, 733 - "scope": oauth_session.scope, 734 - "access_token": oauth_session.access_token, 735 - "refresh_token": oauth_session.refresh_token, 736 - "dpop_private_key_pem": dpop_key_pem, 737 - "dpop_authserver_nonce": oauth_session.dpop_authserver_nonce, 738 - "dpop_pds_nonce": oauth_session.dpop_pds_nonce or "", 739 - "client_auth_method": client_auth_method, 740 - "refresh_token_lifetime_days": refresh_lifetime_days, 741 - "refresh_token_expires_at": refresh_expires_at.isoformat(), 742 - } 743 - return oauth_session.did, handle, session_data 744 - except Exception as e: 745 - raise HTTPException( 746 - status_code=401, 747 - detail=f"OAuth callback failed: {e}", 748 - ) from e 749 - 750 - 751 - async def check_artist_profile_exists(did: str) -> bool: 752 - """check if artist profile exists for a DID.""" 753 - from backend.models import Artist 754 - 755 - async with db_session() as db: 756 - result = await db.execute(select(Artist).where(Artist.did == did)) 757 - artist = result.scalar_one_or_none() 758 - return artist is not None 759 - 760 - 761 - async def ensure_artist_exists(did: str, handle: str) -> bool: 762 - """ensure an Artist record exists for the given DID, creating a minimal one if needed. 763 - 764 - this ensures all authenticated users have at least a basic Artist record, 765 - which is needed for displaying handles in share link stats, comments, etc. 766 - 767 - returns True if artist was created, False if it already existed. 768 - """ 769 - from backend._internal.atproto.profile import fetch_user_avatar 770 - from backend.models import Artist 771 - 772 - async with db_session() as db: 773 - result = await db.execute(select(Artist).where(Artist.did == did)) 774 - if result.scalar_one_or_none(): 775 - return False # already exists 776 - 777 - # fetch avatar from Bluesky 778 - avatar_url = await fetch_user_avatar(did) 779 - 780 - # create minimal artist record 781 - artist = Artist( 782 - did=did, 783 - handle=handle, 784 - display_name=handle, # use handle as initial display name 785 - avatar_url=avatar_url, 786 - ) 787 - db.add(artist) 788 - await db.commit() 789 - logger.info(f"created minimal artist record for {did} (@{handle})") 790 - return True 791 - 792 - 793 - async def create_exchange_token(session_id: str, is_dev_token: bool = False) -> str: 794 - """create a one-time use exchange token for secure OAuth callback. 795 - 796 - exchange tokens expire after 60 seconds and can only be used once, 797 - preventing session_id exposure in browser history/referrers. 798 - 799 - args: 800 - session_id: the session to associate with this exchange token 801 - is_dev_token: if True, the exchange will not set a browser cookie 802 - """ 803 - token = secrets.token_urlsafe(32) 804 - 805 - async with db_session() as db: 806 - exchange_token = ExchangeToken( 807 - token=token, 808 - session_id=session_id, 809 - is_dev_token=is_dev_token, 810 - ) 811 - db.add(exchange_token) 812 - await db.commit() 813 - 814 - return token 815 - 816 - 817 - async def consume_exchange_token(token: str) -> tuple[str, bool] | None: 818 - """consume an exchange token and return (session_id, is_dev_token). 819 - 820 - returns None if token is invalid, expired, or already used. 821 - uses atomic UPDATE to prevent race conditions (token can only be used once). 822 - """ 823 - from sqlalchemy import update 824 - 825 - async with db_session() as db: 826 - # first, check if token exists and is not expired 827 - result = await db.execute( 828 - select(ExchangeToken).where(ExchangeToken.token == token) 829 - ) 830 - exchange_token = result.scalar_one_or_none() 831 - 832 - if not exchange_token: 833 - return None 834 - 835 - # check if expired 836 - if datetime.now(UTC) > exchange_token.expires_at: 837 - return None 838 - 839 - # capture is_dev_token before atomic update 840 - is_dev_token = exchange_token.is_dev_token 841 - 842 - # atomically mark as used ONLY if not already used 843 - # this prevents race conditions where two requests try to use the same token 844 - result = await db.execute( 845 - update(ExchangeToken) 846 - .where(ExchangeToken.token == token, ExchangeToken.used == False) # noqa: E712 847 - .values(used=True) 848 - .returning(ExchangeToken.session_id) 849 - ) 850 - await db.commit() 851 - 852 - # if no rows were updated, token was already used 853 - session_id = result.scalar_one_or_none() 854 - if session_id is None: 855 - return None 856 - 857 - return session_id, is_dev_token 858 - 859 - 860 - async def require_auth( 861 - authorization: Annotated[str | None, Header()] = None, 862 - session_id: Annotated[str | None, Cookie(alias="session_id")] = None, 863 - ) -> Session: 864 - """fastapi dependency to require authentication with expiration validation. 865 - 866 - checks cookie first (for browser requests), then falls back to Authorization 867 - header (for SDK/CLI clients). this enables secure HttpOnly cookies for browsers 868 - while maintaining bearer token support for API clients. 869 - 870 - also validates that the session's granted scopes cover all currently required 871 - scopes. if not, returns 403 with "scope_upgrade_required" to prompt re-login. 872 - """ 873 - session_id_value = None 874 - 875 - if session_id: 876 - session_id_value = session_id 877 - elif authorization and authorization.startswith("Bearer "): 878 - session_id_value = authorization.removeprefix("Bearer ") 879 - 880 - if not session_id_value: 881 - raise HTTPException( 882 - status_code=401, 883 - detail="not authenticated - login required", 884 - ) 885 - 886 - session = await get_session(session_id_value) 887 - if not session: 888 - raise HTTPException( 889 - status_code=401, 890 - detail="invalid or expired session", 891 - ) 892 - 893 - # check if session has all required scopes 894 - granted_scope = session.oauth_session.get("scope", "") 895 - required_scope = settings.atproto.resolved_scope 896 - 897 - if not _check_scope_coverage(granted_scope, required_scope): 898 - missing = _get_missing_scopes(granted_scope, required_scope) 899 - logger.info( 900 - f"session {session.did} missing scopes: {missing}, prompting re-auth" 901 - ) 902 - raise HTTPException( 903 - status_code=403, 904 - detail="scope_upgrade_required", 905 - ) 906 - 907 - return session 908 - 909 - 910 - async def get_optional_session( 911 - authorization: Annotated[str | None, Header()] = None, 912 - session_id: Annotated[str | None, Cookie(alias="session_id")] = None, 913 - ) -> Session | None: 914 - """fastapi dependency to optionally get the current session. 915 - 916 - returns None if not authenticated, otherwise returns the session. 917 - useful for public endpoints that show additional info for logged-in users. 918 - """ 919 - session_id_value = None 920 - 921 - if session_id: 922 - session_id_value = session_id 923 - elif authorization and authorization.startswith("Bearer "): 924 - session_id_value = authorization.removeprefix("Bearer ") 925 - 926 - if not session_id_value: 927 - return None 928 - 929 - return await get_session(session_id_value) 930 - 931 - 932 - async def require_artist_profile( 933 - authorization: Annotated[str | None, Header()] = None, 934 - session_id: Annotated[str | None, Cookie(alias="session_id")] = None, 935 - ) -> Session: 936 - """fastapi dependency to require authentication AND complete artist profile. 937 - 938 - Returns 403 with specific message if artist profile doesn't exist, 939 - prompting frontend to redirect to profile setup. 940 - """ 941 - session = await require_auth(authorization, session_id) 942 - 943 - # check if artist profile exists 944 - if not await check_artist_profile_exists(session.did): 945 - raise HTTPException( 946 - status_code=403, 947 - detail="artist_profile_required", 948 - ) 949 - 950 - return session 951 - 952 - 953 - @dataclass 954 - class DeveloperToken: 955 - """developer token metadata (without sensitive session data).""" 956 - 957 - session_id: str 958 - token_name: str | None 959 - created_at: datetime 960 - expires_at: datetime | None 961 - 962 - 963 - async def list_developer_tokens(did: str) -> list[DeveloperToken]: 964 - """list all developer tokens for a user.""" 965 - async with db_session() as db: 966 - result = await db.execute( 967 - select(UserSession).where( 968 - UserSession.did == did, 969 - UserSession.is_developer_token == True, # noqa: E712 970 - ) 971 - ) 972 - sessions = result.scalars().all() 973 - 974 - tokens = [] 975 - now = datetime.now(UTC) 976 - for session in sessions: 977 - decrypted_data = _decrypt_data(session.oauth_session_data) 978 - oauth_session_data = ( 979 - json.loads(decrypted_data) if decrypted_data is not None else {} 980 - ) 981 - refresh_expires_at = _get_refresh_token_expires_at( 982 - session, oauth_session_data 983 - ) 984 - effective_expires_at = session.expires_at 985 - if refresh_expires_at and ( 986 - effective_expires_at is None 987 - or refresh_expires_at < effective_expires_at 988 - ): 989 - effective_expires_at = refresh_expires_at 990 - 991 - # check if expired 992 - if effective_expires_at and now > effective_expires_at: 993 - continue # skip expired tokens 994 - 995 - tokens.append( 996 - DeveloperToken( 997 - session_id=session.session_id, 998 - token_name=session.token_name, 999 - created_at=session.created_at, 1000 - expires_at=effective_expires_at, 1001 - ) 1002 - ) 1003 - 1004 - return tokens 1005 - 1006 - 1007 - async def revoke_developer_token(did: str, session_id: str) -> bool: 1008 - """revoke a developer token. returns True if successful, False if not found.""" 1009 - async with db_session() as db: 1010 - result = await db.execute( 1011 - select(UserSession).where( 1012 - UserSession.session_id == session_id, 1013 - UserSession.did == did, # ensure user owns this token 1014 - UserSession.is_developer_token == True, # noqa: E712 1015 - ) 1016 - ) 1017 - session = result.scalar_one_or_none() 1018 - 1019 - if not session: 1020 - return False 1021 - 1022 - await db.delete(session) 1023 - await db.commit() 1024 - return True 1025 - 1026 - 1027 - @dataclass 1028 - class PendingDevTokenData: 1029 - """metadata for a pending developer token OAuth flow.""" 1030 - 1031 - state: str 1032 - did: str 1033 - token_name: str | None 1034 - expires_in_days: int 1035 - 1036 - 1037 - async def save_pending_dev_token( 1038 - state: str, 1039 - did: str, 1040 - token_name: str | None, 1041 - expires_in_days: int, 1042 - ) -> None: 1043 - """save pending dev token metadata keyed by OAuth state.""" 1044 - async with db_session() as db: 1045 - pending = PendingDevToken( 1046 - state=state, 1047 - did=did, 1048 - token_name=token_name, 1049 - expires_in_days=expires_in_days, 1050 - ) 1051 - db.add(pending) 1052 - await db.commit() 1053 - 1054 - 1055 - async def get_pending_dev_token(state: str) -> PendingDevTokenData | None: 1056 - """get pending dev token metadata by OAuth state.""" 1057 - async with db_session() as db: 1058 - result = await db.execute( 1059 - select(PendingDevToken).where(PendingDevToken.state == state) 1060 - ) 1061 - pending = result.scalar_one_or_none() 1062 - 1063 - if not pending: 1064 - return None 1065 - 1066 - # check if expired 1067 - if datetime.now(UTC) > pending.expires_at: 1068 - await db.delete(pending) 1069 - await db.commit() 1070 - return None 1071 - 1072 - return PendingDevTokenData( 1073 - state=pending.state, 1074 - did=pending.did, 1075 - token_name=pending.token_name, 1076 - expires_in_days=pending.expires_in_days, 1077 - ) 1078 - 1079 - 1080 - async def delete_pending_dev_token(state: str) -> None: 1081 - """delete pending dev token metadata after use.""" 1082 - async with db_session() as db: 1083 - result = await db.execute( 1084 - select(PendingDevToken).where(PendingDevToken.state == state) 1085 - ) 1086 - if pending := result.scalar_one_or_none(): 1087 - await db.delete(pending) 1088 - await db.commit() 1089 - 1090 - 1091 - # scope upgrade flow helpers 1092 - 1093 - 1094 - @dataclass 1095 - class PendingScopeUpgradeData: 1096 - """metadata for a pending scope upgrade OAuth flow.""" 1097 - 1098 - state: str 1099 - did: str 1100 - old_session_id: str 1101 - requested_scopes: str 1102 - 1103 - 1104 - async def save_pending_scope_upgrade( 1105 - state: str, 1106 - did: str, 1107 - old_session_id: str, 1108 - requested_scopes: str, 1109 - ) -> None: 1110 - """save pending scope upgrade metadata keyed by OAuth state.""" 1111 - from backend.models import PendingScopeUpgrade 1112 - 1113 - async with db_session() as db: 1114 - pending = PendingScopeUpgrade( 1115 - state=state, 1116 - did=did, 1117 - old_session_id=old_session_id, 1118 - requested_scopes=requested_scopes, 1119 - ) 1120 - db.add(pending) 1121 - await db.commit() 1122 - 1123 - 1124 - async def get_pending_scope_upgrade(state: str) -> PendingScopeUpgradeData | None: 1125 - """get pending scope upgrade metadata by OAuth state.""" 1126 - from backend.models import PendingScopeUpgrade 1127 - 1128 - async with db_session() as db: 1129 - result = await db.execute( 1130 - select(PendingScopeUpgrade).where(PendingScopeUpgrade.state == state) 1131 - ) 1132 - pending = result.scalar_one_or_none() 1133 - 1134 - if not pending: 1135 - return None 1136 - 1137 - # check if expired 1138 - if datetime.now(UTC) > pending.expires_at: 1139 - await db.delete(pending) 1140 - await db.commit() 1141 - return None 1142 - 1143 - return PendingScopeUpgradeData( 1144 - state=pending.state, 1145 - did=pending.did, 1146 - old_session_id=pending.old_session_id, 1147 - requested_scopes=pending.requested_scopes, 1148 - ) 1149 - 1150 - 1151 - async def delete_pending_scope_upgrade(state: str) -> None: 1152 - """delete pending scope upgrade metadata after use.""" 1153 - from backend.models import PendingScopeUpgrade 1154 - 1155 - async with db_session() as db: 1156 - result = await db.execute( 1157 - select(PendingScopeUpgrade).where(PendingScopeUpgrade.state == state) 1158 - ) 1159 - if pending := result.scalar_one_or_none(): 1160 - await db.delete(pending) 1161 - await db.commit() 1162 - 1163 - 1164 - # multi-account session group helpers 1165 - 1166 - 1167 - @dataclass 1168 - class LinkedAccount: 1169 - """account info for account switcher UI.""" 1170 - 1171 - did: str 1172 - handle: str 1173 - session_id: str 1174 - 1175 - 1176 - async def _get_session_group_impl( 1177 - session_id: str, db: AsyncSession 1178 - ) -> list[LinkedAccount]: 1179 - """implementation of get_session_group using provided db session.""" 1180 - result = await db.execute( 1181 - select(UserSession.group_id).where(UserSession.session_id == session_id) 1182 - ) 1183 - group_id = result.scalar_one_or_none() 1184 - 1185 - if not group_id: 1186 - return [] 1187 - 1188 - result = await db.execute( 1189 - select(UserSession).where( 1190 - UserSession.group_id == group_id, 1191 - UserSession.is_developer_token == False, # noqa: E712 1192 - ) 1193 - ) 1194 - sessions = result.scalars().all() 1195 - 1196 - accounts = [] 1197 - for session in sessions: 1198 - if session.expires_at and datetime.now(UTC) > session.expires_at: 1199 - continue 1200 - 1201 - accounts.append( 1202 - LinkedAccount( 1203 - did=session.did, 1204 - handle=session.handle, 1205 - session_id=session.session_id, 1206 - ) 1207 - ) 1208 - 1209 - return accounts 1210 - 1211 - 1212 - async def get_session_group( 1213 - session_id: str, db: AsyncSession | None = None 1214 - ) -> list[LinkedAccount]: 1215 - """get all accounts in the same session group. 1216 - 1217 - returns empty list if session has no group_id (single account). 1218 - 1219 - args: 1220 - session_id: the session to look up 1221 - db: optional database session to reuse (avoids new connection) 1222 - """ 1223 - if db is not None: 1224 - return await _get_session_group_impl(session_id, db) 1225 - 1226 - async with db_session() as new_db: 1227 - return await _get_session_group_impl(session_id, new_db) 1228 - 1229 - 1230 - async def get_or_create_group_id(session_id: str) -> str: 1231 - """get existing group_id or create one for this session. 1232 - 1233 - used when adding a second account to create a group. 1234 - """ 1235 - async with db_session() as db: 1236 - result = await db.execute( 1237 - select(UserSession).where(UserSession.session_id == session_id) 1238 - ) 1239 - session = result.scalar_one_or_none() 1240 - 1241 - if not session: 1242 - raise HTTPException(status_code=404, detail="session not found") 1243 - 1244 - if session.group_id: 1245 - return session.group_id 1246 - 1247 - # create new group_id for this session 1248 - group_id = secrets.token_urlsafe(32) 1249 - session.group_id = group_id 1250 - await db.commit() 1251 - 1252 - return group_id 1253 - 1254 - 1255 - async def _switch_active_account_impl( 1256 - current_session_id: str, target_session_id: str, db: AsyncSession 1257 - ) -> str: 1258 - """implementation of switch_active_account using provided db session.""" 1259 - result = await db.execute( 1260 - select(UserSession).where(UserSession.session_id == current_session_id) 1261 - ) 1262 - current_session = result.scalar_one_or_none() 1263 - 1264 - if not current_session or not current_session.group_id: 1265 - raise HTTPException(status_code=400, detail="no session group found") 1266 - 1267 - result = await db.execute( 1268 - select(UserSession).where(UserSession.session_id == target_session_id) 1269 - ) 1270 - target_session = result.scalar_one_or_none() 1271 - 1272 - if not target_session: 1273 - raise HTTPException(status_code=404, detail="target session not found") 1274 - 1275 - if target_session.group_id != current_session.group_id: 1276 - raise HTTPException(status_code=403, detail="target session not in same group") 1277 - 1278 - if target_session.expires_at and datetime.now(UTC) > target_session.expires_at: 1279 - raise HTTPException(status_code=401, detail="target session expired") 1280 - 1281 - return target_session_id 1282 - 1283 - 1284 - async def switch_active_account( 1285 - current_session_id: str, target_session_id: str, db: AsyncSession | None = None 1286 - ) -> str: 1287 - """switch to a different account within a session group. 1288 - 1289 - validates that the target session exists, is in the same group, and isn't expired. 1290 - returns the target session_id (caller updates the cookie). 1291 - 1292 - args: 1293 - current_session_id: the current session 1294 - target_session_id: the session to switch to 1295 - db: optional database session to reuse (avoids new connection) 1296 - """ 1297 - if db is not None: 1298 - return await _switch_active_account_impl( 1299 - current_session_id, target_session_id, db 1300 - ) 1301 - 1302 - async with db_session() as new_db: 1303 - return await _switch_active_account_impl( 1304 - current_session_id, target_session_id, new_db 1305 - ) 1306 - 1307 - 1308 - async def remove_account_from_group(session_id: str) -> str | None: 1309 - """remove a session from its group and delete it. 1310 - 1311 - returns session_id of another account in the group, or None if last account. 1312 - """ 1313 - async with db_session() as db: 1314 - result = await db.execute( 1315 - select(UserSession).where(UserSession.session_id == session_id) 1316 - ) 1317 - session = result.scalar_one_or_none() 1318 - 1319 - if not session: 1320 - return None 1321 - 1322 - group_id = session.group_id 1323 - 1324 - await db.delete(session) 1325 - await db.commit() 1326 - 1327 - if not group_id: 1328 - return None 1329 - 1330 - result = await db.execute( 1331 - select(UserSession).where( 1332 - UserSession.group_id == group_id, 1333 - UserSession.is_developer_token == False, # noqa: E712 1334 - ) 1335 - ) 1336 - remaining = result.scalars().first() 1337 - 1338 - return remaining.session_id if remaining else None 1339 - 1340 - 1341 - # pending add account flow helpers 1342 - 1343 - 1344 - @dataclass 1345 - class PendingAddAccountData: 1346 - """metadata for a pending add-account OAuth flow.""" 1347 - 1348 - state: str 1349 - group_id: str 1350 - 1351 - 1352 - async def save_pending_add_account(state: str, group_id: str) -> None: 1353 - """save pending add-account metadata keyed by OAuth state.""" 1354 - from backend.models import PendingAddAccount 1355 - 1356 - async with db_session() as db: 1357 - pending = PendingAddAccount( 1358 - state=state, 1359 - group_id=group_id, 1360 - ) 1361 - db.add(pending) 1362 - await db.commit() 1363 - 1364 - 1365 - async def get_pending_add_account(state: str) -> PendingAddAccountData | None: 1366 - """get pending add-account metadata by OAuth state.""" 1367 - from backend.models import PendingAddAccount 1368 - 1369 - async with db_session() as db: 1370 - result = await db.execute( 1371 - select(PendingAddAccount).where(PendingAddAccount.state == state) 1372 - ) 1373 - pending = result.scalar_one_or_none() 1374 - 1375 - if not pending: 1376 - return None 1377 - 1378 - # check if expired 1379 - if datetime.now(UTC) > pending.expires_at: 1380 - await db.delete(pending) 1381 - await db.commit() 1382 - return None 1383 - 1384 - return PendingAddAccountData( 1385 - state=pending.state, 1386 - group_id=pending.group_id, 1387 - ) 1388 - 1389 - 1390 - async def delete_pending_add_account(state: str) -> None: 1391 - """delete pending add-account metadata after use.""" 1392 - from backend.models import PendingAddAccount 1393 - 1394 - async with db_session() as db: 1395 - result = await db.execute( 1396 - select(PendingAddAccount).where(PendingAddAccount.state == state) 1397 - ) 1398 - if pending := result.scalar_one_or_none(): 1399 - await db.delete(pending) 1400 - await db.commit()
+116
backend/src/backend/_internal/auth/__init__.py
··· 1 + """OAuth 2.1 authentication and session management.""" 2 + 3 + from backend._internal.auth.account_groups import ( 4 + LinkedAccount, 5 + PendingAddAccountData, 6 + PendingScopeUpgradeData, 7 + delete_pending_add_account, 8 + delete_pending_scope_upgrade, 9 + get_or_create_group_id, 10 + get_pending_add_account, 11 + get_pending_scope_upgrade, 12 + get_session_group, 13 + remove_account_from_group, 14 + save_pending_add_account, 15 + save_pending_scope_upgrade, 16 + switch_active_account, 17 + ) 18 + from backend._internal.auth.dependencies import ( 19 + get_optional_session, 20 + require_artist_profile, 21 + require_auth, 22 + ) 23 + from backend._internal.auth.developer_tokens import ( 24 + DeveloperToken, 25 + PendingDevTokenData, 26 + delete_pending_dev_token, 27 + get_pending_dev_token, 28 + list_developer_tokens, 29 + revoke_developer_token, 30 + save_pending_dev_token, 31 + ) 32 + from backend._internal.auth.encryption import _decrypt_data, _encrypt_data 33 + from backend._internal.auth.exchange import ( 34 + consume_exchange_token, 35 + create_exchange_token, 36 + ) 37 + from backend._internal.auth.oauth import ( 38 + check_artist_profile_exists, 39 + ensure_artist_exists, 40 + get_oauth_client, 41 + get_oauth_client_for_scope, 42 + get_public_jwks, 43 + handle_oauth_callback, 44 + start_oauth_flow, 45 + start_oauth_flow_for_pds, 46 + start_oauth_flow_with_scopes, 47 + ) 48 + from backend._internal.auth.scopes import ( 49 + _check_scope_coverage, 50 + _get_missing_scopes, 51 + _parse_scopes, 52 + ) 53 + from backend._internal.auth.session import ( 54 + CONFIDENTIAL_REFRESH_TOKEN_DAYS, 55 + PUBLIC_REFRESH_TOKEN_DAYS, 56 + Session, 57 + create_session, 58 + delete_session, 59 + get_client_auth_method, 60 + get_refresh_token_lifetime_days, 61 + get_session, 62 + is_confidential_client, 63 + update_session_tokens, 64 + ) 65 + 66 + __all__ = [ 67 + "CONFIDENTIAL_REFRESH_TOKEN_DAYS", 68 + "PUBLIC_REFRESH_TOKEN_DAYS", 69 + "DeveloperToken", 70 + "LinkedAccount", 71 + "PendingAddAccountData", 72 + "PendingDevTokenData", 73 + "PendingScopeUpgradeData", 74 + "Session", 75 + "_check_scope_coverage", 76 + "_decrypt_data", 77 + "_encrypt_data", 78 + "_get_missing_scopes", 79 + "_parse_scopes", 80 + "check_artist_profile_exists", 81 + "consume_exchange_token", 82 + "create_exchange_token", 83 + "create_session", 84 + "delete_pending_add_account", 85 + "delete_pending_dev_token", 86 + "delete_pending_scope_upgrade", 87 + "delete_session", 88 + "ensure_artist_exists", 89 + "get_client_auth_method", 90 + "get_oauth_client", 91 + "get_oauth_client_for_scope", 92 + "get_optional_session", 93 + "get_or_create_group_id", 94 + "get_pending_add_account", 95 + "get_pending_dev_token", 96 + "get_pending_scope_upgrade", 97 + "get_public_jwks", 98 + "get_refresh_token_lifetime_days", 99 + "get_session", 100 + "get_session_group", 101 + "handle_oauth_callback", 102 + "is_confidential_client", 103 + "list_developer_tokens", 104 + "remove_account_from_group", 105 + "require_artist_profile", 106 + "require_auth", 107 + "revoke_developer_token", 108 + "save_pending_add_account", 109 + "save_pending_dev_token", 110 + "save_pending_scope_upgrade", 111 + "start_oauth_flow", 112 + "start_oauth_flow_for_pds", 113 + "start_oauth_flow_with_scopes", 114 + "switch_active_account", 115 + "update_session_tokens", 116 + ]
+321
backend/src/backend/_internal/auth/account_groups.py
··· 1 + """Multi-account session groups: LinkedAccount, group CRUD, switch, remove, pending add-account.""" 2 + 3 + import secrets 4 + from dataclasses import dataclass 5 + from datetime import UTC, datetime 6 + 7 + from fastapi import HTTPException 8 + from sqlalchemy import select 9 + from sqlalchemy.ext.asyncio import AsyncSession 10 + 11 + from backend.models import UserSession 12 + from backend.utilities.database import db_session 13 + 14 + 15 + @dataclass 16 + class LinkedAccount: 17 + """account info for account switcher UI.""" 18 + 19 + did: str 20 + handle: str 21 + session_id: str 22 + 23 + 24 + async def _get_session_group_impl( 25 + session_id: str, db: AsyncSession 26 + ) -> list[LinkedAccount]: 27 + """implementation of get_session_group using provided db session.""" 28 + result = await db.execute( 29 + select(UserSession.group_id).where(UserSession.session_id == session_id) 30 + ) 31 + group_id = result.scalar_one_or_none() 32 + 33 + if not group_id: 34 + return [] 35 + 36 + result = await db.execute( 37 + select(UserSession).where( 38 + UserSession.group_id == group_id, 39 + UserSession.is_developer_token == False, # noqa: E712 40 + ) 41 + ) 42 + sessions = result.scalars().all() 43 + 44 + accounts = [] 45 + for session in sessions: 46 + if session.expires_at and datetime.now(UTC) > session.expires_at: 47 + continue 48 + 49 + accounts.append( 50 + LinkedAccount( 51 + did=session.did, 52 + handle=session.handle, 53 + session_id=session.session_id, 54 + ) 55 + ) 56 + 57 + return accounts 58 + 59 + 60 + async def get_session_group( 61 + session_id: str, db: AsyncSession | None = None 62 + ) -> list[LinkedAccount]: 63 + """get all accounts in the same session group. 64 + 65 + returns empty list if session has no group_id (single account). 66 + 67 + args: 68 + session_id: the session to look up 69 + db: optional database session to reuse (avoids new connection) 70 + """ 71 + if db is not None: 72 + return await _get_session_group_impl(session_id, db) 73 + 74 + async with db_session() as new_db: 75 + return await _get_session_group_impl(session_id, new_db) 76 + 77 + 78 + async def get_or_create_group_id(session_id: str) -> str: 79 + """get existing group_id or create one for this session. 80 + 81 + used when adding a second account to create a group. 82 + """ 83 + async with db_session() as db: 84 + result = await db.execute( 85 + select(UserSession).where(UserSession.session_id == session_id) 86 + ) 87 + session = result.scalar_one_or_none() 88 + 89 + if not session: 90 + raise HTTPException(status_code=404, detail="session not found") 91 + 92 + if session.group_id: 93 + return session.group_id 94 + 95 + # create new group_id for this session 96 + group_id = secrets.token_urlsafe(32) 97 + session.group_id = group_id 98 + await db.commit() 99 + 100 + return group_id 101 + 102 + 103 + async def _switch_active_account_impl( 104 + current_session_id: str, target_session_id: str, db: AsyncSession 105 + ) -> str: 106 + """implementation of switch_active_account using provided db session.""" 107 + result = await db.execute( 108 + select(UserSession).where(UserSession.session_id == current_session_id) 109 + ) 110 + current_session = result.scalar_one_or_none() 111 + 112 + if not current_session or not current_session.group_id: 113 + raise HTTPException(status_code=400, detail="no session group found") 114 + 115 + result = await db.execute( 116 + select(UserSession).where(UserSession.session_id == target_session_id) 117 + ) 118 + target_session = result.scalar_one_or_none() 119 + 120 + if not target_session: 121 + raise HTTPException(status_code=404, detail="target session not found") 122 + 123 + if target_session.group_id != current_session.group_id: 124 + raise HTTPException(status_code=403, detail="target session not in same group") 125 + 126 + if target_session.expires_at and datetime.now(UTC) > target_session.expires_at: 127 + raise HTTPException(status_code=401, detail="target session expired") 128 + 129 + return target_session_id 130 + 131 + 132 + async def switch_active_account( 133 + current_session_id: str, target_session_id: str, db: AsyncSession | None = None 134 + ) -> str: 135 + """switch to a different account within a session group. 136 + 137 + validates that the target session exists, is in the same group, and isn't expired. 138 + returns the target session_id (caller updates the cookie). 139 + 140 + args: 141 + current_session_id: the current session 142 + target_session_id: the session to switch to 143 + db: optional database session to reuse (avoids new connection) 144 + """ 145 + if db is not None: 146 + return await _switch_active_account_impl( 147 + current_session_id, target_session_id, db 148 + ) 149 + 150 + async with db_session() as new_db: 151 + return await _switch_active_account_impl( 152 + current_session_id, target_session_id, new_db 153 + ) 154 + 155 + 156 + async def remove_account_from_group(session_id: str) -> str | None: 157 + """remove a session from its group and delete it. 158 + 159 + returns session_id of another account in the group, or None if last account. 160 + """ 161 + async with db_session() as db: 162 + result = await db.execute( 163 + select(UserSession).where(UserSession.session_id == session_id) 164 + ) 165 + session = result.scalar_one_or_none() 166 + 167 + if not session: 168 + return None 169 + 170 + group_id = session.group_id 171 + 172 + await db.delete(session) 173 + await db.commit() 174 + 175 + if not group_id: 176 + return None 177 + 178 + result = await db.execute( 179 + select(UserSession).where( 180 + UserSession.group_id == group_id, 181 + UserSession.is_developer_token == False, # noqa: E712 182 + ) 183 + ) 184 + remaining = result.scalars().first() 185 + 186 + return remaining.session_id if remaining else None 187 + 188 + 189 + # pending add account flow helpers 190 + 191 + 192 + @dataclass 193 + class PendingAddAccountData: 194 + """metadata for a pending add-account OAuth flow.""" 195 + 196 + state: str 197 + group_id: str 198 + 199 + 200 + async def save_pending_add_account(state: str, group_id: str) -> None: 201 + """save pending add-account metadata keyed by OAuth state.""" 202 + from backend.models import PendingAddAccount 203 + 204 + async with db_session() as db: 205 + pending = PendingAddAccount( 206 + state=state, 207 + group_id=group_id, 208 + ) 209 + db.add(pending) 210 + await db.commit() 211 + 212 + 213 + async def get_pending_add_account(state: str) -> PendingAddAccountData | None: 214 + """get pending add-account metadata by OAuth state.""" 215 + from backend.models import PendingAddAccount 216 + 217 + async with db_session() as db: 218 + result = await db.execute( 219 + select(PendingAddAccount).where(PendingAddAccount.state == state) 220 + ) 221 + pending = result.scalar_one_or_none() 222 + 223 + if not pending: 224 + return None 225 + 226 + # check if expired 227 + if datetime.now(UTC) > pending.expires_at: 228 + await db.delete(pending) 229 + await db.commit() 230 + return None 231 + 232 + return PendingAddAccountData( 233 + state=pending.state, 234 + group_id=pending.group_id, 235 + ) 236 + 237 + 238 + async def delete_pending_add_account(state: str) -> None: 239 + """delete pending add-account metadata after use.""" 240 + from backend.models import PendingAddAccount 241 + 242 + async with db_session() as db: 243 + result = await db.execute( 244 + select(PendingAddAccount).where(PendingAddAccount.state == state) 245 + ) 246 + if pending := result.scalar_one_or_none(): 247 + await db.delete(pending) 248 + await db.commit() 249 + 250 + 251 + # scope upgrade flow helpers 252 + 253 + 254 + @dataclass 255 + class PendingScopeUpgradeData: 256 + """metadata for a pending scope upgrade OAuth flow.""" 257 + 258 + state: str 259 + did: str 260 + old_session_id: str 261 + requested_scopes: str 262 + 263 + 264 + async def save_pending_scope_upgrade( 265 + state: str, 266 + did: str, 267 + old_session_id: str, 268 + requested_scopes: str, 269 + ) -> None: 270 + """save pending scope upgrade metadata keyed by OAuth state.""" 271 + from backend.models import PendingScopeUpgrade 272 + 273 + async with db_session() as db: 274 + pending = PendingScopeUpgrade( 275 + state=state, 276 + did=did, 277 + old_session_id=old_session_id, 278 + requested_scopes=requested_scopes, 279 + ) 280 + db.add(pending) 281 + await db.commit() 282 + 283 + 284 + async def get_pending_scope_upgrade(state: str) -> PendingScopeUpgradeData | None: 285 + """get pending scope upgrade metadata by OAuth state.""" 286 + from backend.models import PendingScopeUpgrade 287 + 288 + async with db_session() as db: 289 + result = await db.execute( 290 + select(PendingScopeUpgrade).where(PendingScopeUpgrade.state == state) 291 + ) 292 + pending = result.scalar_one_or_none() 293 + 294 + if not pending: 295 + return None 296 + 297 + # check if expired 298 + if datetime.now(UTC) > pending.expires_at: 299 + await db.delete(pending) 300 + await db.commit() 301 + return None 302 + 303 + return PendingScopeUpgradeData( 304 + state=pending.state, 305 + did=pending.did, 306 + old_session_id=pending.old_session_id, 307 + requested_scopes=pending.requested_scopes, 308 + ) 309 + 310 + 311 + async def delete_pending_scope_upgrade(state: str) -> None: 312 + """delete pending scope upgrade metadata after use.""" 313 + from backend.models import PendingScopeUpgrade 314 + 315 + async with db_session() as db: 316 + result = await db.execute( 317 + select(PendingScopeUpgrade).where(PendingScopeUpgrade.state == state) 318 + ) 319 + if pending := result.scalar_one_or_none(): 320 + await db.delete(pending) 321 + await db.commit()
+106
backend/src/backend/_internal/auth/dependencies.py
··· 1 + """FastAPI dependencies: require_auth, get_optional_session, require_artist_profile.""" 2 + 3 + import logging 4 + from typing import Annotated 5 + 6 + from fastapi import Cookie, Header, HTTPException 7 + 8 + from backend._internal.auth.oauth import check_artist_profile_exists 9 + from backend._internal.auth.scopes import _check_scope_coverage, _get_missing_scopes 10 + from backend._internal.auth.session import Session, get_session 11 + from backend.config import settings 12 + 13 + logger = logging.getLogger(__name__) 14 + 15 + 16 + async def require_auth( 17 + authorization: Annotated[str | None, Header()] = None, 18 + session_id: Annotated[str | None, Cookie(alias="session_id")] = None, 19 + ) -> Session: 20 + """fastapi dependency to require authentication with expiration validation. 21 + 22 + checks cookie first (for browser requests), then falls back to Authorization 23 + header (for SDK/CLI clients). this enables secure HttpOnly cookies for browsers 24 + while maintaining bearer token support for API clients. 25 + 26 + also validates that the session's granted scopes cover all currently required 27 + scopes. if not, returns 403 with "scope_upgrade_required" to prompt re-login. 28 + """ 29 + session_id_value = None 30 + 31 + if session_id: 32 + session_id_value = session_id 33 + elif authorization and authorization.startswith("Bearer "): 34 + session_id_value = authorization.removeprefix("Bearer ") 35 + 36 + if not session_id_value: 37 + raise HTTPException( 38 + status_code=401, 39 + detail="not authenticated - login required", 40 + ) 41 + 42 + session = await get_session(session_id_value) 43 + if not session: 44 + raise HTTPException( 45 + status_code=401, 46 + detail="invalid or expired session", 47 + ) 48 + 49 + # check if session has all required scopes 50 + granted_scope = session.oauth_session.get("scope", "") 51 + required_scope = settings.atproto.resolved_scope 52 + 53 + if not _check_scope_coverage(granted_scope, required_scope): 54 + missing = _get_missing_scopes(granted_scope, required_scope) 55 + logger.info( 56 + f"session {session.did} missing scopes: {missing}, prompting re-auth" 57 + ) 58 + raise HTTPException( 59 + status_code=403, 60 + detail="scope_upgrade_required", 61 + ) 62 + 63 + return session 64 + 65 + 66 + async def get_optional_session( 67 + authorization: Annotated[str | None, Header()] = None, 68 + session_id: Annotated[str | None, Cookie(alias="session_id")] = None, 69 + ) -> Session | None: 70 + """fastapi dependency to optionally get the current session. 71 + 72 + returns None if not authenticated, otherwise returns the session. 73 + useful for public endpoints that show additional info for logged-in users. 74 + """ 75 + session_id_value = None 76 + 77 + if session_id: 78 + session_id_value = session_id 79 + elif authorization and authorization.startswith("Bearer "): 80 + session_id_value = authorization.removeprefix("Bearer ") 81 + 82 + if not session_id_value: 83 + return None 84 + 85 + return await get_session(session_id_value) 86 + 87 + 88 + async def require_artist_profile( 89 + authorization: Annotated[str | None, Header()] = None, 90 + session_id: Annotated[str | None, Cookie(alias="session_id")] = None, 91 + ) -> Session: 92 + """fastapi dependency to require authentication AND complete artist profile. 93 + 94 + Returns 403 with specific message if artist profile doesn't exist, 95 + prompting frontend to redirect to profile setup. 96 + """ 97 + session = await require_auth(authorization, session_id) 98 + 99 + # check if artist profile exists 100 + if not await check_artist_profile_exists(session.did): 101 + raise HTTPException( 102 + status_code=403, 103 + detail="artist_profile_required", 104 + ) 105 + 106 + return session
+153
backend/src/backend/_internal/auth/developer_tokens.py
··· 1 + """Developer token management: list, revoke, pending dev token flow.""" 2 + 3 + import json 4 + import logging 5 + from dataclasses import dataclass 6 + from datetime import UTC, datetime 7 + 8 + from sqlalchemy import select 9 + 10 + from backend._internal.auth.encryption import _decrypt_data 11 + from backend._internal.auth.session import _get_refresh_token_expires_at 12 + from backend.models import PendingDevToken, UserSession 13 + from backend.utilities.database import db_session 14 + 15 + logger = logging.getLogger(__name__) 16 + 17 + 18 + @dataclass 19 + class DeveloperToken: 20 + """developer token metadata (without sensitive session data).""" 21 + 22 + session_id: str 23 + token_name: str | None 24 + created_at: datetime 25 + expires_at: datetime | None 26 + 27 + 28 + async def list_developer_tokens(did: str) -> list[DeveloperToken]: 29 + """list all developer tokens for a user.""" 30 + async with db_session() as db: 31 + result = await db.execute( 32 + select(UserSession).where( 33 + UserSession.did == did, 34 + UserSession.is_developer_token == True, # noqa: E712 35 + ) 36 + ) 37 + sessions = result.scalars().all() 38 + 39 + tokens = [] 40 + now = datetime.now(UTC) 41 + for session in sessions: 42 + decrypted_data = _decrypt_data(session.oauth_session_data) 43 + oauth_session_data = ( 44 + json.loads(decrypted_data) if decrypted_data is not None else {} 45 + ) 46 + refresh_expires_at = _get_refresh_token_expires_at( 47 + session, oauth_session_data 48 + ) 49 + effective_expires_at = session.expires_at 50 + if refresh_expires_at and ( 51 + effective_expires_at is None 52 + or refresh_expires_at < effective_expires_at 53 + ): 54 + effective_expires_at = refresh_expires_at 55 + 56 + # check if expired 57 + if effective_expires_at and now > effective_expires_at: 58 + continue # skip expired tokens 59 + 60 + tokens.append( 61 + DeveloperToken( 62 + session_id=session.session_id, 63 + token_name=session.token_name, 64 + created_at=session.created_at, 65 + expires_at=effective_expires_at, 66 + ) 67 + ) 68 + 69 + return tokens 70 + 71 + 72 + async def revoke_developer_token(did: str, session_id: str) -> bool: 73 + """revoke a developer token. returns True if successful, False if not found.""" 74 + async with db_session() as db: 75 + result = await db.execute( 76 + select(UserSession).where( 77 + UserSession.session_id == session_id, 78 + UserSession.did == did, # ensure user owns this token 79 + UserSession.is_developer_token == True, # noqa: E712 80 + ) 81 + ) 82 + session = result.scalar_one_or_none() 83 + 84 + if not session: 85 + return False 86 + 87 + await db.delete(session) 88 + await db.commit() 89 + return True 90 + 91 + 92 + @dataclass 93 + class PendingDevTokenData: 94 + """metadata for a pending developer token OAuth flow.""" 95 + 96 + state: str 97 + did: str 98 + token_name: str | None 99 + expires_in_days: int 100 + 101 + 102 + async def save_pending_dev_token( 103 + state: str, 104 + did: str, 105 + token_name: str | None, 106 + expires_in_days: int, 107 + ) -> None: 108 + """save pending dev token metadata keyed by OAuth state.""" 109 + async with db_session() as db: 110 + pending = PendingDevToken( 111 + state=state, 112 + did=did, 113 + token_name=token_name, 114 + expires_in_days=expires_in_days, 115 + ) 116 + db.add(pending) 117 + await db.commit() 118 + 119 + 120 + async def get_pending_dev_token(state: str) -> PendingDevTokenData | None: 121 + """get pending dev token metadata by OAuth state.""" 122 + async with db_session() as db: 123 + result = await db.execute( 124 + select(PendingDevToken).where(PendingDevToken.state == state) 125 + ) 126 + pending = result.scalar_one_or_none() 127 + 128 + if not pending: 129 + return None 130 + 131 + # check if expired 132 + if datetime.now(UTC) > pending.expires_at: 133 + await db.delete(pending) 134 + await db.commit() 135 + return None 136 + 137 + return PendingDevTokenData( 138 + state=pending.state, 139 + did=pending.did, 140 + token_name=pending.token_name, 141 + expires_in_days=pending.expires_in_days, 142 + ) 143 + 144 + 145 + async def delete_pending_dev_token(state: str) -> None: 146 + """delete pending dev token metadata after use.""" 147 + async with db_session() as db: 148 + result = await db.execute( 149 + select(PendingDevToken).where(PendingDevToken.state == state) 150 + ) 151 + if pending := result.scalar_one_or_none(): 152 + await db.delete(pending) 153 + await db.commit()
+33
backend/src/backend/_internal/auth/encryption.py
··· 1 + """Fernet encryption for sensitive OAuth data at rest.""" 2 + 3 + from cryptography.fernet import Fernet 4 + 5 + from backend.config import settings 6 + 7 + # CRITICAL: encryption key must be configured and stable across restarts 8 + # otherwise all sessions become undecipherable after restart 9 + if not settings.atproto.oauth_encryption_key: 10 + raise RuntimeError( 11 + "oauth_encryption_key must be configured in settings. " 12 + "generate one with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'" 13 + ) 14 + 15 + _encryption_key = settings.atproto.oauth_encryption_key.encode() 16 + _fernet = Fernet(_encryption_key) 17 + 18 + 19 + def _encrypt_data(data: str) -> str: 20 + """encrypt sensitive data for storage.""" 21 + return _fernet.encrypt(data.encode()).decode() 22 + 23 + 24 + def _decrypt_data(encrypted: str) -> str | None: 25 + """decrypt sensitive data from storage. 26 + 27 + returns None if decryption fails (e.g., key changed, data corrupted). 28 + """ 29 + try: 30 + return _fernet.decrypt(encrypted.encode()).decode() 31 + except Exception: 32 + # decryption failed - likely key mismatch or corrupted data 33 + return None
+74
backend/src/backend/_internal/auth/exchange.py
··· 1 + """Exchange token creation and consumption.""" 2 + 3 + import secrets 4 + from datetime import UTC, datetime 5 + 6 + from sqlalchemy import select, update 7 + 8 + from backend.models import ExchangeToken 9 + from backend.utilities.database import db_session 10 + 11 + 12 + async def create_exchange_token(session_id: str, is_dev_token: bool = False) -> str: 13 + """create a one-time use exchange token for secure OAuth callback. 14 + 15 + exchange tokens expire after 60 seconds and can only be used once, 16 + preventing session_id exposure in browser history/referrers. 17 + 18 + args: 19 + session_id: the session to associate with this exchange token 20 + is_dev_token: if True, the exchange will not set a browser cookie 21 + """ 22 + token = secrets.token_urlsafe(32) 23 + 24 + async with db_session() as db: 25 + exchange_token = ExchangeToken( 26 + token=token, 27 + session_id=session_id, 28 + is_dev_token=is_dev_token, 29 + ) 30 + db.add(exchange_token) 31 + await db.commit() 32 + 33 + return token 34 + 35 + 36 + async def consume_exchange_token(token: str) -> tuple[str, bool] | None: 37 + """consume an exchange token and return (session_id, is_dev_token). 38 + 39 + returns None if token is invalid, expired, or already used. 40 + uses atomic UPDATE to prevent race conditions (token can only be used once). 41 + """ 42 + async with db_session() as db: 43 + # first, check if token exists and is not expired 44 + result = await db.execute( 45 + select(ExchangeToken).where(ExchangeToken.token == token) 46 + ) 47 + exchange_token = result.scalar_one_or_none() 48 + 49 + if not exchange_token: 50 + return None 51 + 52 + # check if expired 53 + if datetime.now(UTC) > exchange_token.expires_at: 54 + return None 55 + 56 + # capture is_dev_token before atomic update 57 + is_dev_token = exchange_token.is_dev_token 58 + 59 + # atomically mark as used ONLY if not already used 60 + # this prevents race conditions where two requests try to use the same token 61 + result = await db.execute( 62 + update(ExchangeToken) 63 + .where(ExchangeToken.token == token, ExchangeToken.used == False) # noqa: E712 64 + .values(used=True) 65 + .returning(ExchangeToken.session_id) 66 + ) 67 + await db.commit() 68 + 69 + # if no rows were updated, token was already used 70 + session_id = result.scalar_one_or_none() 71 + if session_id is None: 72 + return None 73 + 74 + return session_id, is_dev_token
+486
backend/src/backend/_internal/auth/oauth.py
··· 1 + """OAuth client config, flows, callback, and artist profile.""" 2 + 3 + import json 4 + import logging 5 + import secrets 6 + import time 7 + from datetime import UTC, datetime 8 + 9 + from atproto_oauth import OAuthClient, OAuthState, PromptType 10 + from atproto_oauth.client import ( 11 + discover_authserver_from_pds_async, 12 + fetch_authserver_metadata_async, 13 + ) 14 + from atproto_oauth.dpop import DPoPManager 15 + from atproto_oauth.pkce import PKCEManager 16 + from atproto_oauth.stores.memory import MemorySessionStore 17 + from cryptography.hazmat.primitives.asymmetric import ec 18 + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey 19 + from cryptography.hazmat.primitives.serialization import load_pem_private_key 20 + from fastapi import HTTPException 21 + from jose import jwk 22 + from sqlalchemy import select 23 + 24 + from backend._internal.auth.session import ( 25 + _check_teal_preference, 26 + _compute_refresh_token_expires_at, 27 + get_client_auth_method, 28 + get_refresh_token_lifetime_days, 29 + ) 30 + from backend._internal.oauth_stores import PostgresStateStore 31 + from backend.config import settings 32 + from backend.models import Artist 33 + from backend.utilities.database import db_session 34 + 35 + logger = logging.getLogger(__name__) 36 + 37 + # OAuth stores 38 + # state store: postgres-backed for multi-instance resilience 39 + # session store: in-memory (not used, we use UserSession table instead) 40 + _state_store = PostgresStateStore() 41 + _session_store = MemorySessionStore() 42 + 43 + # confidential client key (loaded lazily) 44 + _client_secret_key: EllipticCurvePrivateKey | None = None 45 + _client_secret_kid: str | None = None 46 + _client_secret_key_loaded = False 47 + 48 + 49 + def _load_client_secret() -> tuple[EllipticCurvePrivateKey | None, str | None]: 50 + """load EC private key and kid from OAUTH_JWK setting. 51 + 52 + returns (None, None) if OAUTH_JWK is not configured (public client mode). 53 + """ 54 + global _client_secret_key, _client_secret_kid, _client_secret_key_loaded 55 + 56 + if _client_secret_key_loaded: 57 + return _client_secret_key, _client_secret_kid 58 + 59 + _client_secret_key_loaded = True 60 + 61 + if not settings.atproto.oauth_jwk: 62 + logger.info("OAUTH_JWK not configured, using public OAuth client") 63 + return None, None 64 + 65 + try: 66 + # parse JWK JSON 67 + jwk_data = json.loads(settings.atproto.oauth_jwk) 68 + 69 + # extract kid (required for client assertions) 70 + _client_secret_kid = jwk_data.get("kid") 71 + if not _client_secret_kid: 72 + raise ValueError("OAUTH_JWK must include 'kid' field") 73 + 74 + # convert JWK to PEM format using python-jose 75 + key_obj = jwk.construct(jwk_data, algorithm="ES256") 76 + pem_bytes = key_obj.to_pem() 77 + 78 + # load as cryptography key 79 + loaded_key = load_pem_private_key(pem_bytes, password=None) 80 + 81 + if not isinstance(loaded_key, ec.EllipticCurvePrivateKey): 82 + raise ValueError("OAUTH_JWK must be an EC key (ES256)") 83 + 84 + _client_secret_key = loaded_key 85 + logger.info(f"loaded confidential OAuth client key (kid={_client_secret_kid})") 86 + return _client_secret_key, _client_secret_kid 87 + 88 + except Exception as e: 89 + logger.error(f"failed to load OAUTH_JWK: {e}") 90 + raise RuntimeError(f"invalid OAUTH_JWK configuration: {e}") from e 91 + 92 + 93 + def get_public_jwks() -> dict | None: 94 + """get public JWKS for the /.well-known/jwks.json endpoint. 95 + 96 + returns None if confidential client is not configured. 97 + """ 98 + if not settings.atproto.oauth_jwk: 99 + return None 100 + 101 + try: 102 + # parse private JWK 103 + jwk_data = json.loads(settings.atproto.oauth_jwk) 104 + 105 + # construct key and extract public components 106 + key_obj = jwk.construct(jwk_data, algorithm="ES256") 107 + public_jwk = key_obj.to_dict() 108 + 109 + # remove private key components, keep only public 110 + public_jwk.pop("d", None) # private key scalar 111 + 112 + # ensure required fields for public key 113 + public_jwk["use"] = "sig" 114 + public_jwk["alg"] = "ES256" 115 + 116 + # preserve kid from original JWK (python-jose's to_dict() doesn't include it) 117 + if "kid" in jwk_data: 118 + public_jwk["kid"] = jwk_data["kid"] 119 + 120 + return {"keys": [public_jwk]} 121 + 122 + except Exception as e: 123 + logger.error(f"failed to generate public JWKS: {e}") 124 + return None 125 + 126 + 127 + def get_oauth_client(include_teal: bool = False) -> OAuthClient: 128 + """create an OAuth client with the appropriate scopes. 129 + 130 + if OAUTH_JWK is configured, creates a confidential client with 131 + private_key_jwt authentication. otherwise creates a public client. 132 + """ 133 + scope = ( 134 + settings.atproto.resolved_scope_with_teal( 135 + settings.teal.play_collection, settings.teal.status_collection 136 + ) 137 + if include_teal 138 + else settings.atproto.resolved_scope 139 + ) 140 + 141 + # load confidential client key if configured 142 + client_secret_key, client_secret_kid = _load_client_secret() 143 + 144 + return OAuthClient( 145 + client_id=settings.atproto.client_id, 146 + redirect_uri=settings.atproto.redirect_uri, 147 + scope=scope, 148 + state_store=_state_store, 149 + session_store=_session_store, 150 + client_secret_key=client_secret_key, 151 + client_secret_kid=client_secret_kid, 152 + ) 153 + 154 + 155 + def get_oauth_client_for_scope(scope: str) -> OAuthClient: 156 + """get an OAuth client matching a given scope string. 157 + 158 + used during callback to match the scope that was used during authorization. 159 + """ 160 + include_teal = settings.teal.play_collection in scope 161 + return get_oauth_client(include_teal=include_teal) 162 + 163 + 164 + async def start_oauth_flow( 165 + handle: str, prompt: PromptType | None = None 166 + ) -> tuple[str, str]: 167 + """start OAuth flow and return (auth_url, state). 168 + 169 + uses extended scope if user has enabled teal.fm scrobbling. 170 + """ 171 + from backend._internal.atproto.handles import resolve_handle 172 + 173 + try: 174 + # resolve handle to DID to check preferences 175 + resolved = await resolve_handle(handle) 176 + if resolved: 177 + did = resolved["did"] 178 + wants_teal = await _check_teal_preference(did) 179 + client = get_oauth_client(include_teal=wants_teal) 180 + logger.info(f"starting OAuth for {handle} (did={did}, teal={wants_teal})") 181 + else: 182 + # fallback to base client if resolution fails 183 + # (OAuth flow will resolve handle again internally) 184 + client = get_oauth_client(include_teal=False) 185 + logger.info(f"starting OAuth for {handle} (resolution failed, using base)") 186 + 187 + auth_url, state = await client.start_authorization(handle, prompt=prompt) 188 + return auth_url, state 189 + except Exception as e: 190 + raise HTTPException( 191 + status_code=400, 192 + detail=f"failed to start OAuth flow: {e}", 193 + ) from e 194 + 195 + 196 + async def start_oauth_flow_with_scopes( 197 + handle: str, include_teal: bool, prompt: PromptType | None = None 198 + ) -> tuple[str, str]: 199 + """start OAuth flow with explicit scope selection (used for scope upgrades).""" 200 + try: 201 + client = get_oauth_client(include_teal=include_teal) 202 + logger.info(f"starting scope upgrade OAuth for {handle} (teal={include_teal})") 203 + auth_url, state = await client.start_authorization(handle, prompt=prompt) 204 + return auth_url, state 205 + except Exception as e: 206 + raise HTTPException( 207 + status_code=400, 208 + detail=f"failed to start OAuth flow: {e}", 209 + ) from e 210 + 211 + 212 + async def start_oauth_flow_for_pds(pds_url: str) -> tuple[str, str]: 213 + """start OAuth flow for account creation on a PDS. 214 + 215 + discovers auth server from PDS URL and sends PAR with prompt=create. 216 + """ 217 + from urllib.parse import urlencode 218 + 219 + import httpx 220 + 221 + try: 222 + pds_url = pds_url.rstrip("/") 223 + 224 + # discover auth server from PDS 225 + authserver_url = await discover_authserver_from_pds_async(pds_url) 226 + authserver_url = authserver_url.rstrip("/") 227 + 228 + # fetch auth server metadata 229 + authserver_meta = await fetch_authserver_metadata_async(authserver_url) 230 + 231 + # get OAuth client for scope/keys 232 + client = get_oauth_client(include_teal=False) 233 + 234 + # generate PKCE and DPoP 235 + pkce = PKCEManager() 236 + pkce_verifier, pkce_challenge = pkce.generate_pair() 237 + dpop = DPoPManager() 238 + dpop_key = dpop.generate_keypair() 239 + state_token = secrets.token_urlsafe(32) 240 + 241 + # build PAR request with prompt=create and no login_hint 242 + par_url = authserver_meta.pushed_authorization_request_endpoint 243 + params: dict[str, str] = { 244 + "response_type": "code", 245 + "code_challenge": pkce_challenge, 246 + "code_challenge_method": "S256", 247 + "state": state_token, 248 + "redirect_uri": client.redirect_uri, 249 + "scope": client.scope, 250 + "client_id": client.client_id, 251 + "prompt": "create", 252 + } 253 + 254 + # add client authentication if confidential client 255 + client_secret_key, client_secret_kid = _load_client_secret() 256 + if client_secret_key and client_secret_kid: 257 + client_assertion = _create_client_assertion( 258 + client.client_id, 259 + authserver_meta.issuer, 260 + client_secret_key, 261 + client_secret_kid, 262 + ) 263 + params["client_assertion_type"] = ( 264 + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" 265 + ) 266 + params["client_assertion"] = client_assertion 267 + 268 + # make PAR request with DPoP nonce retry 269 + dpop_nonce = "" 270 + for attempt in range(2): 271 + dpop_proof = dpop.create_proof( 272 + method="POST", 273 + url=par_url, 274 + private_key=dpop_key, 275 + nonce=dpop_nonce if dpop_nonce else None, 276 + ) 277 + 278 + async with httpx.AsyncClient() as http: 279 + response = await http.post( 280 + par_url, data=params, headers={"DPoP": dpop_proof} 281 + ) 282 + 283 + if dpop.is_dpop_nonce_error(response): 284 + new_nonce = dpop.extract_nonce_from_response(response) 285 + if new_nonce and attempt == 0: 286 + dpop_nonce = new_nonce 287 + continue 288 + 289 + dpop_nonce = dpop.extract_nonce_from_response(response) or dpop_nonce 290 + break 291 + 292 + if response.status_code not in (200, 201): 293 + raise HTTPException( 294 + status_code=400, 295 + detail=f"PAR request failed: {response.status_code} {response.text}", 296 + ) 297 + 298 + par_response = response.json() 299 + request_uri = par_response["request_uri"] 300 + 301 + # store state with did=None (unknown until account created) 302 + oauth_state = OAuthState( 303 + state=state_token, 304 + pkce_verifier=pkce_verifier, 305 + redirect_uri=client.redirect_uri, 306 + scope=client.scope, 307 + authserver_iss=authserver_meta.issuer, 308 + dpop_private_key=dpop_key, 309 + dpop_authserver_nonce=dpop_nonce, 310 + did=None, 311 + handle=None, 312 + pds_url=pds_url, 313 + ) 314 + await _state_store.save_state(oauth_state) 315 + 316 + # build authorization URL 317 + auth_params = {"client_id": client.client_id, "request_uri": request_uri} 318 + auth_url = f"{authserver_meta.authorization_endpoint}?{urlencode(auth_params)}" 319 + 320 + logger.info(f"starting account creation OAuth for PDS {pds_url}") 321 + return auth_url, state_token 322 + 323 + except Exception as e: 324 + if isinstance(e, HTTPException): 325 + raise 326 + raise HTTPException( 327 + status_code=400, 328 + detail=f"failed to start account creation OAuth: {e}", 329 + ) from e 330 + 331 + 332 + def _create_client_assertion( 333 + client_id: str, 334 + audience: str, 335 + private_key: EllipticCurvePrivateKey, 336 + kid: str, 337 + ) -> str: 338 + """create client assertion JWT for confidential client.""" 339 + header = {"alg": "ES256", "typ": "JWT", "kid": kid} 340 + now = int(time.time()) 341 + payload = { 342 + "iss": client_id, 343 + "sub": client_id, 344 + "aud": audience, 345 + "jti": secrets.token_urlsafe(16), 346 + "iat": now, 347 + "exp": now + 60, 348 + } 349 + 350 + dpop = DPoPManager() 351 + return dpop._sign_jwt(header, payload, private_key) 352 + 353 + 354 + async def _resolve_handle_from_pds(pds_url: str, did: str) -> str | None: 355 + """resolve handle from PDS when OAuth doesn't return it. 356 + 357 + this happens for newly created accounts on third-party PDSes where 358 + the handle isn't yet indexed by the Bluesky AppView. 359 + """ 360 + import httpx 361 + 362 + try: 363 + async with httpx.AsyncClient() as client: 364 + resp = await client.get( 365 + f"{pds_url}/xrpc/com.atproto.repo.describeRepo", 366 + params={"repo": did}, 367 + timeout=10.0, 368 + ) 369 + if resp.status_code == 200: 370 + data = resp.json() 371 + handle = data.get("handle") 372 + if handle: 373 + logger.info(f"resolved handle from PDS: {handle}") 374 + return handle 375 + except Exception as e: 376 + logger.warning(f"failed to resolve handle from PDS: {e}") 377 + return None 378 + 379 + 380 + async def handle_oauth_callback( 381 + code: str, state: str, iss: str 382 + ) -> tuple[str, str, dict]: 383 + """handle OAuth callback and return (did, handle, oauth_session). 384 + 385 + uses the appropriate OAuth client based on stored state's scope. 386 + """ 387 + try: 388 + # look up stored state to determine which scope was used 389 + if stored_state := await _state_store.get_state(state): 390 + client = get_oauth_client_for_scope(stored_state.scope) 391 + logger.info( 392 + f"callback using client for scope: {stored_state.scope[:50]}..." 393 + ) 394 + else: 395 + # fallback to base client (state might have been cleaned up) 396 + client = get_oauth_client(include_teal=False) 397 + logger.warning(f"state {state[:8]}... not found, using base client") 398 + 399 + oauth_session = await client.handle_callback( 400 + code=code, 401 + state=state, 402 + iss=iss, 403 + ) 404 + 405 + # resolve handle from PDS if not provided by OAuth 406 + # (happens for newly created accounts on third-party PDSes) 407 + handle = oauth_session.handle 408 + if not handle: 409 + handle = ( 410 + await _resolve_handle_from_pds(oauth_session.pds_url, oauth_session.did) 411 + or "" 412 + ) 413 + 414 + # serialize DPoP private key for storage 415 + from cryptography.hazmat.primitives import serialization 416 + 417 + dpop_key_pem = oauth_session.dpop_private_key.private_bytes( 418 + encoding=serialization.Encoding.PEM, 419 + format=serialization.PrivateFormat.PKCS8, 420 + encryption_algorithm=serialization.NoEncryption(), 421 + ).decode("utf-8") 422 + 423 + client_auth_method = get_client_auth_method() 424 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 425 + refresh_expires_at = _compute_refresh_token_expires_at( 426 + datetime.now(UTC), client_auth_method 427 + ) 428 + 429 + # store full OAuth session with tokens in database 430 + session_data = { 431 + "did": oauth_session.did, 432 + "handle": handle, 433 + "pds_url": oauth_session.pds_url, 434 + "authserver_iss": oauth_session.authserver_iss, 435 + "scope": oauth_session.scope, 436 + "access_token": oauth_session.access_token, 437 + "refresh_token": oauth_session.refresh_token, 438 + "dpop_private_key_pem": dpop_key_pem, 439 + "dpop_authserver_nonce": oauth_session.dpop_authserver_nonce, 440 + "dpop_pds_nonce": oauth_session.dpop_pds_nonce or "", 441 + "client_auth_method": client_auth_method, 442 + "refresh_token_lifetime_days": refresh_lifetime_days, 443 + "refresh_token_expires_at": refresh_expires_at.isoformat(), 444 + } 445 + return oauth_session.did, handle, session_data 446 + except Exception as e: 447 + raise HTTPException( 448 + status_code=401, 449 + detail=f"OAuth callback failed: {e}", 450 + ) from e 451 + 452 + 453 + async def check_artist_profile_exists(did: str) -> bool: 454 + """check if artist profile exists for a DID.""" 455 + async with db_session() as db: 456 + result = await db.execute(select(Artist).where(Artist.did == did)) 457 + artist = result.scalar_one_or_none() 458 + return artist is not None 459 + 460 + 461 + async def ensure_artist_exists(did: str, handle: str) -> bool: 462 + """ensure an Artist record exists for the given DID, creating one if needed. 463 + 464 + returns True if artist was created, False if it already existed. 465 + """ 466 + from backend._internal.atproto.profile import fetch_user_avatar 467 + 468 + async with db_session() as db: 469 + result = await db.execute(select(Artist).where(Artist.did == did)) 470 + if result.scalar_one_or_none(): 471 + return False # already exists 472 + 473 + # fetch avatar from Bluesky 474 + avatar_url = await fetch_user_avatar(did) 475 + 476 + # create minimal artist record 477 + artist = Artist( 478 + did=did, 479 + handle=handle, 480 + display_name=handle, # use handle as initial display name 481 + avatar_url=avatar_url, 482 + ) 483 + db.add(artist) 484 + await db.commit() 485 + logger.info(f"created minimal artist record for {did} (@{handle})") 486 + return True
+29
backend/src/backend/_internal/auth/scopes.py
··· 1 + """OAuth scope parsing and validation.""" 2 + 3 + 4 + def _parse_scopes(scope_string: str) -> set[str]: 5 + """parse an OAuth scope string into a set of individual scopes. 6 + 7 + handles format like: "atproto repo:fm.plyr.track repo:fm.plyr.like" 8 + returns: {"repo:fm.plyr.track", "repo:fm.plyr.like"} 9 + """ 10 + parts = scope_string.split() 11 + # filter out the "atproto" prefix and keep just the repo: scopes 12 + return {p for p in parts if p.startswith("repo:")} 13 + 14 + 15 + def _check_scope_coverage(granted_scope: str, required_scope: str) -> bool: 16 + """check if granted scope covers all required scopes. 17 + 18 + returns True if the session has all required permissions. 19 + """ 20 + granted = _parse_scopes(granted_scope) 21 + required = _parse_scopes(required_scope) 22 + return required.issubset(granted) 23 + 24 + 25 + def _get_missing_scopes(granted_scope: str, required_scope: str) -> set[str]: 26 + """get the scopes that are required but not granted.""" 27 + granted = _parse_scopes(granted_scope) 28 + required = _parse_scopes(required_scope) 29 + return required - granted
+232
backend/src/backend/_internal/auth/session.py
··· 1 + """Session dataclass, CRUD, token update, and teal check.""" 2 + 3 + import json 4 + import logging 5 + import secrets 6 + from dataclasses import dataclass 7 + from datetime import UTC, datetime, timedelta 8 + from typing import Any 9 + 10 + from sqlalchemy import select 11 + 12 + from backend._internal.auth.encryption import _decrypt_data, _encrypt_data 13 + from backend.config import settings 14 + from backend.models import UserPreferences, UserSession 15 + from backend.utilities.database import db_session 16 + 17 + logger = logging.getLogger(__name__) 18 + 19 + PUBLIC_REFRESH_TOKEN_DAYS = 14 20 + CONFIDENTIAL_REFRESH_TOKEN_DAYS = 180 21 + 22 + 23 + @dataclass 24 + class Session: 25 + """authenticated user session.""" 26 + 27 + session_id: str 28 + did: str 29 + handle: str 30 + oauth_session: dict # store OAuth session data 31 + 32 + def get_oauth_session_id(self) -> str: 33 + """extract OAuth session ID for retrieving from session store.""" 34 + return self.oauth_session.get("session_id", self.did) 35 + 36 + 37 + def is_confidential_client() -> bool: 38 + """check if confidential OAuth client is configured.""" 39 + return bool(settings.atproto.oauth_jwk) 40 + 41 + 42 + def get_client_auth_method(oauth_session_data: dict[str, Any] | None = None) -> str: 43 + """resolve client auth method for a session.""" 44 + if oauth_session_data: 45 + method = oauth_session_data.get("client_auth_method") 46 + if method in {"public", "confidential"}: 47 + return method 48 + return "confidential" if is_confidential_client() else "public" 49 + 50 + 51 + def get_refresh_token_lifetime_days(client_auth_method: str | None) -> int: 52 + """get expected refresh token lifetime in days.""" 53 + method = client_auth_method or get_client_auth_method() 54 + return ( 55 + CONFIDENTIAL_REFRESH_TOKEN_DAYS 56 + if method == "confidential" 57 + else PUBLIC_REFRESH_TOKEN_DAYS 58 + ) 59 + 60 + 61 + def _compute_refresh_token_expires_at( 62 + now: datetime, client_auth_method: str | None 63 + ) -> datetime: 64 + """compute refresh token expiration time.""" 65 + return now + timedelta(days=get_refresh_token_lifetime_days(client_auth_method)) 66 + 67 + 68 + def _parse_datetime(value: str | None) -> datetime | None: 69 + """parse ISO datetime string safely.""" 70 + if not value: 71 + return None 72 + try: 73 + return datetime.fromisoformat(value) 74 + except ValueError: 75 + return None 76 + 77 + 78 + def _get_refresh_token_expires_at( 79 + user_session: UserSession, 80 + oauth_session_data: dict[str, Any], 81 + ) -> datetime | None: 82 + """determine refresh token expiry for a session.""" 83 + parsed = _parse_datetime(oauth_session_data.get("refresh_token_expires_at")) 84 + if parsed: 85 + return parsed 86 + 87 + client_auth_method = oauth_session_data.get("client_auth_method") 88 + if client_auth_method: 89 + return user_session.created_at + timedelta( 90 + days=get_refresh_token_lifetime_days(client_auth_method) 91 + ) 92 + 93 + if user_session.is_developer_token: 94 + return user_session.created_at + timedelta(days=PUBLIC_REFRESH_TOKEN_DAYS) 95 + 96 + return None 97 + 98 + 99 + async def create_session( 100 + did: str, 101 + handle: str, 102 + oauth_session: dict[str, Any], 103 + expires_in_days: int = 14, 104 + is_developer_token: bool = False, 105 + token_name: str | None = None, 106 + group_id: str | None = None, 107 + ) -> str: 108 + """create a new session for authenticated user with encrypted OAuth data. 109 + 110 + args: 111 + did: user's decentralized identifier 112 + handle: user's ATProto handle 113 + oauth_session: OAuth session data to encrypt and store 114 + expires_in_days: session expiration in days (default 14, capped by refresh lifetime) 115 + is_developer_token: whether this is a developer token (for listing/revocation) 116 + token_name: optional name for the token (only for developer tokens) 117 + group_id: optional session group ID for multi-account support 118 + """ 119 + session_id = secrets.token_urlsafe(32) 120 + now = datetime.now(UTC) 121 + 122 + client_auth_method = get_client_auth_method(oauth_session) 123 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 124 + refresh_expires_at = _compute_refresh_token_expires_at(now, client_auth_method) 125 + 126 + oauth_session = dict(oauth_session) 127 + oauth_session.setdefault("client_auth_method", client_auth_method) 128 + oauth_session.setdefault("refresh_token_lifetime_days", refresh_lifetime_days) 129 + oauth_session.setdefault("refresh_token_expires_at", refresh_expires_at.isoformat()) 130 + 131 + effective_days = ( 132 + refresh_lifetime_days 133 + if expires_in_days <= 0 134 + else min(expires_in_days, refresh_lifetime_days) 135 + ) 136 + expires_at = now + timedelta(days=effective_days) 137 + 138 + encrypted_data = _encrypt_data(json.dumps(oauth_session)) 139 + 140 + async with db_session() as db: 141 + user_session = UserSession( 142 + session_id=session_id, 143 + did=did, 144 + handle=handle, 145 + oauth_session_data=encrypted_data, 146 + expires_at=expires_at, 147 + is_developer_token=is_developer_token, 148 + token_name=token_name, 149 + group_id=group_id, 150 + ) 151 + db.add(user_session) 152 + await db.commit() 153 + 154 + return session_id 155 + 156 + 157 + async def get_session(session_id: str) -> Session | None: 158 + """retrieve session by id, decrypt OAuth data, and validate expiration.""" 159 + async with db_session() as db: 160 + result = await db.execute( 161 + select(UserSession).where(UserSession.session_id == session_id) 162 + ) 163 + if not (user_session := result.scalar_one_or_none()): 164 + return None 165 + 166 + # check if session is expired 167 + if user_session.expires_at and datetime.now(UTC) > user_session.expires_at: 168 + # session expired - delete it and return None 169 + await delete_session(session_id) 170 + return None 171 + 172 + # decrypt OAuth session data 173 + decrypted_data = _decrypt_data(user_session.oauth_session_data) 174 + if decrypted_data is None: 175 + # decryption failed - session is invalid (key changed or data corrupted) 176 + # delete the corrupted session 177 + await delete_session(session_id) 178 + return None 179 + 180 + oauth_session_data = json.loads(decrypted_data) 181 + 182 + refresh_expires_at = _get_refresh_token_expires_at( 183 + user_session, oauth_session_data 184 + ) 185 + if refresh_expires_at and datetime.now(UTC) > refresh_expires_at: 186 + await delete_session(session_id) 187 + return None 188 + 189 + return Session( 190 + session_id=user_session.session_id, 191 + did=user_session.did, 192 + handle=user_session.handle, 193 + oauth_session=oauth_session_data, 194 + ) 195 + 196 + 197 + async def update_session_tokens( 198 + session_id: str, oauth_session_data: dict[str, Any] 199 + ) -> None: 200 + """update OAuth session data for a session (e.g., after token refresh).""" 201 + async with db_session() as db: 202 + result = await db.execute( 203 + select(UserSession).where(UserSession.session_id == session_id) 204 + ) 205 + if user_session := result.scalar_one_or_none(): 206 + # encrypt updated OAuth session data 207 + encrypted_data = _encrypt_data(json.dumps(oauth_session_data)) 208 + user_session.oauth_session_data = encrypted_data 209 + await db.commit() 210 + 211 + 212 + async def delete_session(session_id: str) -> None: 213 + """delete a session.""" 214 + async with db_session() as db: 215 + result = await db.execute( 216 + select(UserSession).where(UserSession.session_id == session_id) 217 + ) 218 + if user_session := result.scalar_one_or_none(): 219 + await db.delete(user_session) 220 + await db.commit() 221 + 222 + 223 + async def _check_teal_preference(did: str) -> bool: 224 + """check if user has enabled teal.fm scrobbling.""" 225 + async with db_session() as db: 226 + result = await db.execute( 227 + select(UserPreferences.enable_teal_scrobbling).where( 228 + UserPreferences.did == did 229 + ) 230 + ) 231 + pref = result.scalar_one_or_none() 232 + return pref is True
+1 -1
backend/src/backend/_internal/background.py
··· 113 113 tasks must be registered before they can be executed by workers. 114 114 new tasks should be added to background_tasks.background_tasks list. 115 115 """ 116 - docket.register_collection("backend._internal.background_tasks:background_tasks") 116 + docket.register_collection("backend._internal.tasks:background_tasks") 117 117 118 118 logger.info("registered background tasks")
-803
backend/src/backend/_internal/background_tasks.py
··· 1 - """background task functions for docket. 2 - 3 - these functions are registered with docket and executed by workers. 4 - they should be self-contained and handle their own database sessions. 5 - 6 - requires DOCKET_URL to be set (Redis is always available). 7 - """ 8 - 9 - import logging 10 - from datetime import UTC, datetime, timedelta 11 - 12 - import httpx 13 - import logfire 14 - from docket import Perpetual 15 - from sqlalchemy import select 16 - 17 - from backend._internal.atproto.records import ( 18 - create_comment_record, 19 - create_like_record, 20 - delete_record_by_uri, 21 - update_comment_record, 22 - ) 23 - from backend._internal.auth import get_session 24 - from backend._internal.background import get_docket 25 - from backend._internal.clap_client import get_clap_client 26 - from backend._internal.export_tasks import process_export 27 - from backend._internal.pds_backfill_tasks import backfill_tracks_to_pds 28 - from backend._internal.tpuf_client import upsert as tpuf_upsert 29 - from backend.config import settings 30 - from backend.models import Artist, CopyrightScan, Track, TrackComment, TrackLike 31 - from backend.utilities.database import db_session 32 - 33 - logger = logging.getLogger(__name__) 34 - 35 - 36 - async def scan_copyright(track_id: int, audio_url: str) -> None: 37 - """scan a track for potential copyright matches. 38 - 39 - args: 40 - track_id: database ID of the track to scan 41 - audio_url: public URL of the audio file (R2) 42 - """ 43 - from backend._internal.moderation import scan_track_for_copyright 44 - 45 - await scan_track_for_copyright(track_id, audio_url) 46 - 47 - 48 - async def schedule_copyright_scan(track_id: int, audio_url: str) -> None: 49 - """schedule a copyright scan via docket.""" 50 - docket = get_docket() 51 - await docket.add(scan_copyright)(track_id, audio_url) 52 - logfire.info("scheduled copyright scan", track_id=track_id) 53 - 54 - 55 - async def sync_copyright_resolutions( 56 - perpetual: Perpetual = Perpetual(every=timedelta(minutes=5), automatic=True), # noqa: B008 57 - ) -> None: 58 - """sync resolution status from labeler to backend database. 59 - 60 - finds tracks that are flagged but have no resolution, checks the labeler 61 - to see if the labels were negated (dismissed), and marks them as resolved. 62 - 63 - this replaces the lazy reconciliation that was happening on read paths. 64 - runs automatically every 5 minutes via docket's Perpetual. 65 - """ 66 - from backend._internal.moderation_client import get_moderation_client 67 - 68 - async with db_session() as db: 69 - # find flagged scans with AT URIs that haven't been resolved 70 - result = await db.execute( 71 - select(CopyrightScan, Track.atproto_record_uri) 72 - .join(Track, CopyrightScan.track_id == Track.id) 73 - .where( 74 - CopyrightScan.is_flagged == True, # noqa: E712 75 - Track.atproto_record_uri.isnot(None), 76 - ) 77 - ) 78 - rows = result.all() 79 - 80 - if not rows: 81 - logfire.debug("sync_copyright_resolutions: no flagged scans to check") 82 - return 83 - 84 - # batch check with labeler 85 - scan_by_uri: dict[str, CopyrightScan] = {} 86 - for scan, uri in rows: 87 - if uri: 88 - scan_by_uri[uri] = scan 89 - 90 - if not scan_by_uri: 91 - return 92 - 93 - client = get_moderation_client() 94 - active_uris = await client.get_active_labels(list(scan_by_uri.keys())) 95 - 96 - # find scans that are no longer active (label was negated) 97 - resolved_count = 0 98 - for uri, scan in scan_by_uri.items(): 99 - if uri not in active_uris: 100 - # label was negated - track is no longer flagged 101 - scan.is_flagged = False 102 - resolved_count += 1 103 - 104 - if resolved_count > 0: 105 - await db.commit() 106 - logfire.info( 107 - "sync_copyright_resolutions: resolved {count} scans", 108 - count=resolved_count, 109 - ) 110 - else: 111 - logfire.debug( 112 - "sync_copyright_resolutions: checked {count} scans, none resolved", 113 - count=len(scan_by_uri), 114 - ) 115 - 116 - 117 - async def schedule_copyright_resolution_sync() -> None: 118 - """schedule a copyright resolution sync via docket.""" 119 - docket = get_docket() 120 - await docket.add(sync_copyright_resolutions)() 121 - logfire.info("scheduled copyright resolution sync") 122 - 123 - 124 - async def sync_atproto(session_id: str, user_did: str) -> None: 125 - """sync ATProto records (profile, albums, liked tracks) for a user. 126 - 127 - this runs after login or scope upgrade to ensure the user's PDS 128 - has up-to-date records for their plyr.fm data. 129 - 130 - args: 131 - session_id: the user's session ID for authentication 132 - user_did: the user's DID 133 - """ 134 - from backend._internal.atproto.sync import sync_atproto_records 135 - from backend._internal.auth import get_session 136 - 137 - auth_session = await get_session(session_id) 138 - if not auth_session: 139 - logger.warning(f"sync_atproto: session {session_id[:8]}... not found") 140 - return 141 - 142 - await sync_atproto_records(auth_session, user_did) 143 - 144 - 145 - async def schedule_atproto_sync(session_id: str, user_did: str) -> None: 146 - """schedule an ATProto sync via docket.""" 147 - docket = get_docket() 148 - await docket.add(sync_atproto)(session_id, user_did) 149 - logfire.info("scheduled atproto sync", user_did=user_did) 150 - 151 - 152 - async def scrobble_to_teal( 153 - session_id: str, 154 - track_id: int, 155 - track_title: str, 156 - artist_name: str, 157 - duration: int | None, 158 - album_name: str | None, 159 - ) -> None: 160 - """scrobble a play to teal.fm (creates play record + updates status). 161 - 162 - args: 163 - session_id: the user's session ID for authentication 164 - track_id: database ID of the track 165 - track_title: title of the track 166 - artist_name: name of the artist 167 - duration: track duration in seconds 168 - album_name: album name (optional) 169 - """ 170 - from backend._internal.atproto.teal import ( 171 - create_teal_play_record, 172 - update_teal_status, 173 - ) 174 - from backend._internal.auth import get_session 175 - from backend.config import settings 176 - 177 - auth_session = await get_session(session_id) 178 - if not auth_session: 179 - logger.warning(f"teal scrobble: session {session_id[:8]}... not found") 180 - return 181 - 182 - origin_url = f"{settings.frontend.url}/track/{track_id}" 183 - 184 - try: 185 - # create play record (scrobble) 186 - play_uri = await create_teal_play_record( 187 - auth_session=auth_session, 188 - track_name=track_title, 189 - artist_name=artist_name, 190 - duration=duration, 191 - album_name=album_name, 192 - origin_url=origin_url, 193 - ) 194 - logger.info(f"teal play record created: {play_uri}") 195 - 196 - # update status (now playing) 197 - status_uri = await update_teal_status( 198 - auth_session=auth_session, 199 - track_name=track_title, 200 - artist_name=artist_name, 201 - duration=duration, 202 - album_name=album_name, 203 - origin_url=origin_url, 204 - ) 205 - logger.info(f"teal status updated: {status_uri}") 206 - 207 - except Exception as e: 208 - logger.error(f"teal scrobble failed for track {track_id}: {e}", exc_info=True) 209 - 210 - 211 - async def schedule_teal_scrobble( 212 - session_id: str, 213 - track_id: int, 214 - track_title: str, 215 - artist_name: str, 216 - duration: int | None, 217 - album_name: str | None, 218 - ) -> None: 219 - """schedule a teal scrobble via docket.""" 220 - docket = get_docket() 221 - await docket.add(scrobble_to_teal)( 222 - session_id, track_id, track_title, artist_name, duration, album_name 223 - ) 224 - logfire.info("scheduled teal scrobble", track_id=track_id) 225 - 226 - 227 - async def sync_album_list(session_id: str, album_id: str) -> None: 228 - """sync a single album's ATProto list record. 229 - 230 - creates or updates the album's list record on the user's PDS. 231 - called after track uploads or album mutations. 232 - 233 - args: 234 - session_id: the user's session ID for authentication 235 - album_id: the album's database ID 236 - """ 237 - from sqlalchemy import select 238 - 239 - from backend._internal.atproto.records.fm_plyr import upsert_album_list_record 240 - from backend._internal.auth import get_session 241 - from backend.models import Album, Track 242 - from backend.utilities.database import db_session 243 - 244 - auth_session = await get_session(session_id) 245 - if not auth_session: 246 - logger.warning(f"sync_album_list: session {session_id[:8]}... not found") 247 - return 248 - 249 - async with db_session() as session: 250 - # fetch album 251 - album_result = await session.execute(select(Album).where(Album.id == album_id)) 252 - album = album_result.scalar_one_or_none() 253 - if not album: 254 - logger.warning(f"sync_album_list: album {album_id} not found") 255 - return 256 - 257 - # verify album belongs to this user 258 - if album.artist_did != auth_session.did: 259 - logger.warning( 260 - f"sync_album_list: album {album_id} does not belong to {auth_session.did}" 261 - ) 262 - return 263 - 264 - # fetch tracks with ATProto records 265 - tracks_result = await session.execute( 266 - select(Track) 267 - .where( 268 - Track.album_id == album_id, 269 - Track.atproto_record_uri.isnot(None), 270 - Track.atproto_record_cid.isnot(None), 271 - ) 272 - .order_by(Track.created_at.asc()) 273 - ) 274 - tracks = tracks_result.scalars().all() 275 - 276 - if not tracks: 277 - logger.debug( 278 - f"sync_album_list: album {album_id} has no tracks with ATProto records" 279 - ) 280 - return 281 - 282 - track_refs = [ 283 - {"uri": t.atproto_record_uri, "cid": t.atproto_record_cid} 284 - for t in tracks 285 - if t.atproto_record_uri and t.atproto_record_cid 286 - ] 287 - 288 - try: 289 - result = await upsert_album_list_record( 290 - auth_session, 291 - album_id=album_id, 292 - album_title=album.title, 293 - track_refs=track_refs, 294 - existing_uri=album.atproto_record_uri, 295 - existing_created_at=album.created_at, 296 - ) 297 - if result: 298 - album.atproto_record_uri = result[0] 299 - album.atproto_record_cid = result[1] 300 - await session.commit() 301 - logger.info(f"synced album list record for {album_id}: {result[0]}") 302 - except Exception as e: 303 - logger.warning(f"failed to sync album list record for {album_id}: {e}") 304 - 305 - 306 - async def schedule_album_list_sync(session_id: str, album_id: str) -> None: 307 - """schedule an album list sync via docket.""" 308 - docket = get_docket() 309 - await docket.add(sync_album_list)(session_id, album_id) 310 - logfire.info("scheduled album list sync", album_id=album_id) 311 - 312 - 313 - # --------------------------------------------------------------------------- 314 - # PDS record write tasks 315 - # 316 - # these tasks handle writing records to the user's PDS (Personal Data Server) 317 - # in the background, then updating the local database with the result. 318 - # this keeps API responses fast while ensuring PDS and DB stay in sync. 319 - # --------------------------------------------------------------------------- 320 - 321 - 322 - async def pds_create_like( 323 - session_id: str, 324 - like_id: int, 325 - subject_uri: str, 326 - subject_cid: str, 327 - ) -> None: 328 - """create a like record on the user's PDS and update the database. 329 - 330 - args: 331 - session_id: the user's session ID for authentication 332 - like_id: database ID of the TrackLike record to update 333 - subject_uri: AT URI of the track being liked 334 - subject_cid: CID of the track being liked 335 - """ 336 - auth_session = await get_session(session_id) 337 - if not auth_session: 338 - logger.warning(f"pds_create_like: session {session_id[:8]}... not found") 339 - return 340 - 341 - try: 342 - like_uri = await create_like_record( 343 - auth_session=auth_session, 344 - subject_uri=subject_uri, 345 - subject_cid=subject_cid, 346 - ) 347 - 348 - # update database with the ATProto URI 349 - async with db_session() as session: 350 - result = await session.execute( 351 - select(TrackLike).where(TrackLike.id == like_id) 352 - ) 353 - like = result.scalar_one_or_none() 354 - if like: 355 - like.atproto_like_uri = like_uri 356 - await session.commit() 357 - logger.info(f"pds_create_like: created like record {like_uri}") 358 - else: 359 - # like was deleted before we could update it - clean up orphan 360 - logger.warning(f"pds_create_like: like {like_id} no longer exists") 361 - await delete_record_by_uri(auth_session, like_uri) 362 - 363 - except Exception as e: 364 - logger.error(f"pds_create_like failed for like {like_id}: {e}", exc_info=True) 365 - # note: we don't delete the DB record on failure - user still sees "liked" 366 - # and we can retry or fix later. this is better than inconsistent state. 367 - 368 - 369 - async def schedule_pds_create_like( 370 - session_id: str, 371 - like_id: int, 372 - subject_uri: str, 373 - subject_cid: str, 374 - ) -> None: 375 - """schedule a like record creation via docket.""" 376 - docket = get_docket() 377 - await docket.add(pds_create_like)(session_id, like_id, subject_uri, subject_cid) 378 - logfire.info("scheduled pds like creation", like_id=like_id) 379 - 380 - 381 - async def pds_delete_like( 382 - session_id: str, 383 - like_uri: str, 384 - ) -> None: 385 - """delete a like record from the user's PDS. 386 - 387 - args: 388 - session_id: the user's session ID for authentication 389 - like_uri: AT URI of the like record to delete 390 - """ 391 - auth_session = await get_session(session_id) 392 - if not auth_session: 393 - logger.warning(f"pds_delete_like: session {session_id[:8]}... not found") 394 - return 395 - 396 - try: 397 - await delete_record_by_uri(auth_session, like_uri) 398 - logger.info(f"pds_delete_like: deleted like record {like_uri}") 399 - except Exception as e: 400 - logger.error(f"pds_delete_like failed for {like_uri}: {e}", exc_info=True) 401 - # deletion failed - the PDS record may still exist, but DB is already clean 402 - # this is acceptable: orphaned PDS records are harmless 403 - 404 - 405 - async def schedule_pds_delete_like(session_id: str, like_uri: str) -> None: 406 - """schedule a like record deletion via docket.""" 407 - docket = get_docket() 408 - await docket.add(pds_delete_like)(session_id, like_uri) 409 - logfire.info("scheduled pds like deletion", like_uri=like_uri) 410 - 411 - 412 - async def pds_create_comment( 413 - session_id: str, 414 - comment_id: int, 415 - subject_uri: str, 416 - subject_cid: str, 417 - text: str, 418 - timestamp_ms: int, 419 - ) -> None: 420 - """create a comment record on the user's PDS and update the database. 421 - 422 - args: 423 - session_id: the user's session ID for authentication 424 - comment_id: database ID of the TrackComment record to update 425 - subject_uri: AT URI of the track being commented on 426 - subject_cid: CID of the track being commented on 427 - text: comment text 428 - timestamp_ms: playback position when comment was made 429 - """ 430 - auth_session = await get_session(session_id) 431 - if not auth_session: 432 - logger.warning(f"pds_create_comment: session {session_id[:8]}... not found") 433 - return 434 - 435 - try: 436 - comment_uri = await create_comment_record( 437 - auth_session=auth_session, 438 - subject_uri=subject_uri, 439 - subject_cid=subject_cid, 440 - text=text, 441 - timestamp_ms=timestamp_ms, 442 - ) 443 - 444 - # update database with the ATProto URI 445 - async with db_session() as session: 446 - result = await session.execute( 447 - select(TrackComment).where(TrackComment.id == comment_id) 448 - ) 449 - comment = result.scalar_one_or_none() 450 - if comment: 451 - comment.atproto_comment_uri = comment_uri 452 - await session.commit() 453 - logger.info(f"pds_create_comment: created comment record {comment_uri}") 454 - else: 455 - # comment was deleted before we could update it - clean up orphan 456 - logger.warning( 457 - f"pds_create_comment: comment {comment_id} no longer exists" 458 - ) 459 - await delete_record_by_uri(auth_session, comment_uri) 460 - 461 - except Exception as e: 462 - logger.error( 463 - f"pds_create_comment failed for comment {comment_id}: {e}", exc_info=True 464 - ) 465 - 466 - 467 - async def schedule_pds_create_comment( 468 - session_id: str, 469 - comment_id: int, 470 - subject_uri: str, 471 - subject_cid: str, 472 - text: str, 473 - timestamp_ms: int, 474 - ) -> None: 475 - """schedule a comment record creation via docket.""" 476 - docket = get_docket() 477 - await docket.add(pds_create_comment)( 478 - session_id, comment_id, subject_uri, subject_cid, text, timestamp_ms 479 - ) 480 - logfire.info("scheduled pds comment creation", comment_id=comment_id) 481 - 482 - 483 - async def pds_delete_comment( 484 - session_id: str, 485 - comment_uri: str, 486 - ) -> None: 487 - """delete a comment record from the user's PDS. 488 - 489 - args: 490 - session_id: the user's session ID for authentication 491 - comment_uri: AT URI of the comment record to delete 492 - """ 493 - auth_session = await get_session(session_id) 494 - if not auth_session: 495 - logger.warning(f"pds_delete_comment: session {session_id[:8]}... not found") 496 - return 497 - 498 - try: 499 - await delete_record_by_uri(auth_session, comment_uri) 500 - logger.info(f"pds_delete_comment: deleted comment record {comment_uri}") 501 - except Exception as e: 502 - logger.error(f"pds_delete_comment failed for {comment_uri}: {e}", exc_info=True) 503 - 504 - 505 - async def schedule_pds_delete_comment(session_id: str, comment_uri: str) -> None: 506 - """schedule a comment record deletion via docket.""" 507 - docket = get_docket() 508 - await docket.add(pds_delete_comment)(session_id, comment_uri) 509 - logfire.info("scheduled pds comment deletion", comment_uri=comment_uri) 510 - 511 - 512 - async def pds_update_comment( 513 - session_id: str, 514 - comment_id: int, 515 - comment_uri: str, 516 - subject_uri: str, 517 - subject_cid: str, 518 - text: str, 519 - timestamp_ms: int, 520 - created_at: datetime, 521 - ) -> None: 522 - """update a comment record on the user's PDS. 523 - 524 - args: 525 - session_id: the user's session ID for authentication 526 - comment_id: database ID of the TrackComment record 527 - comment_uri: AT URI of the comment record to update 528 - subject_uri: AT URI of the track being commented on 529 - subject_cid: CID of the track being commented on 530 - text: new comment text 531 - timestamp_ms: playback position when comment was made 532 - created_at: original creation timestamp 533 - """ 534 - auth_session = await get_session(session_id) 535 - if not auth_session: 536 - logger.warning(f"pds_update_comment: session {session_id[:8]}... not found") 537 - return 538 - 539 - try: 540 - await update_comment_record( 541 - auth_session=auth_session, 542 - comment_uri=comment_uri, 543 - subject_uri=subject_uri, 544 - subject_cid=subject_cid, 545 - text=text, 546 - timestamp_ms=timestamp_ms, 547 - created_at=created_at, 548 - updated_at=datetime.now(UTC), 549 - ) 550 - logger.info(f"pds_update_comment: updated comment record {comment_uri}") 551 - except Exception as e: 552 - logger.error( 553 - f"pds_update_comment failed for comment {comment_id}: {e}", exc_info=True 554 - ) 555 - 556 - 557 - async def schedule_pds_update_comment( 558 - session_id: str, 559 - comment_id: int, 560 - comment_uri: str, 561 - subject_uri: str, 562 - subject_cid: str, 563 - text: str, 564 - timestamp_ms: int, 565 - created_at: datetime, 566 - ) -> None: 567 - """schedule a comment record update via docket.""" 568 - docket = get_docket() 569 - await docket.add(pds_update_comment)( 570 - session_id, 571 - comment_id, 572 - comment_uri, 573 - subject_uri, 574 - subject_cid, 575 - text, 576 - timestamp_ms, 577 - created_at, 578 - ) 579 - logfire.info("scheduled pds comment update", comment_id=comment_id) 580 - 581 - 582 - async def generate_embedding(track_id: int, audio_url: str) -> None: 583 - """generate a CLAP embedding for a track and store in turbopuffer. 584 - 585 - args: 586 - track_id: database ID of the track 587 - audio_url: public URL of the audio file (R2) 588 - """ 589 - if not (settings.modal.enabled and settings.turbopuffer.enabled): 590 - logger.debug("embedding generation disabled, skipping track %d", track_id) 591 - return 592 - 593 - async with db_session() as db: 594 - result = await db.execute( 595 - select(Track) 596 - .join(Artist, Track.artist_did == Artist.did) 597 - .where(Track.id == track_id) 598 - ) 599 - row = result.first() 600 - if not row: 601 - logger.warning("generate_embedding: track %d not found", track_id) 602 - return 603 - 604 - track = row[0] 605 - 606 - artist_result = await db.execute( 607 - select(Artist).where(Artist.did == track.artist_did) 608 - ) 609 - artist = artist_result.scalar_one_or_none() 610 - if not artist: 611 - logger.warning( 612 - "generate_embedding: artist not found for track %d", track_id 613 - ) 614 - return 615 - 616 - # download audio from R2 617 - async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client: 618 - resp = await client.get(audio_url) 619 - resp.raise_for_status() 620 - audio_bytes = resp.content 621 - 622 - # generate embedding via CLAP 623 - clap_client = get_clap_client() 624 - embed_result = await clap_client.embed_audio(audio_bytes) 625 - 626 - if not embed_result.success or not embed_result.embedding: 627 - logger.error( 628 - "generate_embedding: CLAP embedding failed for track %d: %s", 629 - track_id, 630 - embed_result.error, 631 - ) 632 - return 633 - 634 - # store in turbopuffer 635 - await tpuf_upsert( 636 - track_id=track_id, 637 - embedding=embed_result.embedding, 638 - title=track.title, 639 - artist_handle=artist.handle, 640 - artist_did=artist.did, 641 - ) 642 - 643 - logfire.info("generated embedding for track", track_id=track_id) 644 - 645 - 646 - async def schedule_embedding_generation(track_id: int, audio_url: str) -> None: 647 - """schedule an embedding generation via docket.""" 648 - docket = get_docket() 649 - await docket.add(generate_embedding)(track_id, audio_url) 650 - logfire.info("scheduled embedding generation", track_id=track_id) 651 - 652 - 653 - async def classify_genres(track_id: int, audio_url: str) -> None: 654 - """classify genres for a track via Replicate effnet-discogs and store results. 655 - 656 - args: 657 - track_id: database ID of the track 658 - audio_url: public URL of the audio file (R2) 659 - """ 660 - from backend._internal.replicate_client import get_replicate_client 661 - 662 - if not settings.replicate.enabled: 663 - logger.debug("genre classification disabled, skipping track %d", track_id) 664 - return 665 - 666 - client = get_replicate_client() 667 - result = await client.classify(audio_url) 668 - 669 - if not result.success: 670 - logger.error( 671 - "genre classification failed for track %d: %s", 672 - track_id, 673 - result.error, 674 - ) 675 - return 676 - 677 - predictions = [{"name": g.name, "confidence": g.confidence} for g in result.genres] 678 - 679 - async with db_session() as db: 680 - db_result = await db.execute(select(Track).where(Track.id == track_id)) 681 - track = db_result.scalar_one_or_none() 682 - if not track: 683 - logger.warning("classify_genres: track %d not found", track_id) 684 - return 685 - 686 - extra = dict(track.extra) if track.extra else {} 687 - extra["genre_predictions"] = predictions 688 - extra["genre_predictions_file_id"] = track.file_id 689 - 690 - # auto-tag if requested 691 - if extra.get("auto_tag") and predictions: 692 - from backend.api.tracks.uploads import _add_tags_to_track 693 - 694 - # ratio-to-top: keep tags scoring >= 50% of top score 695 - top_confidence = float(predictions[0]["confidence"]) 696 - top_tags = [ 697 - str(p["name"]) 698 - for p in predictions 699 - if float(p["confidence"]) >= top_confidence * 0.5 700 - ][:5] # cap at 5 701 - 702 - if top_tags: 703 - await _add_tags_to_track(db, track_id, top_tags, track.artist_did) 704 - logfire.info( 705 - "auto-tagged track", 706 - track_id=track_id, 707 - tags=top_tags, 708 - ) 709 - 710 - # clean up flag 711 - del extra["auto_tag"] 712 - 713 - track.extra = extra 714 - await db.commit() 715 - 716 - logfire.info( 717 - "classified genres for track", 718 - track_id=track_id, 719 - top_genre=predictions[0]["name"] if predictions else None, 720 - ) 721 - 722 - 723 - async def schedule_genre_classification(track_id: int, audio_url: str) -> None: 724 - """schedule a genre classification via docket.""" 725 - docket = get_docket() 726 - await docket.add(classify_genres)(track_id, audio_url) 727 - logfire.info("scheduled genre classification", track_id=track_id) 728 - 729 - 730 - async def move_track_audio(track_id: int, to_private: bool) -> None: 731 - """move a track's audio file between public and private buckets. 732 - 733 - called when support_gate is toggled on an existing track. 734 - 735 - args: 736 - track_id: database ID of the track 737 - to_private: if True, move to private bucket; if False, move to public 738 - """ 739 - from backend.models import Track 740 - from backend.storage import storage 741 - 742 - async with db_session() as db: 743 - result = await db.execute(select(Track).where(Track.id == track_id)) 744 - track = result.scalar_one_or_none() 745 - 746 - if not track: 747 - logger.warning(f"move_track_audio: track {track_id} not found") 748 - return 749 - 750 - if not track.file_id or not track.file_type: 751 - logger.warning( 752 - f"move_track_audio: track {track_id} missing file_id/file_type" 753 - ) 754 - return 755 - 756 - result_url = await storage.move_audio( 757 - file_id=track.file_id, 758 - extension=track.file_type, 759 - to_private=to_private, 760 - ) 761 - 762 - # update r2_url: None for private, public URL for public 763 - if to_private: 764 - # moved to private - result_url is None on success, None on failure 765 - # we check by verifying the file was actually moved (no error logged) 766 - track.r2_url = None 767 - await db.commit() 768 - logger.info(f"moved track {track_id} to private bucket") 769 - elif result_url: 770 - # moved to public - result_url is the public URL 771 - track.r2_url = result_url 772 - await db.commit() 773 - logger.info(f"moved track {track_id} to public bucket") 774 - else: 775 - logger.error(f"failed to move track {track_id}") 776 - 777 - 778 - async def schedule_move_track_audio(track_id: int, to_private: bool) -> None: 779 - """schedule a track audio move via docket.""" 780 - docket = get_docket() 781 - await docket.add(move_track_audio)(track_id, to_private) 782 - direction = "private" if to_private else "public" 783 - logfire.info(f"scheduled track audio move to {direction}", track_id=track_id) 784 - 785 - 786 - # collection of all background task functions for docket registration 787 - background_tasks = [ 788 - scan_copyright, 789 - sync_copyright_resolutions, 790 - process_export, 791 - sync_atproto, 792 - scrobble_to_teal, 793 - sync_album_list, 794 - pds_create_like, 795 - pds_delete_like, 796 - pds_create_comment, 797 - pds_delete_comment, 798 - pds_update_comment, 799 - backfill_tracks_to_pds, 800 - move_track_audio, 801 - generate_embedding, 802 - classify_genres, 803 - ]
backend/src/backend/_internal/clap_client.py backend/src/backend/_internal/clients/clap.py
+1
backend/src/backend/_internal/clients/__init__.py
··· 1 + """HTTP service clients."""
+1 -1
backend/src/backend/_internal/moderation.py
··· 8 8 from sqlalchemy import select 9 9 from sqlalchemy.orm import joinedload 10 10 11 - from backend._internal.moderation_client import get_moderation_client 11 + from backend._internal.clients.moderation import get_moderation_client 12 12 from backend._internal.notifications import notification_service 13 13 from backend.config import settings 14 14 from backend.models import CopyrightScan, Track
backend/src/backend/_internal/moderation_client.py backend/src/backend/_internal/clients/moderation.py
backend/src/backend/_internal/replicate_client.py backend/src/backend/_internal/clients/replicate.py
+95
backend/src/backend/_internal/tasks/__init__.py
··· 1 + """background task functions for docket. 2 + 3 + these functions are registered with docket and executed by workers. 4 + they should be self-contained and handle their own database sessions. 5 + 6 + requires DOCKET_URL to be set (Redis is always available). 7 + """ 8 + 9 + from backend._internal.export_tasks import process_export 10 + from backend._internal.pds_backfill_tasks import backfill_tracks_to_pds 11 + from backend._internal.tasks.copyright import ( 12 + scan_copyright, 13 + schedule_copyright_resolution_sync, 14 + schedule_copyright_scan, 15 + sync_copyright_resolutions, 16 + ) 17 + from backend._internal.tasks.ml import ( 18 + classify_genres, 19 + generate_embedding, 20 + schedule_embedding_generation, 21 + schedule_genre_classification, 22 + ) 23 + from backend._internal.tasks.pds import ( 24 + pds_create_comment, 25 + pds_create_like, 26 + pds_delete_comment, 27 + pds_delete_like, 28 + pds_update_comment, 29 + schedule_pds_create_comment, 30 + schedule_pds_create_like, 31 + schedule_pds_delete_comment, 32 + schedule_pds_delete_like, 33 + schedule_pds_update_comment, 34 + ) 35 + from backend._internal.tasks.storage import ( 36 + move_track_audio, 37 + schedule_move_track_audio, 38 + ) 39 + from backend._internal.tasks.sync import ( 40 + schedule_album_list_sync, 41 + schedule_atproto_sync, 42 + schedule_teal_scrobble, 43 + scrobble_to_teal, 44 + sync_album_list, 45 + sync_atproto, 46 + ) 47 + 48 + # collection of all background task functions for docket registration 49 + background_tasks = [ 50 + scan_copyright, 51 + sync_copyright_resolutions, 52 + process_export, 53 + sync_atproto, 54 + scrobble_to_teal, 55 + sync_album_list, 56 + pds_create_like, 57 + pds_delete_like, 58 + pds_create_comment, 59 + pds_delete_comment, 60 + pds_update_comment, 61 + backfill_tracks_to_pds, 62 + move_track_audio, 63 + generate_embedding, 64 + classify_genres, 65 + ] 66 + 67 + __all__ = [ 68 + "background_tasks", 69 + "classify_genres", 70 + "generate_embedding", 71 + "move_track_audio", 72 + "pds_create_comment", 73 + "pds_create_like", 74 + "pds_delete_comment", 75 + "pds_delete_like", 76 + "pds_update_comment", 77 + "scan_copyright", 78 + "schedule_album_list_sync", 79 + "schedule_atproto_sync", 80 + "schedule_copyright_resolution_sync", 81 + "schedule_copyright_scan", 82 + "schedule_embedding_generation", 83 + "schedule_genre_classification", 84 + "schedule_move_track_audio", 85 + "schedule_pds_create_comment", 86 + "schedule_pds_create_like", 87 + "schedule_pds_delete_comment", 88 + "schedule_pds_delete_like", 89 + "schedule_pds_update_comment", 90 + "schedule_teal_scrobble", 91 + "scrobble_to_teal", 92 + "sync_album_list", 93 + "sync_atproto", 94 + "sync_copyright_resolutions", 95 + ]
+102
backend/src/backend/_internal/tasks/copyright.py
··· 1 + """copyright scanning and resolution sync background tasks.""" 2 + 3 + import logging 4 + from datetime import timedelta 5 + 6 + import logfire 7 + from docket import Perpetual 8 + from sqlalchemy import select 9 + 10 + from backend._internal.background import get_docket 11 + from backend.models import CopyrightScan, Track 12 + from backend.utilities.database import db_session 13 + 14 + logger = logging.getLogger(__name__) 15 + 16 + 17 + async def scan_copyright(track_id: int, audio_url: str) -> None: 18 + """scan a track for potential copyright matches. 19 + 20 + args: 21 + track_id: database ID of the track to scan 22 + audio_url: public URL of the audio file (R2) 23 + """ 24 + from backend._internal.moderation import scan_track_for_copyright 25 + 26 + await scan_track_for_copyright(track_id, audio_url) 27 + 28 + 29 + async def schedule_copyright_scan(track_id: int, audio_url: str) -> None: 30 + """schedule a copyright scan via docket.""" 31 + docket = get_docket() 32 + await docket.add(scan_copyright)(track_id, audio_url) 33 + logfire.info("scheduled copyright scan", track_id=track_id) 34 + 35 + 36 + async def sync_copyright_resolutions( 37 + perpetual: Perpetual = Perpetual(every=timedelta(minutes=5), automatic=True), # noqa: B008 38 + ) -> None: 39 + """sync resolution status from labeler to backend database. 40 + 41 + finds tracks that are flagged but have no resolution, checks the labeler 42 + to see if the labels were negated (dismissed), and marks them as resolved. 43 + 44 + this replaces the lazy reconciliation that was happening on read paths. 45 + runs automatically every 5 minutes via docket's Perpetual. 46 + """ 47 + from backend._internal.clients.moderation import get_moderation_client 48 + 49 + async with db_session() as db: 50 + # find flagged scans with AT URIs that haven't been resolved 51 + result = await db.execute( 52 + select(CopyrightScan, Track.atproto_record_uri) 53 + .join(Track, CopyrightScan.track_id == Track.id) 54 + .where( 55 + CopyrightScan.is_flagged == True, # noqa: E712 56 + Track.atproto_record_uri.isnot(None), 57 + ) 58 + ) 59 + rows = result.all() 60 + 61 + if not rows: 62 + logfire.debug("sync_copyright_resolutions: no flagged scans to check") 63 + return 64 + 65 + # batch check with labeler 66 + scan_by_uri: dict[str, CopyrightScan] = {} 67 + for scan, uri in rows: 68 + if uri: 69 + scan_by_uri[uri] = scan 70 + 71 + if not scan_by_uri: 72 + return 73 + 74 + client = get_moderation_client() 75 + active_uris = await client.get_active_labels(list(scan_by_uri.keys())) 76 + 77 + # find scans that are no longer active (label was negated) 78 + resolved_count = 0 79 + for uri, scan in scan_by_uri.items(): 80 + if uri not in active_uris: 81 + # label was negated - track is no longer flagged 82 + scan.is_flagged = False 83 + resolved_count += 1 84 + 85 + if resolved_count > 0: 86 + await db.commit() 87 + logfire.info( 88 + "sync_copyright_resolutions: resolved {count} scans", 89 + count=resolved_count, 90 + ) 91 + else: 92 + logfire.debug( 93 + "sync_copyright_resolutions: checked {count} scans, none resolved", 94 + count=len(scan_by_uri), 95 + ) 96 + 97 + 98 + async def schedule_copyright_resolution_sync() -> None: 99 + """schedule a copyright resolution sync via docket.""" 100 + docket = get_docket() 101 + await docket.add(sync_copyright_resolutions)() 102 + logfire.info("scheduled copyright resolution sync")
+164
backend/src/backend/_internal/tasks/ml.py
··· 1 + """ML background tasks (embeddings, genre classification).""" 2 + 3 + import logging 4 + 5 + import httpx 6 + import logfire 7 + from sqlalchemy import select 8 + 9 + from backend._internal.background import get_docket 10 + from backend._internal.clients.clap import get_clap_client 11 + from backend._internal.clients.tpuf import upsert as tpuf_upsert 12 + from backend.config import settings 13 + from backend.models import Artist, Track 14 + from backend.utilities.database import db_session 15 + 16 + logger = logging.getLogger(__name__) 17 + 18 + 19 + async def generate_embedding(track_id: int, audio_url: str) -> None: 20 + """generate a CLAP embedding for a track and store in turbopuffer. 21 + 22 + args: 23 + track_id: database ID of the track 24 + audio_url: public URL of the audio file (R2) 25 + """ 26 + if not (settings.modal.enabled and settings.turbopuffer.enabled): 27 + logger.debug("embedding generation disabled, skipping track %d", track_id) 28 + return 29 + 30 + async with db_session() as db: 31 + result = await db.execute( 32 + select(Track) 33 + .join(Artist, Track.artist_did == Artist.did) 34 + .where(Track.id == track_id) 35 + ) 36 + row = result.first() 37 + if not row: 38 + logger.warning("generate_embedding: track %d not found", track_id) 39 + return 40 + 41 + track = row[0] 42 + 43 + artist_result = await db.execute( 44 + select(Artist).where(Artist.did == track.artist_did) 45 + ) 46 + artist = artist_result.scalar_one_or_none() 47 + if not artist: 48 + logger.warning( 49 + "generate_embedding: artist not found for track %d", track_id 50 + ) 51 + return 52 + 53 + # download audio from R2 54 + async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client: 55 + resp = await client.get(audio_url) 56 + resp.raise_for_status() 57 + audio_bytes = resp.content 58 + 59 + # generate embedding via CLAP 60 + clap_client = get_clap_client() 61 + embed_result = await clap_client.embed_audio(audio_bytes) 62 + 63 + if not embed_result.success or not embed_result.embedding: 64 + logger.error( 65 + "generate_embedding: CLAP embedding failed for track %d: %s", 66 + track_id, 67 + embed_result.error, 68 + ) 69 + return 70 + 71 + # store in turbopuffer 72 + await tpuf_upsert( 73 + track_id=track_id, 74 + embedding=embed_result.embedding, 75 + title=track.title, 76 + artist_handle=artist.handle, 77 + artist_did=artist.did, 78 + ) 79 + 80 + logfire.info("generated embedding for track", track_id=track_id) 81 + 82 + 83 + async def schedule_embedding_generation(track_id: int, audio_url: str) -> None: 84 + """schedule an embedding generation via docket.""" 85 + docket = get_docket() 86 + await docket.add(generate_embedding)(track_id, audio_url) 87 + logfire.info("scheduled embedding generation", track_id=track_id) 88 + 89 + 90 + async def classify_genres(track_id: int, audio_url: str) -> None: 91 + """classify genres for a track via Replicate effnet-discogs and store results. 92 + 93 + args: 94 + track_id: database ID of the track 95 + audio_url: public URL of the audio file (R2) 96 + """ 97 + from backend._internal.clients.replicate import get_replicate_client 98 + 99 + if not settings.replicate.enabled: 100 + logger.debug("genre classification disabled, skipping track %d", track_id) 101 + return 102 + 103 + client = get_replicate_client() 104 + result = await client.classify(audio_url) 105 + 106 + if not result.success: 107 + logger.error( 108 + "genre classification failed for track %d: %s", 109 + track_id, 110 + result.error, 111 + ) 112 + return 113 + 114 + predictions = [{"name": g.name, "confidence": g.confidence} for g in result.genres] 115 + 116 + async with db_session() as db: 117 + db_result = await db.execute(select(Track).where(Track.id == track_id)) 118 + track = db_result.scalar_one_or_none() 119 + if not track: 120 + logger.warning("classify_genres: track %d not found", track_id) 121 + return 122 + 123 + extra = dict(track.extra) if track.extra else {} 124 + extra["genre_predictions"] = predictions 125 + extra["genre_predictions_file_id"] = track.file_id 126 + 127 + # auto-tag if requested 128 + if extra.get("auto_tag") and predictions: 129 + from backend.utilities.tags import add_tags_to_track 130 + 131 + # ratio-to-top: keep tags scoring >= 50% of top score 132 + top_confidence = float(predictions[0]["confidence"]) 133 + top_tags = [ 134 + str(p["name"]) 135 + for p in predictions 136 + if float(p["confidence"]) >= top_confidence * 0.5 137 + ][:5] # cap at 5 138 + 139 + if top_tags: 140 + await add_tags_to_track(db, track_id, top_tags, track.artist_did) 141 + logfire.info( 142 + "auto-tagged track", 143 + track_id=track_id, 144 + tags=top_tags, 145 + ) 146 + 147 + # clean up flag 148 + del extra["auto_tag"] 149 + 150 + track.extra = extra 151 + await db.commit() 152 + 153 + logfire.info( 154 + "classified genres for track", 155 + track_id=track_id, 156 + top_genre=predictions[0]["name"] if predictions else None, 157 + ) 158 + 159 + 160 + async def schedule_genre_classification(track_id: int, audio_url: str) -> None: 161 + """schedule a genre classification via docket.""" 162 + docket = get_docket() 163 + await docket.add(classify_genres)(track_id, audio_url) 164 + logfire.info("scheduled genre classification", track_id=track_id)
+280
backend/src/backend/_internal/tasks/pds.py
··· 1 + """PDS record write background tasks (likes, comments).""" 2 + 3 + import logging 4 + from datetime import UTC, datetime 5 + 6 + import logfire 7 + from sqlalchemy import select 8 + 9 + from backend._internal.atproto.records import ( 10 + create_comment_record, 11 + create_like_record, 12 + delete_record_by_uri, 13 + update_comment_record, 14 + ) 15 + from backend._internal.auth import get_session 16 + from backend._internal.background import get_docket 17 + from backend.models import TrackComment, TrackLike 18 + from backend.utilities.database import db_session 19 + 20 + logger = logging.getLogger(__name__) 21 + 22 + 23 + async def pds_create_like( 24 + session_id: str, 25 + like_id: int, 26 + subject_uri: str, 27 + subject_cid: str, 28 + ) -> None: 29 + """create a like record on the user's PDS and update the database. 30 + 31 + args: 32 + session_id: the user's session ID for authentication 33 + like_id: database ID of the TrackLike record to update 34 + subject_uri: AT URI of the track being liked 35 + subject_cid: CID of the track being liked 36 + """ 37 + auth_session = await get_session(session_id) 38 + if not auth_session: 39 + logger.warning(f"pds_create_like: session {session_id[:8]}... not found") 40 + return 41 + 42 + try: 43 + like_uri = await create_like_record( 44 + auth_session=auth_session, 45 + subject_uri=subject_uri, 46 + subject_cid=subject_cid, 47 + ) 48 + 49 + # update database with the ATProto URI 50 + async with db_session() as session: 51 + result = await session.execute( 52 + select(TrackLike).where(TrackLike.id == like_id) 53 + ) 54 + like = result.scalar_one_or_none() 55 + if like: 56 + like.atproto_like_uri = like_uri 57 + await session.commit() 58 + logger.info(f"pds_create_like: created like record {like_uri}") 59 + else: 60 + # like was deleted before we could update it - clean up orphan 61 + logger.warning(f"pds_create_like: like {like_id} no longer exists") 62 + await delete_record_by_uri(auth_session, like_uri) 63 + 64 + except Exception as e: 65 + logger.error(f"pds_create_like failed for like {like_id}: {e}", exc_info=True) 66 + # note: we don't delete the DB record on failure - user still sees "liked" 67 + # and we can retry or fix later. this is better than inconsistent state. 68 + 69 + 70 + async def schedule_pds_create_like( 71 + session_id: str, 72 + like_id: int, 73 + subject_uri: str, 74 + subject_cid: str, 75 + ) -> None: 76 + """schedule a like record creation via docket.""" 77 + docket = get_docket() 78 + await docket.add(pds_create_like)(session_id, like_id, subject_uri, subject_cid) 79 + logfire.info("scheduled pds like creation", like_id=like_id) 80 + 81 + 82 + async def pds_delete_like( 83 + session_id: str, 84 + like_uri: str, 85 + ) -> None: 86 + """delete a like record from the user's PDS. 87 + 88 + args: 89 + session_id: the user's session ID for authentication 90 + like_uri: AT URI of the like record to delete 91 + """ 92 + auth_session = await get_session(session_id) 93 + if not auth_session: 94 + logger.warning(f"pds_delete_like: session {session_id[:8]}... not found") 95 + return 96 + 97 + try: 98 + await delete_record_by_uri(auth_session, like_uri) 99 + logger.info(f"pds_delete_like: deleted like record {like_uri}") 100 + except Exception as e: 101 + logger.error(f"pds_delete_like failed for {like_uri}: {e}", exc_info=True) 102 + # deletion failed - the PDS record may still exist, but DB is already clean 103 + # this is acceptable: orphaned PDS records are harmless 104 + 105 + 106 + async def schedule_pds_delete_like(session_id: str, like_uri: str) -> None: 107 + """schedule a like record deletion via docket.""" 108 + docket = get_docket() 109 + await docket.add(pds_delete_like)(session_id, like_uri) 110 + logfire.info("scheduled pds like deletion", like_uri=like_uri) 111 + 112 + 113 + async def pds_create_comment( 114 + session_id: str, 115 + comment_id: int, 116 + subject_uri: str, 117 + subject_cid: str, 118 + text: str, 119 + timestamp_ms: int, 120 + ) -> None: 121 + """create a comment record on the user's PDS and update the database. 122 + 123 + args: 124 + session_id: the user's session ID for authentication 125 + comment_id: database ID of the TrackComment record to update 126 + subject_uri: AT URI of the track being commented on 127 + subject_cid: CID of the track being commented on 128 + text: comment text 129 + timestamp_ms: playback position when comment was made 130 + """ 131 + auth_session = await get_session(session_id) 132 + if not auth_session: 133 + logger.warning(f"pds_create_comment: session {session_id[:8]}... not found") 134 + return 135 + 136 + try: 137 + comment_uri = await create_comment_record( 138 + auth_session=auth_session, 139 + subject_uri=subject_uri, 140 + subject_cid=subject_cid, 141 + text=text, 142 + timestamp_ms=timestamp_ms, 143 + ) 144 + 145 + # update database with the ATProto URI 146 + async with db_session() as session: 147 + result = await session.execute( 148 + select(TrackComment).where(TrackComment.id == comment_id) 149 + ) 150 + comment = result.scalar_one_or_none() 151 + if comment: 152 + comment.atproto_comment_uri = comment_uri 153 + await session.commit() 154 + logger.info(f"pds_create_comment: created comment record {comment_uri}") 155 + else: 156 + # comment was deleted before we could update it - clean up orphan 157 + logger.warning( 158 + f"pds_create_comment: comment {comment_id} no longer exists" 159 + ) 160 + await delete_record_by_uri(auth_session, comment_uri) 161 + 162 + except Exception as e: 163 + logger.error( 164 + f"pds_create_comment failed for comment {comment_id}: {e}", exc_info=True 165 + ) 166 + 167 + 168 + async def schedule_pds_create_comment( 169 + session_id: str, 170 + comment_id: int, 171 + subject_uri: str, 172 + subject_cid: str, 173 + text: str, 174 + timestamp_ms: int, 175 + ) -> None: 176 + """schedule a comment record creation via docket.""" 177 + docket = get_docket() 178 + await docket.add(pds_create_comment)( 179 + session_id, comment_id, subject_uri, subject_cid, text, timestamp_ms 180 + ) 181 + logfire.info("scheduled pds comment creation", comment_id=comment_id) 182 + 183 + 184 + async def pds_delete_comment( 185 + session_id: str, 186 + comment_uri: str, 187 + ) -> None: 188 + """delete a comment record from the user's PDS. 189 + 190 + args: 191 + session_id: the user's session ID for authentication 192 + comment_uri: AT URI of the comment record to delete 193 + """ 194 + auth_session = await get_session(session_id) 195 + if not auth_session: 196 + logger.warning(f"pds_delete_comment: session {session_id[:8]}... not found") 197 + return 198 + 199 + try: 200 + await delete_record_by_uri(auth_session, comment_uri) 201 + logger.info(f"pds_delete_comment: deleted comment record {comment_uri}") 202 + except Exception as e: 203 + logger.error(f"pds_delete_comment failed for {comment_uri}: {e}", exc_info=True) 204 + 205 + 206 + async def schedule_pds_delete_comment(session_id: str, comment_uri: str) -> None: 207 + """schedule a comment record deletion via docket.""" 208 + docket = get_docket() 209 + await docket.add(pds_delete_comment)(session_id, comment_uri) 210 + logfire.info("scheduled pds comment deletion", comment_uri=comment_uri) 211 + 212 + 213 + async def pds_update_comment( 214 + session_id: str, 215 + comment_id: int, 216 + comment_uri: str, 217 + subject_uri: str, 218 + subject_cid: str, 219 + text: str, 220 + timestamp_ms: int, 221 + created_at: datetime, 222 + ) -> None: 223 + """update a comment record on the user's PDS. 224 + 225 + args: 226 + session_id: the user's session ID for authentication 227 + comment_id: database ID of the TrackComment record 228 + comment_uri: AT URI of the comment record to update 229 + subject_uri: AT URI of the track being commented on 230 + subject_cid: CID of the track being commented on 231 + text: new comment text 232 + timestamp_ms: playback position when comment was made 233 + created_at: original creation timestamp 234 + """ 235 + auth_session = await get_session(session_id) 236 + if not auth_session: 237 + logger.warning(f"pds_update_comment: session {session_id[:8]}... not found") 238 + return 239 + 240 + try: 241 + await update_comment_record( 242 + auth_session=auth_session, 243 + comment_uri=comment_uri, 244 + subject_uri=subject_uri, 245 + subject_cid=subject_cid, 246 + text=text, 247 + timestamp_ms=timestamp_ms, 248 + created_at=created_at, 249 + updated_at=datetime.now(UTC), 250 + ) 251 + logger.info(f"pds_update_comment: updated comment record {comment_uri}") 252 + except Exception as e: 253 + logger.error( 254 + f"pds_update_comment failed for comment {comment_id}: {e}", exc_info=True 255 + ) 256 + 257 + 258 + async def schedule_pds_update_comment( 259 + session_id: str, 260 + comment_id: int, 261 + comment_uri: str, 262 + subject_uri: str, 263 + subject_cid: str, 264 + text: str, 265 + timestamp_ms: int, 266 + created_at: datetime, 267 + ) -> None: 268 + """schedule a comment record update via docket.""" 269 + docket = get_docket() 270 + await docket.add(pds_update_comment)( 271 + session_id, 272 + comment_id, 273 + comment_uri, 274 + subject_uri, 275 + subject_cid, 276 + text, 277 + timestamp_ms, 278 + created_at, 279 + ) 280 + logfire.info("scheduled pds comment update", comment_id=comment_id)
+67
backend/src/backend/_internal/tasks/storage.py
··· 1 + """storage-related background tasks.""" 2 + 3 + import logging 4 + 5 + import logfire 6 + from sqlalchemy import select 7 + 8 + from backend._internal.background import get_docket 9 + from backend.utilities.database import db_session 10 + 11 + logger = logging.getLogger(__name__) 12 + 13 + 14 + async def move_track_audio(track_id: int, to_private: bool) -> None: 15 + """move a track's audio file between public and private buckets. 16 + 17 + called when support_gate is toggled on an existing track. 18 + 19 + args: 20 + track_id: database ID of the track 21 + to_private: if True, move to private bucket; if False, move to public 22 + """ 23 + from backend.models import Track 24 + from backend.storage import storage 25 + 26 + async with db_session() as db: 27 + result = await db.execute(select(Track).where(Track.id == track_id)) 28 + track = result.scalar_one_or_none() 29 + 30 + if not track: 31 + logger.warning(f"move_track_audio: track {track_id} not found") 32 + return 33 + 34 + if not track.file_id or not track.file_type: 35 + logger.warning( 36 + f"move_track_audio: track {track_id} missing file_id/file_type" 37 + ) 38 + return 39 + 40 + result_url = await storage.move_audio( 41 + file_id=track.file_id, 42 + extension=track.file_type, 43 + to_private=to_private, 44 + ) 45 + 46 + # update r2_url: None for private, public URL for public 47 + if to_private: 48 + # moved to private - result_url is None on success, None on failure 49 + # we check by verifying the file was actually moved (no error logged) 50 + track.r2_url = None 51 + await db.commit() 52 + logger.info(f"moved track {track_id} to private bucket") 53 + elif result_url: 54 + # moved to public - result_url is the public URL 55 + track.r2_url = result_url 56 + await db.commit() 57 + logger.info(f"moved track {track_id} to public bucket") 58 + else: 59 + logger.error(f"failed to move track {track_id}") 60 + 61 + 62 + async def schedule_move_track_audio(track_id: int, to_private: bool) -> None: 63 + """schedule a track audio move via docket.""" 64 + docket = get_docket() 65 + await docket.add(move_track_audio)(track_id, to_private) 66 + direction = "private" if to_private else "public" 67 + logfire.info(f"scheduled track audio move to {direction}", track_id=track_id)
+197
backend/src/backend/_internal/tasks/sync.py
··· 1 + """ATProto sync and teal scrobble background tasks.""" 2 + 3 + import logging 4 + 5 + import logfire 6 + from sqlalchemy import select 7 + 8 + from backend._internal.background import get_docket 9 + 10 + logger = logging.getLogger(__name__) 11 + 12 + 13 + async def sync_atproto(session_id: str, user_did: str) -> None: 14 + """sync ATProto records (profile, albums, liked tracks) for a user. 15 + 16 + this runs after login or scope upgrade to ensure the user's PDS 17 + has up-to-date records for their plyr.fm data. 18 + 19 + args: 20 + session_id: the user's session ID for authentication 21 + user_did: the user's DID 22 + """ 23 + from backend._internal.atproto.sync import sync_atproto_records 24 + from backend._internal.auth import get_session 25 + 26 + auth_session = await get_session(session_id) 27 + if not auth_session: 28 + logger.warning(f"sync_atproto: session {session_id[:8]}... not found") 29 + return 30 + 31 + await sync_atproto_records(auth_session, user_did) 32 + 33 + 34 + async def schedule_atproto_sync(session_id: str, user_did: str) -> None: 35 + """schedule an ATProto sync via docket.""" 36 + docket = get_docket() 37 + await docket.add(sync_atproto)(session_id, user_did) 38 + logfire.info("scheduled atproto sync", user_did=user_did) 39 + 40 + 41 + async def scrobble_to_teal( 42 + session_id: str, 43 + track_id: int, 44 + track_title: str, 45 + artist_name: str, 46 + duration: int | None, 47 + album_name: str | None, 48 + ) -> None: 49 + """scrobble a play to teal.fm (creates play record + updates status). 50 + 51 + args: 52 + session_id: the user's session ID for authentication 53 + track_id: database ID of the track 54 + track_title: title of the track 55 + artist_name: name of the artist 56 + duration: track duration in seconds 57 + album_name: album name (optional) 58 + """ 59 + from backend._internal.atproto.teal import ( 60 + create_teal_play_record, 61 + update_teal_status, 62 + ) 63 + from backend._internal.auth import get_session 64 + from backend.config import settings 65 + 66 + auth_session = await get_session(session_id) 67 + if not auth_session: 68 + logger.warning(f"teal scrobble: session {session_id[:8]}... not found") 69 + return 70 + 71 + origin_url = f"{settings.frontend.url}/track/{track_id}" 72 + 73 + try: 74 + # create play record (scrobble) 75 + play_uri = await create_teal_play_record( 76 + auth_session=auth_session, 77 + track_name=track_title, 78 + artist_name=artist_name, 79 + duration=duration, 80 + album_name=album_name, 81 + origin_url=origin_url, 82 + ) 83 + logger.info(f"teal play record created: {play_uri}") 84 + 85 + # update status (now playing) 86 + status_uri = await update_teal_status( 87 + auth_session=auth_session, 88 + track_name=track_title, 89 + artist_name=artist_name, 90 + duration=duration, 91 + album_name=album_name, 92 + origin_url=origin_url, 93 + ) 94 + logger.info(f"teal status updated: {status_uri}") 95 + 96 + except Exception as e: 97 + logger.error(f"teal scrobble failed for track {track_id}: {e}", exc_info=True) 98 + 99 + 100 + async def schedule_teal_scrobble( 101 + session_id: str, 102 + track_id: int, 103 + track_title: str, 104 + artist_name: str, 105 + duration: int | None, 106 + album_name: str | None, 107 + ) -> None: 108 + """schedule a teal scrobble via docket.""" 109 + docket = get_docket() 110 + await docket.add(scrobble_to_teal)( 111 + session_id, track_id, track_title, artist_name, duration, album_name 112 + ) 113 + logfire.info("scheduled teal scrobble", track_id=track_id) 114 + 115 + 116 + async def sync_album_list(session_id: str, album_id: str) -> None: 117 + """sync a single album's ATProto list record. 118 + 119 + creates or updates the album's list record on the user's PDS. 120 + called after track uploads or album mutations. 121 + 122 + args: 123 + session_id: the user's session ID for authentication 124 + album_id: the album's database ID 125 + """ 126 + from backend._internal.atproto.records.fm_plyr import upsert_album_list_record 127 + from backend._internal.auth import get_session 128 + from backend.models import Album, Track 129 + from backend.utilities.database import db_session 130 + 131 + auth_session = await get_session(session_id) 132 + if not auth_session: 133 + logger.warning(f"sync_album_list: session {session_id[:8]}... not found") 134 + return 135 + 136 + async with db_session() as session: 137 + # fetch album 138 + album_result = await session.execute(select(Album).where(Album.id == album_id)) 139 + album = album_result.scalar_one_or_none() 140 + if not album: 141 + logger.warning(f"sync_album_list: album {album_id} not found") 142 + return 143 + 144 + # verify album belongs to this user 145 + if album.artist_did != auth_session.did: 146 + logger.warning( 147 + f"sync_album_list: album {album_id} does not belong to {auth_session.did}" 148 + ) 149 + return 150 + 151 + # fetch tracks with ATProto records 152 + tracks_result = await session.execute( 153 + select(Track) 154 + .where( 155 + Track.album_id == album_id, 156 + Track.atproto_record_uri.isnot(None), 157 + Track.atproto_record_cid.isnot(None), 158 + ) 159 + .order_by(Track.created_at.asc()) 160 + ) 161 + tracks = tracks_result.scalars().all() 162 + 163 + if not tracks: 164 + logger.debug( 165 + f"sync_album_list: album {album_id} has no tracks with ATProto records" 166 + ) 167 + return 168 + 169 + track_refs = [ 170 + {"uri": t.atproto_record_uri, "cid": t.atproto_record_cid} 171 + for t in tracks 172 + if t.atproto_record_uri and t.atproto_record_cid 173 + ] 174 + 175 + try: 176 + result = await upsert_album_list_record( 177 + auth_session, 178 + album_id=album_id, 179 + album_title=album.title, 180 + track_refs=track_refs, 181 + existing_uri=album.atproto_record_uri, 182 + existing_created_at=album.created_at, 183 + ) 184 + if result: 185 + album.atproto_record_uri = result[0] 186 + album.atproto_record_cid = result[1] 187 + await session.commit() 188 + logger.info(f"synced album list record for {album_id}: {result[0]}") 189 + except Exception as e: 190 + logger.warning(f"failed to sync album list record for {album_id}: {e}") 191 + 192 + 193 + async def schedule_album_list_sync(session_id: str, album_id: str) -> None: 194 + """schedule an album list sync via docket.""" 195 + docket = get_docket() 196 + await docket.add(sync_album_list)(session_id, album_id) 197 + logfire.info("scheduled album list sync", album_id=album_id)
backend/src/backend/_internal/tpuf_client.py backend/src/backend/_internal/clients/tpuf.py
backend/src/backend/_internal/transcoder_client.py backend/src/backend/_internal/clients/transcoder.py
+1 -1
backend/src/backend/api/albums.py
··· 25 25 from backend._internal import Session as AuthSession 26 26 from backend._internal import require_artist_profile 27 27 from backend._internal.auth import get_session 28 - from backend._internal.moderation_client import get_moderation_client 28 + from backend._internal.clients.moderation import get_moderation_client 29 29 from backend._internal.notifications import notification_service 30 30 from backend.config import settings 31 31 from backend.models import Album, Artist, Track, TrackLike, get_db
+1 -1
backend/src/backend/api/auth.py
··· 39 39 switch_active_account, 40 40 ) 41 41 from backend._internal.auth import get_refresh_token_lifetime_days 42 - from backend._internal.background_tasks import schedule_atproto_sync 42 + from backend._internal.tasks import schedule_atproto_sync 43 43 from backend.config import settings 44 44 from backend.models import Artist, get_db 45 45 from backend.utilities.rate_limit import limiter
+1 -1
backend/src/backend/api/moderation.py
··· 9 9 from pydantic import BaseModel, Field 10 10 11 11 from backend._internal import Session, require_auth 12 - from backend._internal.moderation_client import get_moderation_client 12 + from backend._internal.clients.moderation import get_moderation_client 13 13 from backend._internal.notifications import notification_service 14 14 from backend.utilities.rate_limit import limiter 15 15
+2 -2
backend/src/backend/api/search.py
··· 9 9 from sqlalchemy.ext.asyncio import AsyncSession 10 10 11 11 from backend._internal.atproto.handles import search_handles 12 - from backend._internal.clap_client import get_clap_client 13 - from backend._internal.tpuf_client import query as tpuf_query 12 + from backend._internal.clients.clap import get_clap_client 13 + from backend._internal.clients.tpuf import query as tpuf_query 14 14 from backend.config import settings 15 15 from backend.models import Album, Artist, Playlist, Tag, Track, TrackTag, get_db 16 16
+1 -1
backend/src/backend/api/tracks/comments.py
··· 11 11 12 12 from backend._internal import Session as AuthSession 13 13 from backend._internal import require_auth 14 - from backend._internal.background_tasks import ( 14 + from backend._internal.tasks import ( 15 15 schedule_pds_create_comment, 16 16 schedule_pds_delete_comment, 17 17 schedule_pds_update_comment,
+1 -1
backend/src/backend/api/tracks/likes.py
··· 12 12 13 13 from backend._internal import Session as AuthSession 14 14 from backend._internal import require_auth 15 - from backend._internal.background_tasks import ( 15 + from backend._internal.tasks import ( 16 16 schedule_pds_create_like, 17 17 schedule_pds_delete_like, 18 18 )
+1 -1
backend/src/backend/api/tracks/metadata_service.py
··· 12 12 from starlette.datastructures import UploadFile 13 13 14 14 from backend._internal.atproto.handles import resolve_handle 15 + from backend._internal.clients.moderation import get_moderation_client 15 16 from backend._internal.image import ImageFormat 16 - from backend._internal.moderation_client import get_moderation_client 17 17 from backend._internal.notifications import notification_service 18 18 from backend.config import settings 19 19 from backend.models import Track
+4 -42
backend/src/backend/api/tracks/mutations.py
··· 3 3 import contextlib 4 4 import json 5 5 import logging 6 - from datetime import UTC, datetime 7 6 from typing import Annotated 8 7 from urllib.parse import urljoin 9 8 ··· 11 10 from fastapi import Depends, File, Form, HTTPException, UploadFile 12 11 from pydantic import BaseModel 13 12 from sqlalchemy import select 14 - from sqlalchemy.exc import IntegrityError 15 13 from sqlalchemy.ext.asyncio import AsyncSession 16 14 from sqlalchemy.orm import selectinload 17 15 ··· 30 28 ) 31 29 from backend._internal.atproto.tid import datetime_to_tid 32 30 from backend._internal.audio import AudioFormat 33 - from backend._internal.background_tasks import ( 31 + from backend._internal.tasks import ( 34 32 schedule_album_list_sync, 35 33 schedule_move_track_audio, 36 34 ) 37 35 from backend.config import settings 38 - from backend.models import Artist, Tag, Track, TrackTag, get_db 36 + from backend.models import Artist, Track, TrackTag, get_db 39 37 from backend.schemas import MessageResponse, TrackResponse 40 38 from backend.storage import storage 41 - from backend.utilities.tags import parse_tags_json 39 + from backend.utilities.tags import get_or_create_tag, parse_tags_json 42 40 43 41 from .metadata_service import ( 44 42 apply_album_update, ··· 48 46 from .router import router 49 47 50 48 logger = logging.getLogger(__name__) 51 - 52 - 53 - async def _get_or_create_tag(db: AsyncSession, tag_name: str, creator_did: str) -> Tag: 54 - """get existing tag or create new one, handling race conditions. 55 - 56 - uses a select-then-insert pattern with IntegrityError handling 57 - to safely handle concurrent tag creation. 58 - """ 59 - # first try to find existing tag 60 - result = await db.execute(select(Tag).where(Tag.name == tag_name)) 61 - tag = result.scalar_one_or_none() 62 - if tag: 63 - return tag 64 - 65 - # try to create new tag 66 - tag = Tag( 67 - name=tag_name, 68 - created_by_did=creator_did, 69 - created_at=datetime.now(UTC), 70 - ) 71 - db.add(tag) 72 - 73 - try: 74 - await db.flush() 75 - return tag 76 - except IntegrityError as e: 77 - # only handle unique constraint violation on tag name (pgcode 23505) 78 - # re-raise other integrity errors (e.g., foreign key violations) 79 - pgcode = getattr(e.orig, "pgcode", None) 80 - if pgcode != "23505": 81 - raise 82 - # another process created the tag - rollback and fetch it 83 - await db.rollback() 84 - result = await db.execute(select(Tag).where(Tag.name == tag_name)) 85 - tag = result.scalar_one() 86 - return tag 87 49 88 50 89 51 @router.delete("/{track_id}") ··· 308 270 # get or create tags and create track_tags 309 271 for tag_name in validated_tags: 310 272 # get or create tag with race condition handling 311 - tag = await _get_or_create_tag(db, tag_name, auth_session.did) 273 + tag = await get_or_create_tag(db, tag_name, auth_session.did) 312 274 313 275 # create track_tag association 314 276 track_tag = TrackTag(track_id=track_id, tag_id=tag.id)
+1 -1
backend/src/backend/api/tracks/playback.py
··· 11 11 from sqlalchemy.orm import selectinload 12 12 13 13 from backend._internal import Session, get_optional_session 14 - from backend._internal.background_tasks import schedule_teal_scrobble 14 + from backend._internal.tasks import schedule_teal_scrobble 15 15 from backend.config import settings 16 16 from backend.models import ( 17 17 Artist,
+1 -1
backend/src/backend/api/tracks/tags.py
··· 198 198 199 199 if predictions is None and settings.replicate.enabled and track.r2_url: 200 200 # classify on-demand 201 - from backend._internal.replicate_client import get_replicate_client 201 + from backend._internal.clients.replicate import get_replicate_client 202 202 203 203 client = get_replicate_client() 204 204 classify_result = await client.classify(track.r2_url)
+337 -388
backend/src/backend/api/tracks/uploads.py
··· 6 6 import logging 7 7 import tempfile 8 8 from dataclasses import dataclass 9 - from datetime import UTC, datetime 10 9 from pathlib import Path 11 10 from typing import Annotated 12 11 ··· 37 36 ) 38 37 from backend._internal.atproto.handles import resolve_featured_artists 39 38 from backend._internal.audio import AudioFormat 40 - from backend._internal.background_tasks import ( 39 + from backend._internal.clients.transcoder import get_transcoder_client 40 + from backend._internal.image import ImageFormat 41 + from backend._internal.jobs import job_service 42 + from backend._internal.tasks import ( 41 43 schedule_album_list_sync, 42 44 schedule_copyright_scan, 43 45 schedule_embedding_generation, 44 46 schedule_genre_classification, 45 47 ) 46 - from backend._internal.image import ImageFormat 47 - from backend._internal.jobs import job_service 48 - from backend._internal.transcoder_client import get_transcoder_client 49 48 from backend.config import settings 50 - from backend.models import Artist, Tag, Track, TrackTag, UserPreferences 49 + from backend.models import Artist, Track, UserPreferences 51 50 from backend.models.job import JobStatus, JobType 52 51 from backend.storage import storage 53 52 from backend.utilities.audio import extract_duration ··· 55 54 from backend.utilities.hashing import CHUNK_SIZE 56 55 from backend.utilities.progress import R2ProgressTracker 57 56 from backend.utilities.rate_limit import limiter 58 - from backend.utilities.tags import parse_tags_json 57 + from backend.utilities.tags import add_tags_to_track, parse_tags_json 59 58 60 59 from .router import router 61 60 from .services import get_or_create_album ··· 104 103 auto_tag: bool = False 105 104 106 105 107 - async def _get_or_create_tag( 108 - db: "AsyncSession", tag_name: str, creator_did: str 109 - ) -> Tag: 110 - """get existing tag or create new one, handling race conditions. 106 + @dataclass 107 + class AudioInfo: 108 + """result of audio validation phase.""" 111 109 112 - uses a select-then-insert pattern with IntegrityError handling 113 - to safely handle concurrent tag creation. 114 - """ 115 - # first try to find existing tag 116 - result = await db.execute(select(Tag).where(Tag.name == tag_name)) 117 - tag = result.scalar_one_or_none() 118 - if tag: 119 - return tag 110 + format: AudioFormat 111 + duration: int | None 112 + is_gated: bool 120 113 121 - # try to create new tag 122 - tag = Tag( 123 - name=tag_name, 124 - created_by_did=creator_did, 125 - created_at=datetime.now(UTC), 126 - ) 127 - db.add(tag) 114 + 115 + @dataclass 116 + class StorageResult: 117 + """result of audio storage phase.""" 118 + 119 + file_id: str 120 + original_file_id: str | None 121 + original_file_type: str | None 122 + playable_format: AudioFormat 123 + r2_url: str | None 124 + transcode_info: "TranscodeInfo | None" 125 + 126 + 127 + class UploadPhaseError(Exception): 128 + """raised when an upload phase fails with a user-facing message.""" 128 129 129 - try: 130 - await db.flush() 131 - return tag 132 - except IntegrityError as e: 133 - # only handle unique constraint violation on tag name (pgcode 23505) 134 - # re-raise other integrity errors (e.g., foreign key violations) 135 - pgcode = getattr(e.orig, "pgcode", None) 136 - if pgcode != "23505": 137 - raise 138 - # another process created the tag - rollback and fetch it 139 - await db.rollback() 140 - result = await db.execute(select(Tag).where(Tag.name == tag_name)) 141 - tag = result.scalar_one() 142 - return tag 130 + def __init__(self, error: str) -> None: 131 + self.error = error 132 + super().__init__(error) 143 133 144 134 145 135 async def _save_audio_to_storage( ··· 462 452 ) 463 453 464 454 465 - async def _add_tags_to_track( 466 - db: AsyncSession, 467 - track_id: int, 468 - validated_tags: list[str], 469 - creator_did: str, 470 - ) -> None: 471 - """add validated tags to a track.""" 472 - if not validated_tags: 473 - return 474 - 475 - try: 476 - for tag_name in validated_tags: 477 - tag = await _get_or_create_tag(db, tag_name, creator_did) 478 - track_tag = TrackTag(track_id=track_id, tag_id=tag.id) 479 - db.add(track_tag) 480 - await db.commit() 481 - except Exception as e: 482 - logfire.error( 483 - "failed to add tags to track", 484 - track_id=track_id, 485 - tags=validated_tags, 486 - error=str(e), 487 - ) 488 - 489 - 490 455 async def _send_track_notification(db: AsyncSession, track: Track) -> None: 491 456 """send notification for new track upload.""" 492 457 from backend._internal.notifications import notification_service ··· 500 465 logger.warning(f"failed to send notification for track {track.id}: {e}") 501 466 502 467 503 - async def _process_upload_background(ctx: UploadContext) -> None: 504 - """background task to process upload.""" 505 - with logfire.span( 506 - "process upload background", upload_id=ctx.upload_id, filename=ctx.filename 507 - ): 508 - file_id: str | None = None 509 - original_file_id: str | None = None 510 - original_file_type: str | None = None 511 - image_id: str | None = None 512 - audio_format: AudioFormat | None = None 513 - playable_format: AudioFormat | None = None 514 - pds_blob_result: PdsBlobResult | None = None 468 + async def _validate_audio(ctx: UploadContext) -> AudioInfo: 469 + """phase 1: validate file type, extract duration, check gating requirements.""" 470 + ext = Path(ctx.filename).suffix.lower() 471 + audio_format = AudioFormat.from_extension(ext) 472 + if not audio_format: 473 + raise UploadPhaseError(f"unsupported file type: {ext}") 474 + 475 + with open(ctx.file_path, "rb") as f: 476 + duration = extract_duration(f) 477 + 478 + is_gated = ctx.support_gate is not None 479 + if is_gated: 480 + async with db_session() as db: 481 + prefs_result = await db.execute( 482 + select(UserPreferences).where(UserPreferences.did == ctx.artist_did) 483 + ) 484 + prefs = prefs_result.scalar_one_or_none() 485 + if not prefs or prefs.support_url != "atprotofans": 486 + raise UploadPhaseError( 487 + "supporter gating requires atprotofans to be enabled in settings" 488 + ) 489 + 490 + return AudioInfo(format=audio_format, duration=duration, is_gated=is_gated) 491 + 492 + 493 + async def _store_audio(ctx: UploadContext, audio_info: AudioInfo) -> StorageResult: 494 + """phase 2: store audio (transcode if lossless).""" 495 + transcode_info: TranscodeInfo | None = None 496 + 497 + if not audio_info.format.is_web_playable: 498 + if audio_info.is_gated: 499 + raise UploadPhaseError( 500 + "supporter-gated tracks cannot use lossless formats yet" 501 + ) 502 + 503 + original_ext = Path(ctx.filename).suffix.lower().lstrip(".") 504 + transcode_info = await _transcode_audio( 505 + ctx.upload_id, ctx.file_path, ctx.filename, original_ext 506 + ) 507 + if not transcode_info: 508 + raise UploadPhaseError("transcoding failed") 509 + 510 + file_id = transcode_info.transcoded_file_id 511 + playable_format = AudioFormat.from_extension( 512 + transcode_info.transcoded_file_type 513 + ) 514 + if not playable_format: 515 + raise UploadPhaseError("unknown transcoded format") 516 + else: 517 + file_id = await _save_audio_to_storage( 518 + ctx.upload_id, ctx.file_path, ctx.filename, gated=audio_info.is_gated 519 + ) 520 + if not file_id: 521 + raise UploadPhaseError("failed to save audio to storage") 522 + playable_format = audio_info.format 523 + transcode_info = None 524 + 525 + # get R2 URL (only for public tracks) 526 + r2_url: str | None = None 527 + if not audio_info.is_gated: 528 + ext = Path(ctx.filename).suffix.lower() 529 + playable_ext = playable_format.value if playable_format else ext[1:] 530 + r2_url = await storage.get_url( 531 + file_id, file_type="audio", extension=playable_ext 532 + ) 533 + if not r2_url: 534 + raise UploadPhaseError("failed to get public audio URL") 535 + 536 + return StorageResult( 537 + file_id=file_id, 538 + original_file_id=transcode_info.original_file_id if transcode_info else None, 539 + original_file_type=transcode_info.original_file_type 540 + if transcode_info 541 + else None, 542 + playable_format=playable_format, 543 + r2_url=r2_url, 544 + transcode_info=transcode_info, 545 + ) 546 + 515 547 516 - try: 517 - await job_service.update_progress( 518 - ctx.upload_id, JobStatus.PROCESSING, "processing upload..." 548 + async def _check_duplicate(ctx: UploadContext, sr: StorageResult) -> None: 549 + """phase 3: check for duplicate tracks.""" 550 + async with db_session() as db: 551 + result = await db.execute( 552 + select(Track).where( 553 + Track.file_id == sr.file_id, 554 + Track.artist_did == ctx.artist_did, 555 + ) 556 + ) 557 + if existing := result.scalar_one_or_none(): 558 + raise UploadPhaseError( 559 + f"duplicate upload: track already exists (id: {existing.id})" 519 560 ) 520 561 521 - # validate file type 522 - ext = Path(ctx.filename).suffix.lower() 523 - audio_format = AudioFormat.from_extension(ext) 524 - if not audio_format: 525 - await job_service.update_progress( 526 - ctx.upload_id, 527 - JobStatus.FAILED, 528 - "upload failed", 529 - error=f"unsupported file type: {ext}", 530 - ) 531 - return 562 + 563 + async def _upload_to_pds( 564 + ctx: UploadContext, audio_info: AudioInfo, sr: StorageResult 565 + ) -> PdsBlobResult | None: 566 + """phase 4: upload to PDS (best-effort). returns None if skipped.""" 567 + if audio_info.is_gated: 568 + return None 569 + 570 + async with db_session() as db: 571 + allow_pds_upload = await _should_upload_pds_blob(db, ctx.artist_did) 572 + if not allow_pds_upload: 573 + return None 574 + 575 + content_type = sr.playable_format.media_type 576 + if sr.transcode_info: 577 + pds_file_data = sr.transcode_info.transcoded_data 578 + else: 579 + async with aiofiles.open(ctx.file_path, "rb") as f: 580 + pds_file_data = await f.read() 581 + 582 + return await _try_upload_to_pds( 583 + ctx.upload_id, ctx.auth_session, pds_file_data, content_type 584 + ) 585 + 586 + 587 + async def _store_image(ctx: UploadContext) -> tuple[str | None, str | None]: 588 + """phase 5: store image (optional). returns (image_id, image_url).""" 589 + if not ctx.image_path or not ctx.image_filename: 590 + return None, None 591 + return await _save_image_to_storage( 592 + ctx.upload_id, ctx.image_path, ctx.image_filename, ctx.image_content_type 593 + ) 594 + 532 595 533 - # extract duration 534 - with open(ctx.file_path, "rb") as f: 535 - duration = extract_duration(f) 596 + async def _create_records( 597 + ctx: UploadContext, 598 + audio_info: AudioInfo, 599 + sr: StorageResult, 600 + pds_result: PdsBlobResult | None, 601 + image_id: str | None, 602 + image_url: str | None, 603 + ) -> Track: 604 + """phase 6: create ATProto record + DB track record.""" 605 + ext = Path(ctx.filename).suffix.lower() 606 + playable_file_type = sr.playable_format.value if sr.playable_format else ext[1:] 536 607 537 - # validate gating requirements if support_gate is set 538 - is_gated = ctx.support_gate is not None 539 - if is_gated: 540 - async with db_session() as db: 541 - prefs_result = await db.execute( 542 - select(UserPreferences).where( 543 - UserPreferences.did == ctx.artist_did 544 - ) 545 - ) 546 - prefs = prefs_result.scalar_one_or_none() 547 - if not prefs or prefs.support_url != "atprotofans": 548 - await job_service.update_progress( 549 - ctx.upload_id, 550 - JobStatus.FAILED, 551 - "upload failed", 552 - error="supporter gating requires atprotofans to be enabled in settings", 553 - ) 554 - return 608 + async with db_session() as db: 609 + result = await db.execute(select(Artist).where(Artist.did == ctx.artist_did)) 610 + artist = result.scalar_one_or_none() 611 + if not artist: 612 + raise UploadPhaseError("artist profile not found") 555 613 556 - # for non-web-playable formats, transcode first 557 - transcode_info: TranscodeInfo | None = None 558 - if not audio_format.is_web_playable: 559 - # gated tracks don't support transcoding yet 560 - if is_gated: 561 - await job_service.update_progress( 562 - ctx.upload_id, 563 - JobStatus.FAILED, 564 - "upload failed", 565 - error="supporter-gated tracks cannot use lossless formats yet", 566 - ) 567 - return 614 + # resolve featured artists 615 + featured_artists: list[dict] = [] 616 + if ctx.features_json: 617 + await job_service.update_progress( 618 + ctx.upload_id, 619 + JobStatus.PROCESSING, 620 + "resolving featured artists...", 621 + phase="metadata", 622 + ) 623 + featured_artists = await resolve_featured_artists( 624 + ctx.features_json, artist.handle 625 + ) 568 626 569 - # use actual extension from filename (e.g., "aif" not "aiff") 570 - original_ext = Path(ctx.filename).suffix.lower().lstrip(".") 571 - transcode_info = await _transcode_audio( 572 - ctx.upload_id, 573 - ctx.file_path, 574 - ctx.filename, 575 - original_ext, 576 - ) 577 - if not transcode_info: 578 - return 627 + # create ATProto record 628 + await job_service.update_progress( 629 + ctx.upload_id, 630 + JobStatus.PROCESSING, 631 + "creating atproto record...", 632 + phase="atproto", 633 + ) 634 + try: 635 + if audio_info.is_gated: 636 + from urllib.parse import urljoin 579 637 580 - # use transcoded file for playback, store original for export 581 - file_id = transcode_info.transcoded_file_id 582 - original_file_id = transcode_info.original_file_id 583 - original_file_type = transcode_info.original_file_type 584 - playable_format = AudioFormat.from_extension( 585 - transcode_info.transcoded_file_type 586 - ) 638 + backend_url = settings.atproto.redirect_uri.rsplit("/", 2)[0] 639 + audio_url_for_record = urljoin(backend_url + "/", f"audio/{sr.file_id}") 587 640 else: 588 - # web-playable format: save directly 589 - file_id = await _save_audio_to_storage( 590 - ctx.upload_id, ctx.file_path, ctx.filename, gated=is_gated 591 - ) 592 - playable_format = audio_format 641 + assert sr.r2_url is not None 642 + audio_url_for_record = sr.r2_url 643 + 644 + atproto_result = await create_track_record( 645 + auth_session=ctx.auth_session, 646 + title=ctx.title, 647 + artist=artist.display_name, 648 + audio_url=audio_url_for_record, 649 + file_type=playable_file_type, 650 + album=ctx.album, 651 + duration=audio_info.duration, 652 + features=featured_artists or None, 653 + image_url=image_url, 654 + support_gate=ctx.support_gate, 655 + audio_blob=pds_result.blob_ref if pds_result else None, 656 + ) 657 + if not atproto_result: 658 + raise ValueError("PDS returned no record data") 659 + atproto_uri, atproto_cid = atproto_result 660 + except Exception as e: 661 + logger.error("ATProto sync failed for upload %s: %s", ctx.upload_id, e) 662 + # cleanup orphaned media 663 + with contextlib.suppress(Exception): 664 + await storage.delete(sr.file_id, playable_file_type) 665 + if sr.original_file_id and sr.original_file_type: 666 + with contextlib.suppress(Exception): 667 + await storage.delete(sr.original_file_id, sr.original_file_type) 668 + if image_id: 669 + with contextlib.suppress(Exception): 670 + await storage.delete(image_id) 671 + raise UploadPhaseError(f"failed to sync track to ATProto: {e}") from e 593 672 594 - if not file_id: 595 - return 673 + # create DB record 674 + await job_service.update_progress( 675 + ctx.upload_id, 676 + JobStatus.PROCESSING, 677 + "saving track metadata...", 678 + phase="database", 679 + ) 596 680 597 - # check for duplicate 598 - async with db_session() as db: 599 - result = await db.execute( 600 - select(Track).where( 601 - Track.file_id == file_id, 602 - Track.artist_did == ctx.artist_did, 603 - ) 604 - ) 605 - if existing := result.scalar_one_or_none(): 606 - await job_service.update_progress( 607 - ctx.upload_id, 608 - JobStatus.FAILED, 609 - "upload failed", 610 - error=f"duplicate upload: track already exists (id: {existing.id})", 611 - ) 612 - return 681 + extra: dict = {} 682 + if audio_info.duration: 683 + extra["duration"] = audio_info.duration 684 + if ctx.auto_tag: 685 + extra["auto_tag"] = True 613 686 614 - # get R2 URL (only for public tracks - gated tracks have no public URL) 615 - # use playable_format for URL since that's what the ATProto record points to 616 - r2_url: str | None = None 617 - if not is_gated: 618 - playable_ext = playable_format.value if playable_format else ext[1:] 619 - r2_url = await storage.get_url( 620 - file_id, file_type="audio", extension=playable_ext 621 - ) 622 - if not r2_url: 623 - await job_service.update_progress( 624 - ctx.upload_id, 625 - JobStatus.FAILED, 626 - "upload failed", 627 - error="failed to get public audio URL", 628 - ) 629 - return 687 + album_record = None 688 + if ctx.album: 689 + extra["album"] = ctx.album 690 + album_record = await get_or_create_album( 691 + db, artist, ctx.album, image_id, image_url 692 + ) 630 693 631 - # try uploading blob to user's PDS (best-effort, falls back to R2-only) 632 - # gated tracks skip PDS blob upload since they need auth-protected access 633 - if not is_gated and playable_format: 634 - async with db_session() as db: 635 - allow_pds_upload = await _should_upload_pds_blob(db, ctx.artist_did) 636 - if allow_pds_upload: 637 - content_type = playable_format.media_type 638 - # use transcoded bytes if available, otherwise read original file 639 - if transcode_info: 640 - pds_file_data = transcode_info.transcoded_data 641 - else: 642 - async with aiofiles.open(ctx.file_path, "rb") as f: 643 - pds_file_data = await f.read() 644 - pds_blob_result = await _try_upload_to_pds( 645 - ctx.upload_id, 646 - ctx.auth_session, 647 - pds_file_data, 648 - content_type, 649 - ) 694 + has_pds_blob = pds_result and pds_result.cid is not None 695 + audio_storage = "both" if has_pds_blob else "r2" 650 696 651 - # save image if provided 652 - image_url = None 653 - if ctx.image_path and ctx.image_filename: 654 - image_id, image_url = await _save_image_to_storage( 655 - ctx.upload_id, 656 - ctx.image_path, 657 - ctx.image_filename, 658 - ctx.image_content_type, 659 - ) 697 + track = Track( 698 + title=ctx.title, 699 + file_id=sr.file_id, 700 + file_type=playable_file_type, 701 + original_file_id=sr.original_file_id, 702 + original_file_type=sr.original_file_type, 703 + artist_did=ctx.artist_did, 704 + extra=extra, 705 + album_id=album_record.id if album_record else None, 706 + features=featured_artists, 707 + r2_url=sr.r2_url, 708 + atproto_record_uri=atproto_uri, 709 + atproto_record_cid=atproto_cid, 710 + image_id=image_id, 711 + image_url=image_url, 712 + support_gate=ctx.support_gate, 713 + audio_storage=audio_storage, 714 + pds_blob_cid=pds_result.cid if pds_result else None, 715 + pds_blob_size=pds_result.size if pds_result else None, 716 + ) 660 717 661 - # get artist and resolve featured artists 662 - async with db_session() as db: 663 - result = await db.execute( 664 - select(Artist).where(Artist.did == ctx.artist_did) 665 - ) 666 - artist = result.scalar_one_or_none() 667 - if not artist: 668 - await job_service.update_progress( 669 - ctx.upload_id, 670 - JobStatus.FAILED, 671 - "upload failed", 672 - error="artist profile not found", 673 - ) 674 - return 718 + db.add(track) 719 + try: 720 + await db.commit() 721 + await db.refresh(track) 722 + except IntegrityError as e: 723 + await db.rollback() 724 + with contextlib.suppress(Exception): 725 + await storage.delete(sr.file_id, playable_file_type) 726 + if sr.original_file_id and sr.original_file_type: 727 + with contextlib.suppress(Exception): 728 + await storage.delete(sr.original_file_id, sr.original_file_type) 729 + raise UploadPhaseError(f"database constraint violation: {e!s}") from e 675 730 676 - # resolve featured artists 677 - featured_artists: list[dict] = [] 678 - if ctx.features_json: 679 - await job_service.update_progress( 680 - ctx.upload_id, 681 - JobStatus.PROCESSING, 682 - "resolving featured artists...", 683 - phase="metadata", 684 - ) 685 - featured_artists = await resolve_featured_artists( 686 - ctx.features_json, artist.handle 687 - ) 731 + return track 688 732 689 - # create ATProto record 690 - await job_service.update_progress( 691 - ctx.upload_id, 692 - JobStatus.PROCESSING, 693 - "creating atproto record...", 694 - phase="atproto", 695 - ) 696 - try: 697 - # for gated tracks, use API endpoint URL instead of direct R2 URL 698 - # this ensures playback goes through our auth check 699 - if is_gated: 700 - # use backend URL for gated audio 701 - from urllib.parse import urljoin 702 733 703 - backend_url = settings.atproto.redirect_uri.rsplit("/", 2)[0] 704 - audio_url_for_record = urljoin( 705 - backend_url + "/", f"audio/{file_id}" 706 - ) 707 - else: 708 - # r2_url is guaranteed non-None here - we returned early above if None 709 - assert r2_url is not None 710 - audio_url_for_record = r2_url 734 + async def _schedule_post_upload( 735 + ctx: UploadContext, sr: StorageResult, track: Track 736 + ) -> None: 737 + """phase 7: post-upload tasks (tags, notifications, background jobs).""" 738 + async with db_session() as db: 739 + await add_tags_to_track(db, track.id, ctx.tags, ctx.artist_did) 711 740 712 - # use playable format for ATProto record (transcoded if applicable) 713 - playable_file_type = ( 714 - playable_format.value if playable_format else ext[1:] 715 - ) 716 - atproto_result = await create_track_record( 717 - auth_session=ctx.auth_session, 718 - title=ctx.title, 719 - artist=artist.display_name, 720 - audio_url=audio_url_for_record, 721 - file_type=playable_file_type, 722 - album=ctx.album, 723 - duration=duration, 724 - features=featured_artists or None, 725 - image_url=image_url, 726 - support_gate=ctx.support_gate, 727 - audio_blob=pds_blob_result.blob_ref 728 - if pds_blob_result 729 - else None, 730 - ) 731 - if not atproto_result: 732 - raise ValueError("PDS returned no record data") 733 - atproto_uri, atproto_cid = atproto_result 734 - except Exception as e: 735 - logger.error( 736 - "ATProto sync failed for upload %s: %s", ctx.upload_id, e 737 - ) 738 - await job_service.update_progress( 739 - ctx.upload_id, 740 - JobStatus.FAILED, 741 - "upload failed", 742 - error=f"failed to sync track to ATProto: {e}", 743 - phase="atproto", 744 - ) 745 - # cleanup orphaned media 746 - with contextlib.suppress(Exception): 747 - await storage.delete(file_id, playable_file_type) 748 - if original_file_id and original_file_type: 749 - with contextlib.suppress(Exception): 750 - await storage.delete(original_file_id, original_file_type) 751 - if image_id: 752 - with contextlib.suppress(Exception): 753 - await storage.delete(image_id) 754 - return 741 + # skip notifications and copyright scan for integration tests on staging 742 + is_integration_test = ( 743 + settings.observability.environment == "staging" 744 + and "integration-test" in (ctx.tags or []) 745 + ) 746 + if not is_integration_test: 747 + await _send_track_notification(db, track) 748 + if sr.r2_url and not is_integration_test: 749 + await schedule_copyright_scan(track.id, sr.r2_url) 755 750 756 - # create track record 757 - await job_service.update_progress( 758 - ctx.upload_id, 759 - JobStatus.PROCESSING, 760 - "saving track metadata...", 761 - phase="database", 762 - ) 751 + # generate CLAP embedding for vibe search 752 + if sr.r2_url and settings.modal.enabled and settings.turbopuffer.enabled: 753 + await schedule_embedding_generation(track.id, sr.r2_url) 763 754 764 - extra: dict = {} 765 - if duration: 766 - extra["duration"] = duration 767 - if ctx.auto_tag: 768 - extra["auto_tag"] = True 755 + # classify genres via Replicate 756 + if sr.r2_url and settings.replicate.enabled: 757 + await schedule_genre_classification(track.id, sr.r2_url) 769 758 770 - album_record = None 771 - if ctx.album: 772 - extra["album"] = ctx.album 773 - album_record = await get_or_create_album( 774 - db, artist, ctx.album, image_id, image_url 775 - ) 759 + # sync album list record if track is in an album 760 + if track.album_id: 761 + await schedule_album_list_sync(ctx.auth_session.session_id, track.album_id) 776 762 777 - # determine audio storage type 778 - has_pds_blob = pds_blob_result and pds_blob_result.cid is not None 779 - audio_storage = "both" if has_pds_blob else "r2" 780 763 781 - track = Track( 782 - title=ctx.title, 783 - file_id=file_id, 784 - file_type=playable_file_type, 785 - original_file_id=original_file_id, 786 - original_file_type=original_file_type, 787 - artist_did=ctx.artist_did, 788 - extra=extra, 789 - album_id=album_record.id if album_record else None, 790 - features=featured_artists, 791 - r2_url=r2_url, 792 - atproto_record_uri=atproto_uri, 793 - atproto_record_cid=atproto_cid, 794 - image_id=image_id, 795 - image_url=image_url, 796 - support_gate=ctx.support_gate, 797 - audio_storage=audio_storage, 798 - pds_blob_cid=pds_blob_result.cid if pds_blob_result else None, 799 - pds_blob_size=pds_blob_result.size if pds_blob_result else None, 800 - ) 764 + async def _process_upload_background(ctx: UploadContext) -> None: 765 + """orchestrate the upload pipeline through named phases.""" 766 + with logfire.span( 767 + "process upload background", upload_id=ctx.upload_id, filename=ctx.filename 768 + ): 769 + try: 770 + await job_service.update_progress( 771 + ctx.upload_id, JobStatus.PROCESSING, "processing upload..." 772 + ) 801 773 802 - db.add(track) 803 - try: 804 - await db.commit() 805 - await db.refresh(track) 774 + # phase 1: validate and prepare audio 775 + audio_info = await _validate_audio(ctx) 806 776 807 - await _add_tags_to_track(db, track.id, ctx.tags, ctx.artist_did) 777 + # phase 2: store audio (transcode if lossless) 778 + sr = await _store_audio(ctx, audio_info) 808 779 809 - # skip notifications and copyright scan for integration tests on staging 810 - # (synthetic audio, no point spamming DMs or paying for AudD API calls) 811 - is_integration_test = ( 812 - settings.observability.environment == "staging" 813 - and "integration-test" in (ctx.tags or []) 814 - ) 815 - if not is_integration_test: 816 - await _send_track_notification(db, track) 817 - if r2_url and not is_integration_test: 818 - await schedule_copyright_scan(track.id, r2_url) 780 + # phase 3: check for duplicates 781 + await _check_duplicate(ctx, sr) 819 782 820 - # generate CLAP embedding for vibe search 821 - if ( 822 - r2_url 823 - and settings.modal.enabled 824 - and settings.turbopuffer.enabled 825 - ): 826 - await schedule_embedding_generation(track.id, r2_url) 783 + # phase 4: upload to PDS (best-effort) 784 + pds_result = await _upload_to_pds(ctx, audio_info, sr) 827 785 828 - # classify genres via Replicate 829 - if r2_url and settings.replicate.enabled: 830 - await schedule_genre_classification(track.id, r2_url) 786 + # phase 5: store image (optional) 787 + image_id, image_url = await _store_image(ctx) 831 788 832 - # sync album list record if track is in an album 833 - if album_record: 834 - await schedule_album_list_sync( 835 - ctx.auth_session.session_id, album_record.id 836 - ) 789 + # phase 6: create records (ATProto + DB) 790 + track = await _create_records( 791 + ctx, audio_info, sr, pds_result, image_id, image_url 792 + ) 837 793 838 - await job_service.update_progress( 839 - ctx.upload_id, 840 - JobStatus.COMPLETED, 841 - "upload completed successfully", 842 - result={"track_id": track.id}, 843 - ) 794 + # phase 7: post-upload tasks (tags, notifications, background jobs) 795 + await _schedule_post_upload(ctx, sr, track) 844 796 845 - except IntegrityError as e: 846 - await db.rollback() 847 - await job_service.update_progress( 848 - ctx.upload_id, 849 - JobStatus.FAILED, 850 - "upload failed", 851 - error=f"database constraint violation: {e!s}", 852 - ) 853 - with contextlib.suppress(Exception): 854 - await storage.delete(file_id, playable_file_type) 855 - if original_file_id and original_file_type: 856 - with contextlib.suppress(Exception): 857 - await storage.delete(original_file_id, original_file_type) 797 + await job_service.update_progress( 798 + ctx.upload_id, 799 + JobStatus.COMPLETED, 800 + "upload completed successfully", 801 + result={"track_id": track.id}, 802 + ) 858 803 804 + except UploadPhaseError as e: 805 + await job_service.update_progress( 806 + ctx.upload_id, JobStatus.FAILED, "upload failed", error=e.error 807 + ) 859 808 except Exception as e: 860 809 logger.exception(f"upload {ctx.upload_id} failed with unexpected error") 861 810 await job_service.update_progress(
+77
backend/src/backend/utilities/tags.py
··· 1 1 """tag normalization and management utilities.""" 2 2 3 3 import re 4 + from datetime import UTC, datetime 4 5 from typing import Annotated 5 6 7 + import logfire 6 8 from pydantic import Field, TypeAdapter, ValidationError 9 + from sqlalchemy import select 10 + from sqlalchemy.exc import IntegrityError 11 + from sqlalchemy.ext.asyncio import AsyncSession 7 12 8 13 # tags that are hidden by default for new users 9 14 DEFAULT_HIDDEN_TAGS: list[str] = ["ai"] ··· 95 100 def is_tag_hidden_by_default(tag: str) -> bool: 96 101 """check if a tag should be hidden by default.""" 97 102 return normalize_tag(tag) in DEFAULT_HIDDEN_TAGS 103 + 104 + 105 + # --- DB-facing tag operations --- 106 + # Tag/TrackTag imports are deferred to avoid circular import: 107 + # models/__init__.py -> models/preferences.py -> utilities/tags.py 108 + 109 + 110 + async def get_or_create_tag( 111 + db: AsyncSession, tag_name: str, creator_did: str 112 + ): # returns Tag 113 + """get existing tag or create new one, handling race conditions. 114 + 115 + uses a select-then-insert pattern with IntegrityError handling 116 + to safely handle concurrent tag creation. 117 + """ 118 + from backend.models import Tag 119 + 120 + # first try to find existing tag 121 + result = await db.execute(select(Tag).where(Tag.name == tag_name)) 122 + tag = result.scalar_one_or_none() 123 + if tag: 124 + return tag 125 + 126 + # try to create new tag 127 + tag = Tag( 128 + name=tag_name, 129 + created_by_did=creator_did, 130 + created_at=datetime.now(UTC), 131 + ) 132 + db.add(tag) 133 + 134 + try: 135 + await db.flush() 136 + return tag 137 + except IntegrityError as e: 138 + # only handle unique constraint violation on tag name (pgcode 23505) 139 + # re-raise other integrity errors (e.g., foreign key violations) 140 + pgcode = getattr(e.orig, "pgcode", None) 141 + if pgcode != "23505": 142 + raise 143 + # another process created the tag - rollback and fetch it 144 + await db.rollback() 145 + result = await db.execute(select(Tag).where(Tag.name == tag_name)) 146 + tag = result.scalar_one() 147 + return tag 148 + 149 + 150 + async def add_tags_to_track( 151 + db: AsyncSession, 152 + track_id: int, 153 + validated_tags: list[str], 154 + creator_did: str, 155 + ) -> None: 156 + """add validated tags to a track.""" 157 + from backend.models import TrackTag 158 + 159 + if not validated_tags: 160 + return 161 + 162 + try: 163 + for tag_name in validated_tags: 164 + tag = await get_or_create_tag(db, tag_name, creator_did) 165 + track_tag = TrackTag(track_id=track_id, tag_id=tag.id) 166 + db.add(track_tag) 167 + await db.commit() 168 + except Exception as e: 169 + logfire.error( 170 + "failed to add tags to track", 171 + track_id=track_id, 172 + tags=validated_tags, 173 + error=str(e), 174 + )
backend/tests/_internal/__init__.py

This is a binary file and will not be displayed.

+62
backend/tests/_internal/test_auth_modules.py
··· 1 + """stateless smoke tests for auth package modules.""" 2 + 3 + from backend._internal.auth.encryption import _decrypt_data, _encrypt_data 4 + from backend._internal.auth.exchange import ( 5 + consume_exchange_token, 6 + create_exchange_token, 7 + ) 8 + from backend._internal.auth.scopes import ( 9 + _check_scope_coverage, 10 + _get_missing_scopes, 11 + _parse_scopes, 12 + ) 13 + from backend._internal.auth.session import Session 14 + 15 + 16 + def test_scopes_parse_roundtrip(): 17 + """parse and validate scope strings.""" 18 + scope = "atproto repo:fm.plyr.track repo:fm.plyr.like" 19 + parsed = _parse_scopes(scope) 20 + assert parsed == {"repo:fm.plyr.track", "repo:fm.plyr.like"} 21 + assert _check_scope_coverage(scope, scope) is True 22 + assert _get_missing_scopes(scope, scope) == set() 23 + 24 + 25 + def test_encryption_roundtrip(): 26 + """encrypt/decrypt produces original.""" 27 + original = '{"access_token": "secret123", "refresh_token": "refresh456"}' 28 + encrypted = _encrypt_data(original) 29 + assert encrypted != original 30 + decrypted = _decrypt_data(encrypted) 31 + assert decrypted == original 32 + 33 + 34 + def test_session_dataclass_fields(): 35 + """Session has expected fields.""" 36 + session = Session( 37 + session_id="sid-123", 38 + did="did:plc:test", 39 + handle="test.bsky.social", 40 + oauth_session={"session_id": "oauth-123"}, 41 + ) 42 + assert session.session_id == "sid-123" 43 + assert session.did == "did:plc:test" 44 + assert session.handle == "test.bsky.social" 45 + assert session.get_oauth_session_id() == "oauth-123" 46 + 47 + 48 + def test_session_oauth_session_id_fallback(): 49 + """Session.get_oauth_session_id falls back to DID.""" 50 + session = Session( 51 + session_id="sid-456", 52 + did="did:plc:fallback", 53 + handle="fallback.bsky.social", 54 + oauth_session={}, 55 + ) 56 + assert session.get_oauth_session_id() == "did:plc:fallback" 57 + 58 + 59 + def test_exchange_token_functions_exist(): 60 + """exchange token creation/consumption functions are importable.""" 61 + assert callable(create_exchange_token) 62 + assert callable(consume_exchange_token)
+79
backend/tests/api/test_upload_phases.py
··· 1 + """stateless unit tests for upload pipeline phase dataclasses.""" 2 + 3 + from backend._internal.audio import AudioFormat 4 + from backend.api.tracks.uploads import ( 5 + AudioInfo, 6 + PdsBlobResult, 7 + StorageResult, 8 + UploadPhaseError, 9 + ) 10 + 11 + 12 + def test_audio_info_dataclass(): 13 + """AudioInfo construction and fields.""" 14 + info = AudioInfo(format=AudioFormat.MP3, duration=120, is_gated=False) 15 + assert info.format == AudioFormat.MP3 16 + assert info.duration == 120 17 + assert info.is_gated is False 18 + 19 + 20 + def test_audio_info_gated(): 21 + """AudioInfo with gated content.""" 22 + info = AudioInfo(format=AudioFormat.WAV, duration=60, is_gated=True) 23 + assert info.is_gated is True 24 + 25 + 26 + def test_storage_result_dataclass(): 27 + """StorageResult construction and fields.""" 28 + sr = StorageResult( 29 + file_id="abc123", 30 + original_file_id=None, 31 + original_file_type=None, 32 + playable_format=AudioFormat.MP3, 33 + r2_url="https://cdn.example.com/abc123.mp3", 34 + transcode_info=None, 35 + ) 36 + assert sr.file_id == "abc123" 37 + assert sr.original_file_id is None 38 + assert sr.playable_format == AudioFormat.MP3 39 + assert sr.r2_url is not None 40 + 41 + 42 + def test_storage_result_with_transcode(): 43 + """StorageResult with original file from transcoding.""" 44 + sr = StorageResult( 45 + file_id="transcoded123", 46 + original_file_id="original456", 47 + original_file_type="flac", 48 + playable_format=AudioFormat.MP3, 49 + r2_url="https://cdn.example.com/transcoded123.mp3", 50 + transcode_info=None, 51 + ) 52 + assert sr.original_file_id == "original456" 53 + assert sr.original_file_type == "flac" 54 + 55 + 56 + def test_pds_result_defaults(): 57 + """PdsBlobResult with None blob_cid.""" 58 + result = PdsBlobResult(blob_ref=None, cid=None, size=None) 59 + assert result.blob_ref is None 60 + assert result.cid is None 61 + assert result.size is None 62 + 63 + 64 + def test_pds_result_with_data(): 65 + """PdsBlobResult with actual values.""" 66 + result = PdsBlobResult( 67 + blob_ref={"ref": {"$link": "bafyreid123"}}, 68 + cid="bafyreid123", 69 + size=1024000, 70 + ) 71 + assert result.cid == "bafyreid123" 72 + assert result.size == 1024000 73 + 74 + 75 + def test_upload_phase_error(): 76 + """UploadPhaseError carries error message.""" 77 + err = UploadPhaseError("something went wrong") 78 + assert err.error == "something went wrong" 79 + assert str(err) == "something went wrong"
+4 -4
backend/tests/test_auth.py
··· 346 346 347 347 def test_is_confidential_client_false_by_default(): 348 348 """verify is_confidential_client returns False when OAUTH_JWK not set.""" 349 - with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 349 + with patch("backend._internal.auth.session.settings.atproto.oauth_jwk", None): 350 350 assert is_confidential_client() is False 351 351 352 352 ··· 354 354 """verify is_confidential_client returns True when OAUTH_JWK is set.""" 355 355 test_jwk = '{"kty":"EC","crv":"P-256","x":"test","y":"test","d":"test"}' 356 356 357 - with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 357 + with patch("backend._internal.auth.session.settings.atproto.oauth_jwk", test_jwk): 358 358 assert is_confidential_client() is True 359 359 360 360 361 361 def test_get_public_jwks_returns_none_without_config(): 362 362 """verify get_public_jwks returns None when OAUTH_JWK not configured.""" 363 - with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 363 + with patch("backend._internal.auth.oauth.settings.atproto.oauth_jwk", None): 364 364 assert get_public_jwks() is None 365 365 366 366 ··· 383 383 jwk_dict["kid"] = "test-key-id" # add kid to test preservation 384 384 test_jwk = json.dumps(jwk_dict) 385 385 386 - with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 386 + with patch("backend._internal.auth.oauth.settings.atproto.oauth_jwk", test_jwk): 387 387 jwks = get_public_jwks() 388 388 389 389 assert jwks is not None
+13 -10
backend/tests/test_background_tasks.py
··· 4 4 import contextlib 5 5 from unittest.mock import AsyncMock, MagicMock, patch 6 6 7 - import backend._internal.background_tasks as bg_tasks 8 7 import backend._internal.export_tasks as export_tasks 8 + import backend._internal.tasks.copyright as copyright_tasks 9 + import backend._internal.tasks.sync as sync_tasks 9 10 10 11 11 12 async def test_schedule_export_uses_docket() -> None: ··· 39 40 mock_docket.add = MagicMock(return_value=mock_schedule) 40 41 41 42 with ( 42 - patch.object(bg_tasks, "get_docket", return_value=mock_docket), 43 - patch.object(bg_tasks, "scan_copyright", MagicMock()), 43 + patch.object(copyright_tasks, "get_docket", return_value=mock_docket), 44 + patch.object(copyright_tasks, "scan_copyright", MagicMock()), 44 45 ): 45 - await bg_tasks.schedule_copyright_scan(123, "https://example.com/audio.mp3") 46 + await copyright_tasks.schedule_copyright_scan( 47 + 123, "https://example.com/audio.mp3" 48 + ) 46 49 47 50 mock_docket.add.assert_called_once() 48 51 assert calls == [(123, "https://example.com/audio.mp3")] ··· 59 62 mock_docket.add = MagicMock(return_value=mock_schedule) 60 63 61 64 with ( 62 - patch.object(bg_tasks, "get_docket", return_value=mock_docket), 63 - patch.object(bg_tasks, "sync_atproto", MagicMock()), 65 + patch.object(sync_tasks, "get_docket", return_value=mock_docket), 66 + patch.object(sync_tasks, "sync_atproto", MagicMock()), 64 67 ): 65 - await bg_tasks.schedule_atproto_sync("session-abc", "did:plc:testuser") 68 + await sync_tasks.schedule_atproto_sync("session-abc", "did:plc:testuser") 66 69 67 70 mock_docket.add.assert_called_once() 68 71 assert calls == [("session-abc", "did:plc:testuser")] ··· 88 91 mock_docket.add = MagicMock(return_value=mock_schedule) 89 92 90 93 with ( 91 - patch.object(bg_tasks, "get_docket", return_value=mock_docket), 92 - patch.object(bg_tasks, "scrobble_to_teal", MagicMock()), 94 + patch.object(sync_tasks, "get_docket", return_value=mock_docket), 95 + patch.object(sync_tasks, "scrobble_to_teal", MagicMock()), 93 96 ): 94 - await bg_tasks.schedule_teal_scrobble( 97 + await sync_tasks.schedule_teal_scrobble( 95 98 session_id="session-xyz", 96 99 track_id=42, 97 100 track_title="Test Track",
+7 -7
backend/tests/test_moderation.py
··· 11 11 from sqlalchemy.ext.asyncio import AsyncSession 12 12 13 13 from backend._internal import Session, require_auth 14 - from backend._internal.moderation import ( 15 - get_active_copyright_labels, 16 - scan_track_for_copyright, 17 - ) 18 - from backend._internal.moderation_client import ( 14 + from backend._internal.clients.moderation import ( 19 15 CreateReportResult, 20 16 ModerationClient, 21 17 ScanResult, 22 18 SensitiveImagesResult, 19 + ) 20 + from backend._internal.moderation import ( 21 + get_active_copyright_labels, 22 + scan_track_for_copyright, 23 23 ) 24 24 from backend.main import app 25 25 from backend.models import Artist, CopyrightScan, Track ··· 366 366 367 367 async def test_sync_copyright_resolutions(db_session: AsyncSession) -> None: 368 368 """test that sync_copyright_resolutions updates flagged scans.""" 369 - from backend._internal.background_tasks import sync_copyright_resolutions 369 + from backend._internal.tasks import sync_copyright_resolutions 370 370 371 371 # create test artist and tracks 372 372 artist = Artist( ··· 419 419 await db_session.commit() 420 420 421 421 with patch( 422 - "backend._internal.moderation_client.get_moderation_client" 422 + "backend._internal.clients.moderation.get_moderation_client" 423 423 ) as mock_get_client: 424 424 mock_client = AsyncMock() 425 425 # only track2's URI is still active
+4 -4
backend/tests/test_transcoder.py
··· 7 7 import httpx 8 8 import pytest 9 9 10 - from backend._internal.transcoder_client import ( 10 + from backend._internal.clients.transcoder import ( 11 11 TranscoderClient, 12 12 get_transcoder_client, 13 13 ) ··· 151 151 152 152 def test_transcoder_client_from_settings() -> None: 153 153 """test TranscoderClient.from_settings() creates client correctly.""" 154 - with patch("backend._internal.transcoder_client.settings") as mock_settings: 154 + with patch("backend._internal.clients.transcoder.settings") as mock_settings: 155 155 mock_settings.transcoder.service_url = "https://transcoder.example.com" 156 156 mock_settings.transcoder.auth_token = "secret-token" 157 157 mock_settings.transcoder.timeout_seconds = 120 ··· 166 166 167 167 def test_get_transcoder_client_singleton() -> None: 168 168 """test get_transcoder_client() returns singleton.""" 169 - import backend._internal.transcoder_client as module 169 + import backend._internal.clients.transcoder as module 170 170 171 171 # reset singleton 172 172 module._client = None 173 173 174 - with patch("backend._internal.transcoder_client.settings") as mock_settings: 174 + with patch("backend._internal.clients.transcoder.settings") as mock_settings: 175 175 mock_settings.transcoder.service_url = "https://transcoder.example.com" 176 176 mock_settings.transcoder.auth_token = "token" 177 177 mock_settings.transcoder.timeout_seconds = 60
+73
backend/tests/utilities/test_tags.py
··· 1 + """tests for shared tag operations (DB-backed).""" 2 + 3 + from sqlalchemy import select 4 + from sqlalchemy.ext.asyncio import AsyncSession 5 + 6 + from backend.models import Artist, Tag, Track, TrackTag 7 + from backend.utilities.tags import add_tags_to_track, get_or_create_tag 8 + 9 + 10 + async def _create_artist(db_session: AsyncSession, did: str) -> Artist: 11 + """helper to create an artist for FK constraints.""" 12 + artist = Artist(did=did, handle=f"{did.split(':')[-1]}.test", display_name="Test") 13 + db_session.add(artist) 14 + await db_session.flush() 15 + return artist 16 + 17 + 18 + async def test_get_or_create_tag_creates_new(db_session: AsyncSession): 19 + """creates tag that doesn't exist.""" 20 + await _create_artist(db_session, "did:plc:test1") 21 + tag = await get_or_create_tag(db_session, "electronic", "did:plc:test1") 22 + assert tag.name == "electronic" 23 + assert tag.created_by_did == "did:plc:test1" 24 + assert tag.id is not None 25 + 26 + 27 + async def test_get_or_create_tag_returns_existing(db_session: AsyncSession): 28 + """idempotent - returns existing tag.""" 29 + await _create_artist(db_session, "did:plc:test2") 30 + tag1 = await get_or_create_tag(db_session, "ambient", "did:plc:test2") 31 + await db_session.commit() 32 + 33 + tag2 = await get_or_create_tag(db_session, "ambient", "did:plc:test3") 34 + assert tag2.id == tag1.id 35 + assert tag2.name == "ambient" 36 + 37 + 38 + async def test_add_tags_to_track(db_session: AsyncSession): 39 + """associates tags with track.""" 40 + await _create_artist(db_session, "did:plc:tagger") 41 + 42 + track = Track( 43 + title="test track", 44 + file_id="test_file_id_tags", 45 + file_type="mp3", 46 + artist_did="did:plc:tagger", 47 + atproto_record_uri="at://did:plc:tagger/fm.plyr.track/test", 48 + atproto_record_cid="bafytest", 49 + ) 50 + db_session.add(track) 51 + await db_session.commit() 52 + await db_session.refresh(track) 53 + 54 + await add_tags_to_track(db_session, track.id, ["rock", "indie"], "did:plc:tagger") 55 + 56 + # verify tags were created and associated 57 + result = await db_session.execute( 58 + select(TrackTag).where(TrackTag.track_id == track.id) 59 + ) 60 + track_tags = result.scalars().all() 61 + assert len(track_tags) == 2 62 + 63 + tag_ids = {tt.tag_id for tt in track_tags} 64 + result = await db_session.execute(select(Tag).where(Tag.id.in_(tag_ids))) 65 + tags = result.scalars().all() 66 + tag_names = {t.name for t in tags} 67 + assert tag_names == {"rock", "indie"} 68 + 69 + 70 + async def test_add_tags_to_track_empty(db_session: AsyncSession): 71 + """no-op when tags list is empty.""" 72 + # should not raise 73 + await add_tags_to_track(db_session, 999, [], "did:plc:test")
-8
loq.toml
··· 22 22 max_lines = 612 23 23 24 24 [[rules]] 25 - path = "backend/src/backend/_internal/auth.py" 26 - max_lines = 1410 27 - 28 - [[rules]] 29 - path = "backend/src/backend/_internal/background_tasks.py" 30 - max_lines = 954 31 - 32 - [[rules]] 33 25 path = "backend/src/backend/api/albums.py" 34 26 max_lines = 714 35 27
+2 -2
scripts/backfill_embeddings.py
··· 39 39 import httpx 40 40 from sqlalchemy import select 41 41 42 - from backend._internal.clap_client import get_clap_client 43 - from backend._internal.tpuf_client import upsert 42 + from backend._internal.clients.clap import get_clap_client 43 + from backend._internal.clients.tpuf import upsert 44 44 from backend.config import settings 45 45 from backend.models import Artist, Track 46 46 from backend.utilities.database import db_session
+1 -1
scripts/backfill_genres.py
··· 37 37 38 38 from sqlalchemy import select, text 39 39 40 - from backend._internal.replicate_client import get_replicate_client 40 + from backend._internal.clients.replicate import get_replicate_client 41 41 from backend.config import settings 42 42 from backend.models import Artist, Track 43 43 from backend.utilities.database import db_session