[MIRROR ONLY] A correct and efficient ATProto blob proxy for secure content delivery. codeberg.org/Blooym/porxie
36
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: improve application codebase quality and perf (#2)

Reviewed-on: https://codeberg.org/Blooym/porxie/pulls/2

authored by

Lyna and committed by
Blooym
99d2099f dca57154

+1077 -871
+1
.gitignore
··· 1 1 target/ 2 2 .direnv/ 3 + lexicons/ 3 4 4 5 test_server.py
+1
Cargo.lock
··· 2429 2429 "multihash-codetable", 2430 2430 "reqwest", 2431 2431 "serde", 2432 + "subtle", 2432 2433 "sysinfo", 2433 2434 "thiserror 2.0.18", 2434 2435 "tokio",
+6 -2
Cargo.toml
··· 54 54 "env-filter", 55 55 ] } 56 56 mime = { version = "0.3.17", default-features = false } 57 - tracing = { version = "0.1.44", features = ["std"], default-features = false } 58 57 moka = { version = "0.12.14", features = [ 59 58 "future", 60 59 "logging", 61 60 ], default-features = false } 62 61 multihash-codetable = { version = "0.2.1", features = [ 63 62 "sha2", 64 - "blake3", 63 + # "blake3", # if it ever gets added to the spec. 65 64 "std", 66 65 ], default-features = false } 67 66 reqwest = { version = "0.12.28", default-features = false, features = [ ··· 79 78 "derive", 80 79 "std", 81 80 ], default-features = false } 81 + subtle = { version = "2.6", default-features = false, features = ["std"] } 82 82 sysinfo = { version = "0.38.4", default-features = false, features = [ 83 83 "system", 84 84 ] } ··· 95 95 "trace", 96 96 "timeout", 97 97 "tracing", 98 + ], default-features = false } 99 + tracing = { version = "0.1.44", features = [ 100 + "attributes", 101 + "std", 98 102 ], default-features = false } 99 103 tracing-subscriber = { version = "0.3.22", features = [ 100 104 "ansi",
+4 -4
flake.lock
··· 2 2 "nodes": { 3 3 "nixpkgs": { 4 4 "locked": { 5 - "lastModified": 1776255774, 6 - "narHash": "sha256-psVTpH6PK3q1htMJpmdz1hLF5pQgEshu7gQWgKO6t6Y=", 5 + "lastModified": 1776329215, 6 + "narHash": "sha256-a8BYi3mzoJ/AcJP8UldOx8emoPRLeWqALZWu4ZvjPXw=", 7 7 "owner": "nixos", 8 8 "repo": "nixpkgs", 9 - "rev": "566acc07c54dc807f91625bb286cb9b321b5f42a", 9 + "rev": "b86751bc4085f48661017fa226dee99fab6c651b", 10 10 "type": "github" 11 11 }, 12 12 "original": { ··· 24 24 }, 25 25 "root": "root", 26 26 "version": 7 27 - } 27 + }
+3
flake.nix
··· 24 24 rustfmt 25 25 clippy 26 26 ]; 27 + env = { 28 + RUST_SRC_PATH = pkgs.rustPlatform.rustLibSrc; 29 + }; 27 30 }; 28 31 } 29 32 );
+349
src/blob_service.rs
··· 1 + use crate::{ 2 + http::{BytesStreamCappedError, PORXIE_USER_AGENT, bytes_stream_capped}, 3 + identity_service::IdentityService, 4 + mime::{is_mime_allowed, sniff_mime}, 5 + types::blob_cid::BlobCid, 6 + }; 7 + use bytes::Bytes; 8 + use cid::Cid; 9 + use jacquard_common::types::did::Did; 10 + use mime::Mime; 11 + use moka::{future::Cache as MokaCache, policy::EvictionPolicy}; 12 + use multihash_codetable::{Code, MultihashDigest}; 13 + use reqwest::{StatusCode, header, header::HeaderValue}; 14 + use std::{num::NonZeroU64, sync::Arc, time::Duration}; 15 + use thiserror::Error; 16 + use tracing::instrument; 17 + 18 + #[derive(Debug, Error)] 19 + pub enum CreateBlobServiceError { 20 + /// An internal http client error occurred, see [`reqwest::Error`]. 21 + #[error(transparent)] 22 + HttpClient(#[from] reqwest::Error), 23 + } 24 + 25 + #[derive(Debug, Error)] 26 + pub enum BlobDownloadError { 27 + /// The blob resolver returned an error. 28 + #[error("blob resolver returned an error")] 29 + BlobResolutionFailure, 30 + /// The blob's computed CID does not match the requested CID. 31 + #[error("blob's computed CID does not match the requested CID")] 32 + CidMismatch, 33 + /// The requested CID uses an unsupported multihash algorithm. 34 + #[error("requested CID uses an unsupported multihash algorithm")] 35 + CidUnsupportedMultihash, 36 + /// The blob could not be found at the requested address. 37 + #[error("blob could not be found at the requested address")] 38 + NotFound, 39 + /// The blob exceeds the maximum size permitted by this server. 40 + #[error("blob exceeded the maximum size")] 41 + TooLarge, 42 + /// The origin returned a non-successful status code while fetching the blob, 43 + /// excluding 404 which is handled by [`Self::NotFound`]. 44 + #[error("origin returned an unsuccessful status code")] 45 + ErrorStatusCode, 46 + /// The request to the origin failed. 47 + #[error("the request to the origin failed")] 48 + FetchFailure, 49 + /// The blob stream was interrupted before it could be fully downloaded, 50 + /// for example due to the connection being unexpectedly reset. 51 + #[error("the blob stream was interrupted before completion")] 52 + StreamFailed, 53 + /// The blob's detected MIME type is not permitted by this server. 54 + #[error("blob's mimetype was not in the allowlist")] 55 + ForbiddenMimeType, 56 + } 57 + 58 + #[derive(Debug, Error)] 59 + pub enum BlobOwnershipError { 60 + /// The blob resolver returned an error. 61 + #[error("blob resolver returned an error")] 62 + BlobResolutionFailure, 63 + /// The blob could not be found in the user's repository. 64 + #[error("blob could not be found at the requested address")] 65 + NotFound, 66 + /// The origin returned a non-successful status code while fetching the blob, 67 + /// excluding 404 which is handled by [`Self::NotFound`]. 68 + #[error("origin returned an unsuccessful status code")] 69 + ErrorStatusCode, 70 + /// The request to the origin failed. 71 + #[error("the request to the origin failed")] 72 + FetchFailure, 73 + } 74 + 75 + #[derive(Clone)] 76 + pub struct BlobData { 77 + pub bytes: Bytes, 78 + pub mime_type: Mime, 79 + } 80 + 81 + pub enum BlobUrlResolver<'a> { 82 + Pds { 83 + identity_service: &'a IdentityService, 84 + }, 85 + } 86 + 87 + #[derive(Debug, Clone, Copy)] 88 + pub struct BlobServiceOptions { 89 + pub data_cache_max_capacity: u64, 90 + pub data_cache_tti: Duration, 91 + pub ownership_cache_max_capacity: u64, 92 + pub ownership_cache_ttl: Duration, 93 + pub http_timeout: Duration, 94 + pub http_connect_timeout: Duration, 95 + } 96 + 97 + pub struct BlobService { 98 + data_cache: MokaCache<BlobCid, BlobData>, 99 + ownership_cache: MokaCache<(BlobCid, Did<'static>), ()>, 100 + http_client: reqwest::Client, 101 + } 102 + 103 + impl BlobService { 104 + pub fn new(options: BlobServiceOptions) -> Result<Self, CreateBlobServiceError> { 105 + tracing::debug!("creating blob service with options: {options:?}"); 106 + Ok(Self { 107 + data_cache: MokaCache::<BlobCid, BlobData>::builder() 108 + .name("blob-content") 109 + .weigher(|_key, value| value.bytes.len().try_into().unwrap_or(u32::MAX)) 110 + .eviction_policy(EvictionPolicy::tiny_lfu()) 111 + .max_capacity(options.data_cache_max_capacity) 112 + .time_to_idle(options.data_cache_tti) 113 + .build(), 114 + ownership_cache: MokaCache::<(BlobCid, Did<'static>), ()>::builder() 115 + .name("blob-ownership") 116 + .weigher(|key, _value| { 117 + (key.0.encoded_len() + key.1.len()) 118 + .try_into() 119 + .unwrap_or(u32::MAX) 120 + }) 121 + .eviction_policy(EvictionPolicy::tiny_lfu()) 122 + .max_capacity(options.ownership_cache_max_capacity) 123 + .time_to_live(options.ownership_cache_ttl) 124 + .support_invalidation_closures() 125 + .build(), 126 + http_client: reqwest::Client::builder() 127 + .user_agent(PORXIE_USER_AGENT) 128 + .https_only(cfg!(debug_assertions)) 129 + .redirect(reqwest::redirect::Policy::limited(3)) 130 + .gzip(true) 131 + .brotli(true) 132 + .zstd(true) 133 + .deflate(true) 134 + .connect_timeout(options.http_connect_timeout) 135 + .timeout(options.http_timeout) 136 + .build() 137 + .map_err(CreateBlobServiceError::HttpClient)?, 138 + }) 139 + } 140 + 141 + /// Fetch the given blob either from the cache if available or from the upstream source. 142 + /// 143 + /// Concurrent requests for the same blob are coalesced. 144 + /// If the initial fetch fails, the next pending request will 145 + /// try instead, continuing until one succeeds or all have failed. 146 + #[instrument(skip_all, fields(did = %did, cid = %cid))] 147 + pub async fn fetch_blob( 148 + &self, 149 + did: &Did<'static>, 150 + cid: &BlobCid, 151 + url_resolver: BlobUrlResolver<'_>, 152 + max_blob_size: NonZeroU64, 153 + allowed_mimetypes: &[Mime], 154 + ) -> Result<BlobData, Arc<BlobDownloadError>> { 155 + tracing::debug!("fetching blob from origin"); 156 + 157 + self.data_cache 158 + .try_get_with_by_ref(cid, async { 159 + let blob_url = match url_resolver { 160 + BlobUrlResolver::Pds { 161 + identity_service: identity_resolver, 162 + } => { 163 + let mut url = identity_resolver 164 + .pds_for_did(did) 165 + .await 166 + .map_err(|_| BlobDownloadError::BlobResolutionFailure)?; 167 + url.set_path("/xrpc/com.atproto.sync.getBlob"); 168 + url.query_pairs_mut() 169 + .append_pair("did", did.as_str()) 170 + .append_pair("cid", &cid.to_string()); 171 + url 172 + } 173 + }; 174 + 175 + let validated_bytes = { 176 + let response = self.http_client.get(blob_url).send().await.map_err(|err| { 177 + tracing::warn!("failed to request blob from origin: {err:?}"); 178 + BlobDownloadError::FetchFailure 179 + })?; 180 + 181 + // Gracefully handle & abort if we do not receive a successful status code. 182 + if !response.status().is_success() { 183 + return Err(match response.status() { 184 + StatusCode::NOT_FOUND => { 185 + tracing::debug!("origin returned 404 for blob"); 186 + BlobDownloadError::NotFound 187 + } 188 + status => { 189 + tracing::debug!("origin returned error status for blob: {status}"); 190 + BlobDownloadError::ErrorStatusCode 191 + } 192 + }); 193 + } 194 + 195 + // Download bytes as a stream, enforcing a max size limit 196 + // and aborting if it's crossed. 197 + let bytes = bytes_stream_capped(response, max_blob_size).await.map_err( 198 + |err| match err { 199 + BytesStreamCappedError::TooLarge => { 200 + tracing::debug!("blob exceeds max size of {} bytes", max_blob_size); 201 + BlobDownloadError::TooLarge 202 + } 203 + BytesStreamCappedError::ClientError(err) => { 204 + tracing::warn!("error reading blob stream: {err:?}"); 205 + BlobDownloadError::StreamFailed 206 + } 207 + }, 208 + )?; 209 + 210 + // Verify request CID matches the blob's computed CID. 211 + // 212 + // This operation is done via spawn_blocking as creating the digest will block 213 + // this task's executor from switching to other tasks for as long it runs. 214 + tokio::task::spawn_blocking({ 215 + let bytes = bytes.clone(); 216 + let cid = *cid; 217 + move || { 218 + // Enabled Multihashes are set in the multihash-codetable crate features. 219 + let computed_cid = match Code::try_from(cid.hash().code()) { 220 + Ok(code) => Ok(Cid::new_v1( 221 + 0x55, // RaW codec 222 + code.digest(&bytes), 223 + )), 224 + Err(err) => { 225 + tracing::warn!("failed to compute CID: {err:?}"); 226 + Err(BlobDownloadError::CidUnsupportedMultihash) 227 + } 228 + }?; 229 + 230 + if computed_cid != *cid { 231 + tracing::warn!( 232 + "cid mismatch: computed {computed_cid} expected {cid}" 233 + ); 234 + return Err(BlobDownloadError::CidMismatch); 235 + } 236 + 237 + Ok(()) 238 + } 239 + }) 240 + .await 241 + .expect("CID computing task should not panic")?; 242 + 243 + bytes 244 + }; 245 + 246 + // Infer MIME type from content bytes rather than headers; this is fallible 247 + // and falls back to application/octet-stream if the type is unrecognised. 248 + let mime_type = sniff_mime(&validated_bytes); 249 + if !is_mime_allowed(&mime_type, allowed_mimetypes) { 250 + tracing::debug!("blob was inferred to be a disallowed mime type: {mime_type}"); 251 + return Err(BlobDownloadError::ForbiddenMimeType); 252 + } 253 + 254 + // Mark this DID+CID pair as ownership-verified since we just fetched it from the origin. 255 + self.ownership_cache.insert((*cid, did.clone()), ()).await; 256 + 257 + Ok(BlobData { 258 + bytes: validated_bytes, 259 + mime_type, 260 + }) 261 + }) 262 + .await 263 + } 264 + 265 + pub async fn invalidate_blob(&self, cid: &BlobCid) { 266 + self.data_cache.invalidate(cid).await 267 + } 268 + 269 + /// Fetch whether the user owns the given blob either from the cache if available or the upstream source. 270 + /// 271 + /// The internal cache will be automatically populated if the blob was previously fetched from the same user. 272 + #[instrument(skip_all, fields(did = %did, cid = %cid))] 273 + pub async fn fetch_blob_ownership( 274 + &self, 275 + did: &Did<'static>, 276 + cid: BlobCid, 277 + url_resolver: BlobUrlResolver<'_>, 278 + ) -> Result<(), Arc<BlobOwnershipError>> { 279 + tracing::debug!("verifying ownership of blob"); 280 + 281 + self.ownership_cache 282 + // TODO: Remove clone on DID. 283 + .try_get_with((cid, did.clone()), async { 284 + let blob_url = match url_resolver { 285 + BlobUrlResolver::Pds { 286 + identity_service: identity_resolver, 287 + } => { 288 + let mut url = identity_resolver 289 + .pds_for_did(did) 290 + .await 291 + .map_err(|_| BlobOwnershipError::BlobResolutionFailure)?; 292 + url.set_path("/xrpc/com.atproto.sync.getBlob"); 293 + url.query_pairs_mut() 294 + .append_pair("did", did.as_str()) 295 + .append_pair("cid", &cid.to_string()); 296 + url 297 + } 298 + }; 299 + 300 + // Request the blob with as little of the actual body as we can. 301 + // 302 + // While some origins (bsky pds, tranquil pds) may support HTTP HEAD, it is not 303 + // actually a part of the XRPC specification and we cannot rely on it (for now). 304 + // Use a range request to avoid downloading the full body on servers that support it instead. 305 + match self 306 + .http_client 307 + .get(blob_url) 308 + .header( 309 + header::RANGE, 310 + const { HeaderValue::from_static("bytes=0-1023") }, 311 + ) 312 + .send() 313 + .await 314 + .map_err(|err| { 315 + tracing::warn!("failed to request blob from origin: {err:?}"); 316 + BlobOwnershipError::FetchFailure 317 + })? 318 + .status() 319 + { 320 + status if status.is_success() => { 321 + tracing::debug!("verified ownership of blob"); 322 + Ok(()) 323 + } 324 + StatusCode::NOT_FOUND => { 325 + tracing::debug!("origin returned 404 for blob"); 326 + Err(BlobOwnershipError::NotFound) 327 + } 328 + status => { 329 + tracing::debug!("origin returned error status for blob: {}", status); 330 + Err(BlobOwnershipError::ErrorStatusCode) 331 + } 332 + } 333 + }) 334 + .await 335 + } 336 + 337 + pub fn invalidate_blob_ownership< 338 + F: Fn(&(BlobCid, Did<'static>), &()) -> bool + Send + Sync + 'static, 339 + >( 340 + &self, 341 + predicate: F, 342 + ) { 343 + if let Err(err) = self.ownership_cache.invalidate_entries_if(predicate) { 344 + tracing::error!( 345 + "blob service has not enabled support for invalidation closures: {err:?}" 346 + ); 347 + } 348 + } 349 + }
+27 -175
src/cache.rs
··· 1 - use crate::types::blob_cid::BlobCid; 2 - use anyhow::{Context, Result}; 3 - use axum::http::HeaderMap; 4 - use bytes::Bytes; 5 - use jacquard_common::types::did::Did; 6 - use moka::{future::Cache as MokaCache, policy::EvictionPolicy}; 7 - use reqwest::Url; 8 - use std::{cmp, num::NonZeroU64, time::Duration}; 9 - 10 - // Blob Content Cache 11 - 12 - type BlobContentCache = MokaCache<BlobCid, CachedBlobData>; 13 - 14 - #[derive(Clone)] 15 - pub struct CachedBlobData { 16 - pub bytes: Bytes, 17 - pub headers: HeaderMap, 18 - } 19 - 20 - #[must_use] 21 - fn build_blob_content_cache(mem_capacity: u64, tti: Duration) -> BlobContentCache { 22 - tracing::debug!( 23 - "building blob content cache with a mem_capacity of {mem_capacity} bytes and a tti of {}s", 24 - tti.as_secs() 25 - ); 26 - 27 - BlobContentCache::builder() 28 - .name("blob-content") 29 - .weigher(|_key, value| { 30 - (value.bytes.len() 31 - + value 32 - .headers 33 - .iter() 34 - .map(|(k, v)| k.as_str().len() + v.len() + 32) 35 - .sum::<usize>()) 36 - .try_into() 37 - .unwrap_or(u32::MAX) 38 - }) 39 - .eviction_policy(EvictionPolicy::tiny_lfu()) 40 - .max_capacity(mem_capacity) 41 - .time_to_idle(tti) 42 - .build() 43 - } 44 - 45 - // Blob Ownership Cache 46 - 47 - type BlobOwnershipCache = MokaCache<(BlobCid, Did<'static>), ()>; 48 - 49 - #[must_use] 50 - fn build_blob_ownership_cache(mem_capacity: u64, ttl: Duration) -> BlobOwnershipCache { 51 - tracing::debug!( 52 - "building blob ownership cache with a mem_capacity of {mem_capacity} bytes and a ttl of {}s", 53 - ttl.as_secs() 54 - ); 55 - 56 - BlobOwnershipCache::builder() 57 - .name("blob-ownership") 58 - .weigher(|key, _value| { 59 - (key.0.encoded_len() + key.1.len()) 60 - .try_into() 61 - .unwrap_or(u32::MAX) 62 - }) 63 - .eviction_policy(EvictionPolicy::tiny_lfu()) 64 - .max_capacity(mem_capacity) 65 - .time_to_live(ttl) 66 - .support_invalidation_closures() 67 - .build() 68 - } 69 - 70 - // Policy Cache 71 - 72 - type BlobPolicyCache = MokaCache<(Did<'static>, BlobCid), CachedBlobPolicy>; 73 - 74 - #[derive(Debug, Copy, Clone)] 75 - pub struct CachedBlobPolicy { 76 - pub can_serve: bool, 77 - } 78 - 79 - #[must_use] 80 - pub fn build_blob_policy_cache(mem_capacity: u64, ttl: Duration) -> BlobPolicyCache { 81 - tracing::debug!( 82 - "building blob policy cache with a mem_capacity of {mem_capacity} bytes and a ttl of {}s", 83 - ttl.as_secs() 84 - ); 85 - 86 - BlobPolicyCache::builder() 87 - .name("blob-policy") 88 - .weigher(|key, _value| { 89 - (key.0.len() + key.1.encoded_len()) 90 - .try_into() 91 - .unwrap_or(u32::MAX) 92 - }) 93 - .eviction_policy(EvictionPolicy::tiny_lfu()) 94 - .max_capacity(mem_capacity) 95 - .time_to_live(ttl) 96 - .support_invalidation_closures() 97 - .build() 98 - } 99 - 100 - // Identity DID <-> PDS Cache 101 - 102 - type IdentityCache = MokaCache<Did<'static>, Url>; 1 + use std::{cmp, num::NonZeroU64}; 2 + use thiserror::Error; 103 3 104 - #[must_use] 105 - pub fn build_identity_cache(mem_capacity: u64, ttl: Duration) -> IdentityCache { 106 - tracing::debug!( 107 - "building identity cache with a mem_capacity of {mem_capacity} bytes and a ttl of {}s", 108 - ttl.as_secs() 109 - ); 110 - 111 - IdentityCache::builder() 112 - .name("identity") 113 - .weigher(|key, value| { 114 - (key.len() + value.as_str().len()) 115 - .try_into() 116 - .unwrap_or(u32::MAX) 117 - }) 118 - .eviction_policy(EvictionPolicy::tiny_lfu()) 119 - .max_capacity(mem_capacity) 120 - .time_to_live(ttl) 121 - .support_invalidation_closures() 122 - .build() 4 + #[derive(Debug, Error)] 5 + pub enum ComputeCacheSizeError { 6 + #[error("cache size underflowed capacity")] 7 + AllocationUnderflow, 123 8 } 124 9 125 - // Builder 126 - 127 - pub struct Caches { 128 - pub blob_content: BlobContentCache, 129 - pub blob_ownership: BlobOwnershipCache, 130 - pub blob_policy: BlobPolicyCache, 131 - pub identity: IdentityCache, 10 + pub struct CacheSizes { 11 + pub blob: u64, 12 + pub ownership: u64, 13 + pub policy: u64, 14 + pub identity: u64, 132 15 } 133 16 134 - pub struct CacheBuildOptions { 135 - pub memory_capacity: NonZeroU64, 136 - pub blob_content_ttl: Duration, 137 - pub blob_ownership_ttl: Duration, 138 - pub blob_policy_ttl: Duration, 139 - pub identity_cache_ttl: Duration, 140 - } 141 - 142 - pub fn build_caches(options: &CacheBuildOptions) -> Result<Caches> { 143 - let sizes = { 144 - struct CacheSizes { 145 - pub blob: u64, 146 - pub ownership: u64, 147 - pub policy: u64, 148 - pub identity: u64, 149 - } 150 - let policy = cmp::min( 151 - (options.memory_capacity.get() as f64 * 0.06) as u64, 152 - 48_000_000, 153 - ); // 6% up to 48mb max. 154 - let ownership = cmp::min( 155 - (options.memory_capacity.get() as f64 * 0.06) as u64, 156 - 48_000_000, 157 - ); // 6% up to 48mb max. 158 - let identity = cmp::min( 159 - (options.memory_capacity.get() as f64 * 0.06) as u64, 160 - 48_000_000, 161 - ); // 6% up to 48mb max. 162 - CacheSizes { 163 - policy, 164 - ownership, 165 - identity, 166 - blob: options 167 - .memory_capacity 168 - .get() 169 - .checked_sub(policy) 170 - .and_then(|r| r.checked_sub(ownership)) 171 - .and_then(|r| r.checked_sub(identity)) 172 - .context("cache size allocation underflow")?, 173 - } 174 - }; 175 - 176 - Ok(Caches { 177 - blob_content: build_blob_content_cache(sizes.blob, options.blob_content_ttl), 178 - blob_ownership: build_blob_ownership_cache(sizes.ownership, options.blob_ownership_ttl), 179 - blob_policy: build_blob_policy_cache(sizes.policy, options.blob_policy_ttl), 180 - identity: build_identity_cache(sizes.identity, options.identity_cache_ttl), 17 + pub fn compute_cache_sizes( 18 + memory_capacity: NonZeroU64, 19 + ) -> Result<CacheSizes, ComputeCacheSizeError> { 20 + let policy = cmp::min((memory_capacity.get() as f64 * 0.10) as u64, 48_000_000); // 10% up to 48mb max. 21 + let ownership = cmp::min((memory_capacity.get() as f64 * 0.10) as u64, 48_000_000); // 10% up to 48mb max 22 + let identity = cmp::min((memory_capacity.get() as f64 * 0.10) as u64, 48_000_000); // 10% up to 48mb max. 23 + Ok(CacheSizes { 24 + policy, 25 + ownership, 26 + identity, 27 + blob: memory_capacity 28 + .get() 29 + .checked_sub(policy) 30 + .and_then(|r| r.checked_sub(ownership)) 31 + .and_then(|r| r.checked_sub(identity)) 32 + .ok_or(ComputeCacheSizeError::AllocationUnderflow)?, 181 33 }) 182 34 }
+40 -43
src/http.rs
··· 1 - use bytes::{Bytes, BytesMut}; 1 + use bytes::Bytes; 2 2 use futures_util::StreamExt; 3 - use reqwest::redirect::Policy; 4 - use std::{num::NonZeroU64, time::Duration}; 3 + use std::num::NonZeroU64; 5 4 use thiserror::Error; 6 5 7 - #[inline] 8 - pub fn build_http_client( 9 - timeout: Duration, 10 - connect_timeout: Duration, 11 - https_only: bool, 12 - ) -> Result<reqwest::Client, reqwest::Error> { 13 - reqwest::Client::builder() 14 - .user_agent(concat!( 15 - env!("CARGO_PKG_NAME"), 16 - "/", 17 - env!("CARGO_PKG_VERSION_MAJOR"), 18 - ".", 19 - env!("CARGO_PKG_VERSION_MINOR"), 20 - " (", 21 - env!("CARGO_PKG_REPOSITORY"), 22 - ")" 23 - )) 24 - .https_only(https_only) 25 - .redirect(Policy::limited(3)) 26 - .gzip(true) 27 - .brotli(true) 28 - .zstd(true) 29 - .deflate(true) 30 - .connect_timeout(connect_timeout) 31 - .timeout(timeout) 32 - .build() 33 - } 6 + pub const PORXIE_USER_AGENT: &str = concat!( 7 + env!("CARGO_PKG_NAME"), 8 + "/", 9 + env!("CARGO_PKG_VERSION_MAJOR"), 10 + ".", 11 + env!("CARGO_PKG_VERSION_MINOR"), 12 + " (", 13 + env!("CARGO_PKG_REPOSITORY"), 14 + ")" 15 + ); 34 16 35 17 #[derive(Debug, Error)] 36 18 pub enum BytesStreamCappedError { ··· 43 25 ClientError(#[from] reqwest::Error), 44 26 } 45 27 46 - /// A wrapper around `Response::bytes_stream()` that acts like `Response::bytes()` 47 - /// but enforces a maximum size limit while streaming the response. 28 + /// Stream a response into [`Bytes`], aborting if the buffer exceeds `max_size`. 29 + /// 30 + /// Pre-allocates a buffer based on response size heuristics when available, otherwise starts small 31 + /// and grows as data is streamed. If the buffer capacity differs from the buffer length after, 32 + /// the buffer may be shrunk to fit. 48 33 pub async fn bytes_stream_capped( 49 34 response: reqwest::Response, 50 35 max_size: NonZeroU64, 51 36 ) -> Result<Bytes, BytesStreamCappedError> { 52 - if let Some(content_length) = response.content_length() 53 - && content_length > max_size.get() 54 - { 37 + let max_size = max_size.get(); 38 + 39 + // Use body size hint, fallback to content-length header. 40 + let inferred_size = response.content_length().or_else(|| { 41 + response 42 + .headers() 43 + .get(reqwest::header::CONTENT_LENGTH) 44 + .and_then(|v| v.to_str().ok()) 45 + .and_then(|v| v.parse::<u64>().ok()) 46 + }); 47 + 48 + // Skip stream if the inferred size exceeds max size. 49 + if inferred_size.is_some_and(|size| size > max_size) { 55 50 return Err(BytesStreamCappedError::TooLarge); 56 51 } 57 52 58 - let mut buffer = BytesMut::with_capacity( 59 - response 60 - .content_length() 53 + // Stream bytes in chunks and abort if we exceed max size. 54 + let mut stream = response.bytes_stream(); 55 + let mut buffer = Vec::with_capacity( 56 + inferred_size 61 57 .unwrap_or(64 * 1024) 62 - .min(max_size.get()) 58 + .min(max_size) 63 59 .try_into() 64 - .unwrap_or(usize::MAX), 60 + .expect("buffer allocation should not exceed usize"), 65 61 ); 66 - let mut stream = response.bytes_stream(); 67 62 while let Some(chunk) = stream.next().await { 68 63 let chunk = chunk.map_err(BytesStreamCappedError::ClientError)?; 69 - if buffer.len() as u64 + chunk.len() as u64 > max_size.get() { 64 + if buffer.len() as u64 + chunk.len() as u64 > max_size { 70 65 return Err(BytesStreamCappedError::TooLarge); 71 66 } 72 67 buffer.extend_from_slice(&chunk); 73 68 } 74 69 75 - Ok(buffer.freeze()) 70 + Ok(Bytes::from( 71 + buffer.into_boxed_slice(), // shrink capacity to fit 72 + )) 76 73 }
+141
src/identity_service.rs
··· 1 + use crate::http::PORXIE_USER_AGENT; 2 + use jacquard_common::types::did::Did; 3 + use jacquard_identity::{ 4 + JacquardResolver, 5 + resolver::{IdentityError, IdentityResolver as _, PlcSource, ResolverOptions}, 6 + }; 7 + use moka::{future::Cache as MokaCache, policy::EvictionPolicy}; 8 + use reqwest::Url; 9 + use std::{sync::Arc, time::Duration}; 10 + use thiserror::Error; 11 + use tracing::instrument; 12 + 13 + #[derive(Debug, Error)] 14 + #[non_exhaustive] 15 + pub enum CreateIdentityServiceError { 16 + /// An internal http client error occurred, see [`reqwest::Error`]. 17 + #[error(transparent)] 18 + HttpClient(#[from] reqwest::Error), 19 + } 20 + 21 + #[derive(Debug, Clone)] 22 + pub struct IdentityServiceOptions { 23 + /// Maximum size in memory this cache is permitted to grow to. 24 + pub cache_memory_allocation: u64, 25 + /// Time-to-live duration of items in the cache. 26 + pub cache_ttl: Duration, 27 + /// HTTP timeout to apply to all identity fetches. 28 + pub http_timeout: Duration, 29 + /// HTTP connection-phase timeout to apply to all identity requests. 30 + pub http_connect_timeout: Duration, 31 + /// URL to the PLC directory to query for `did:plc` requests. 32 + pub plc_directory_url: Url, 33 + } 34 + 35 + pub struct IdentityService { 36 + resolver: JacquardResolver, 37 + cache: MokaCache<Did<'static>, Url>, 38 + } 39 + 40 + impl IdentityService { 41 + /// Create a new identity service. 42 + pub fn new(options: IdentityServiceOptions) -> Result<Self, CreateIdentityServiceError> { 43 + tracing::debug!("creating identity service with options: {options:?}"); 44 + Ok(Self { 45 + resolver: JacquardResolver::new( 46 + reqwest::Client::builder() 47 + .user_agent(PORXIE_USER_AGENT) 48 + .https_only(cfg!(debug_assertions)) 49 + .redirect(reqwest::redirect::Policy::limited(2)) 50 + .gzip(true) 51 + .brotli(true) 52 + .zstd(true) 53 + .deflate(true) 54 + .connect_timeout(options.http_connect_timeout) 55 + .timeout(options.http_timeout) 56 + .build() 57 + .map_err(CreateIdentityServiceError::HttpClient)?, 58 + ResolverOptions { 59 + plc_source: PlcSource::PlcDirectory { 60 + base: options.plc_directory_url, 61 + }, 62 + public_fallback_for_handle: true, 63 + validate_doc_id: true, 64 + request_timeout: Some(options.http_timeout), 65 + ..Default::default() 66 + }, 67 + ), 68 + cache: MokaCache::<Did<'static>, Url>::builder() 69 + .name("identity") 70 + .weigher(|key, value| { 71 + (key.len() + value.as_str().len()) 72 + .try_into() 73 + .unwrap_or(u32::MAX) 74 + }) 75 + .eviction_policy(EvictionPolicy::tiny_lfu()) 76 + .max_capacity(options.cache_memory_allocation) 77 + .time_to_live(options.cache_ttl) 78 + .build(), 79 + }) 80 + } 81 + 82 + /// Resolve the PDS assigned by the given Did. 83 + /// 84 + // Concurrent requests for the same key are coalesced. 85 + #[instrument(skip_all, fields(did = %did))] 86 + pub async fn pds_for_did(&self, did: &Did<'static>) -> Result<Url, Arc<IdentityError>> { 87 + self.cache 88 + .try_get_with_by_ref(did, self.resolver.pds_for_did(did)) 89 + .await 90 + } 91 + 92 + /// Clears all cached data for the given Di. 93 + pub async fn invalidate_did_cache(&self, did: &Did<'static>) { 94 + self.cache.invalidate(did).await 95 + } 96 + } 97 + 98 + #[cfg(test)] 99 + mod tests { 100 + use crate::identity_service::{IdentityService, IdentityServiceOptions}; 101 + use jacquard_common::types::did::Did; 102 + use reqwest::Url; 103 + use std::time::Duration; 104 + 105 + fn make_service() -> IdentityService { 106 + IdentityService::new(IdentityServiceOptions { 107 + cache_memory_allocation: 500, 108 + cache_ttl: Duration::from_hours(24), 109 + http_timeout: Duration::from_secs(30), 110 + http_connect_timeout: Duration::from_secs(15), 111 + plc_directory_url: Url::parse("https://plc.directory").unwrap(), 112 + }) 113 + .expect("service constructor should be always be valid") 114 + } 115 + 116 + #[tokio::test] 117 + async fn resolve_and_cache() { 118 + let resolver = make_service(); 119 + let did = Did::new_static("did:plc:ewvi7nxzyoun6zhxrhs64oiz") 120 + .expect("test did should always be valid"); // atproto.com 121 + 122 + // Test cold resolve and cache. 123 + assert!(resolver.pds_for_did(&did).await.is_ok()); 124 + assert!(resolver.cache.contains_key(&did)); 125 + 126 + // Test invalidation 127 + resolver.invalidate_did_cache(&did).await; 128 + assert!(!resolver.cache.contains_key(&did)); 129 + } 130 + 131 + #[tokio::test] 132 + async fn resolve_error_uncached() { 133 + let resolver = make_service(); 134 + let did = Did::new_static("did:plc:aaaaaaaaaaaaaaaaaaaaaaaa") 135 + .expect("test did should always be valid"); 136 + 137 + // Test cold resolve and cache. 138 + assert!(resolver.pds_for_did(&did).await.is_err()); 139 + assert!(!resolver.cache.contains_key(&did)); 140 + } 141 + }
+73 -77
src/main.rs
··· 1 + mod blob_service; 1 2 mod cache; 2 3 mod http; 4 + mod identity_service; 3 5 mod mime; 6 + mod policy_client; 4 7 mod routes; 5 8 mod types; 6 9 7 10 use crate::{ 8 - cache::{CacheBuildOptions, Caches, build_caches}, 9 - http::build_http_client, 10 - routes::{delete_cache_handler, get_blob_handler, get_index_handler}, 11 + blob_service::{BlobService, BlobServiceOptions}, 12 + cache::compute_cache_sizes, 13 + identity_service::{IdentityService, IdentityServiceOptions}, 14 + policy_client::{PolicyClient, PolicyClientOptions}, 15 + routes::{delete_cache_handler, get_blob_handler, get_health_handler, get_index_handler}, 11 16 }; 12 17 use ::mime::Mime; 13 - use anyhow::{Context, Result, bail}; 18 + use anyhow::{Context, bail}; 14 19 use axum::{ 15 20 Router, 16 21 extract::Request, 17 22 http::{HeaderName, HeaderValue, StatusCode, header}, 18 23 middleware::{self as axum_middleware, Next}, 24 + response::Response, 19 25 routing::{delete, get}, 20 26 }; 21 27 use bytesize::ByteSize; 22 28 use clap::{Args, Parser}; 23 29 use dotenvy::dotenv; 24 - use jacquard_identity::{ 25 - JacquardResolver, 26 - resolver::{PlcSource, ResolverOptions}, 27 - }; 28 30 use reqwest::Url; 29 31 use std::{net::SocketAddr, num::NonZeroU64, path::PathBuf, str::FromStr, sync::Arc}; 30 32 use tower_http::{ ··· 33 35 timeout::TimeoutLayer, 34 36 trace::{self, DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer}, 35 37 }; 36 - use tracing::{Level, info}; 38 + use tracing::Level; 37 39 use tracing_subscriber::EnvFilter; 38 40 39 - #[derive(Clone)] 41 + #[derive(Debug, Clone)] 40 42 enum AddressType { 43 + /// An IP socket address. 41 44 Ip(SocketAddr), 45 + /// A UNIX socket path. 42 46 #[cfg(unix)] 43 47 Unix(PathBuf), 44 48 } 45 49 46 50 impl FromStr for AddressType { 47 51 type Err = anyhow::Error; 48 - fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { 52 + 53 + fn from_str(s: &str) -> Result<Self, Self::Err> { 49 54 #[cfg(unix)] 50 55 if let Some(path) = s.strip_prefix("unix:") { 51 56 return Ok(AddressType::Unix(PathBuf::from(path))); ··· 173 178 id = "BA_BLOB_CACHE_HEADER", 174 179 long = "blob-cache-header", 175 180 env = "PORXIE_BLOB_CACHE_HEADER", 176 - default_value = "public, max-age=604800, must-revalidate, immutable" 181 + default_value = "public, max-age=604800, immutable" 177 182 )] 178 183 cache_header: HeaderValue, 179 184 ··· 391 396 } 392 397 393 398 struct AppState { 394 - // Core. 395 - identity_resolver: JacquardResolver, 396 - policy_http_client: reqwest::Client, 397 - blob_fetch_http_client: reqwest::Client, 398 - cache: Caches, 399 399 // Authentication. 400 400 auth_token: Option<String>, 401 401 // Blob handling. 402 402 allowed_mimetypes: Vec<Mime>, 403 + blob_service: BlobService, 404 + cache_control_header: HeaderValue, 403 405 max_blob_size: NonZeroU64, 404 - cache_control_header: HeaderValue, 405 - // Policy service. 406 - policy_service_url: Option<Url>, 407 - policy_service_headers: Vec<(HeaderName, HeaderValue)>, 408 - policy_service_fail_open: bool, 406 + // Policy. 407 + policy_client: Option<PolicyClient>, 408 + policy_fail_open: bool, 409 + // Identity. 410 + identity_service: IdentityService, 409 411 } 410 412 411 413 #[tokio::main(flavor = "multi_thread")] 412 - async fn main() -> Result<()> { 414 + async fn main() -> anyhow::Result<()> { 413 415 dotenv().ok(); 414 416 json_subscriber::fmt() 415 417 .with_env_filter(EnvFilter::try_from_default_env().unwrap_or(EnvFilter::new("info"))) ··· 417 419 let args = AppArgs::parse(); 418 420 419 421 // Setup state. 422 + let cache_sizes = compute_cache_sizes(args.cache.size)?; 420 423 let app_state = Arc::new(AppState { 421 - identity_resolver: JacquardResolver::new( 422 - build_http_client( 423 - args.identity.http_timeout.into(), 424 - args.identity.http_connect_timeout.into(), 425 - !cfg!(debug_assertions), 426 - ) 427 - .context("failed to build identity http client")?, 428 - ResolverOptions { 429 - plc_source: PlcSource::PlcDirectory { 430 - base: args.identity.plc_url, 431 - }, 432 - public_fallback_for_handle: true, 433 - validate_doc_id: true, 434 - request_timeout: Some(args.identity.http_timeout.into()), 435 - ..Default::default() 436 - }, 437 - ), 438 - blob_fetch_http_client: build_http_client( 439 - args.blob.http_timeout.into(), 440 - args.blob.http_connect_timeout.into(), 441 - !cfg!(debug_assertions), 442 - ) 443 - .context("failed to build blob fetch http client")?, 444 - policy_http_client: build_http_client( 445 - args.policy.http_timeout.into(), 446 - args.policy.http_connect_timeout.into(), 447 - !cfg!(debug_assertions), 448 - ) 449 - .context("failed to build policy http client")?, 450 - cache: build_caches(&CacheBuildOptions { 451 - memory_capacity: args.cache.size, 452 - blob_content_ttl: args.cache.blob_tti.into(), 453 - blob_ownership_ttl: args.cache.ownership_ttl.into(), 454 - blob_policy_ttl: args.cache.policy_ttl.into(), 455 - identity_cache_ttl: args.cache.identity_ttl.into(), 456 - }) 457 - .context("failed to build caches")?, 424 + identity_service: IdentityService::new(IdentityServiceOptions { 425 + cache_memory_allocation: cache_sizes.identity, 426 + cache_ttl: args.cache.identity_ttl.into(), 427 + http_timeout: args.identity.http_timeout.into(), 428 + http_connect_timeout: args.identity.http_connect_timeout.into(), 429 + plc_directory_url: args.identity.plc_url, 430 + })?, 431 + policy_client: args 432 + .policy 433 + .url 434 + .map(|url| { 435 + PolicyClient::new(PolicyClientOptions { 436 + policy_service_url: url, 437 + policy_service_req_headers: args.policy.request_headers, 438 + cache_max_memory_allocation: cache_sizes.policy, 439 + cache_ttl: args.cache.policy_ttl.into(), 440 + http_timeout: args.policy.http_timeout.into(), 441 + http_connect_timeout: args.policy.http_connect_timeout.into(), 442 + }) 443 + }) 444 + .transpose()?, 445 + blob_service: BlobService::new(BlobServiceOptions { 446 + http_timeout: args.blob.http_timeout.into(), 447 + http_connect_timeout: args.blob.http_connect_timeout.into(), 448 + data_cache_max_capacity: cache_sizes.blob, 449 + data_cache_tti: args.cache.blob_tti.into(), 450 + ownership_cache_max_capacity: cache_sizes.ownership, 451 + ownership_cache_ttl: args.cache.ownership_ttl.into(), 452 + })?, 453 + 458 454 auth_token: args.server.auth_token, 459 455 allowed_mimetypes: args.blob.allowed_mimetypes, 460 456 max_blob_size: args.blob.max_size, 461 457 cache_control_header: args.blob.cache_header, 462 - policy_service_url: args.policy.url, 463 - policy_service_headers: args.policy.request_headers, 464 - policy_service_fail_open: args.policy.fail_open, 458 + 459 + policy_fail_open: args.policy.fail_open, 465 460 }); 466 461 467 462 // Setup router. 468 463 let router = Router::new() 469 464 .route("/", get(get_index_handler)) 465 + .route("/health", get(get_health_handler)) 470 466 .route( 471 467 "/{did}/{cid}", 472 468 get(get_blob_handler).layer(TimeoutLayer::with_status_code( ··· 484 480 ) 485 481 .layer(NormalizePathLayer::trim_trailing_slash()) 486 482 .layer(CatchPanicLayer::new()) 487 - .layer(axum_middleware::from_fn( 488 - async |req: Request, next: Next| { 489 - let mut res = next.run(req).await; 490 - let res_headers = res.headers_mut(); 491 - res_headers.insert( 492 - header::SERVER, 493 - const { HeaderValue::from_static(env!("CARGO_PKG_NAME")) }, 494 - ); 495 - res_headers.insert("X-Robots-Tag", const { HeaderValue::from_static("none") }); 496 - res 497 - }, 498 - )) 483 + .layer(axum_middleware::from_fn(additional_headers_middleware)) 499 484 .with_state(app_state); 500 485 501 486 // Start server listener on specified address. ··· 504 489 let listener = tokio::net::TcpListener::bind(ip) 505 490 .await 506 491 .context("failed to bind tcp listener")?; 507 - info!("listening on http://{ip}"); 492 + tracing::info!("server listening on http://{ip}"); 508 493 axum::serve(listener, router) 509 494 .with_graceful_shutdown(shutdown_signal()) 510 495 .await?; ··· 514 499 let _ = std::fs::remove_file(&path); 515 500 let listener = 516 501 tokio::net::UnixListener::bind(&path).context("failed to bind unix listener")?; 517 - info!("listening on unix:{}", path.display()); 502 + tracing::info!("server listening on unix:{}", path.display()); 518 503 axum::serve(listener, router) 519 504 .with_graceful_shutdown(shutdown_signal()) 520 505 .await?; ··· 523 508 } 524 509 525 510 Ok(()) 511 + } 512 + 513 + async fn additional_headers_middleware(req: Request, next: Next) -> Response { 514 + let mut res = next.run(req).await; 515 + let res_headers = res.headers_mut(); 516 + res_headers.insert( 517 + header::SERVER, 518 + const { HeaderValue::from_static(env!("CARGO_PKG_NAME")) }, 519 + ); 520 + res_headers.insert("X-Robots-Tag", const { HeaderValue::from_static("none") }); 521 + res 526 522 } 527 523 528 524 // https://github.com/tokio-rs/axum/blob/15917c6dbcb4a48707a20e9cfd021992a279a662/examples/graceful-shutdown/src/main.rs#L55
+71 -58
src/mime.rs
··· 1 1 use mime::Mime; 2 2 3 - /// Sniff the MIME type from the given bytes, returning `application/octet-stream` if unknown. 3 + /// Sniff the MIME type from the given bytes. 4 + /// 5 + /// Returns [`mime::APPLICATION_OCTET_STREAM`] when unknown. 4 6 #[must_use] 5 7 pub fn sniff_mime(buf: &[u8]) -> Mime { 6 - // WORKAROUND: infer does not correctly detect SVG. 7 - // I have created PR to fix this at https://github.com/bojand/infer/pull/119 8 - // Until that is merged, this case will work around that limitation. 9 - const SVG_MARKER: &[u8; 4] = b"<svg"; 10 - const XML_MARKER: &[u8; 5] = b"<?xml"; 11 - const XML_SNIFFAHEAD: usize = 256; // How far after the XML marker to sniff ahead for the SVG marker. 12 - if buf.starts_with(SVG_MARKER) 13 - || (buf.starts_with(XML_MARKER) 14 - && buf 15 - .get(..XML_SNIFFAHEAD) 16 - .unwrap_or(buf) 17 - .windows(SVG_MARKER.len()) 18 - .any(|w| w == SVG_MARKER)) 19 - { 20 - return mime::IMAGE_SVG; 21 - } 22 - 23 8 match infer::get(buf) { 24 9 Some(m) => m 25 10 .mime_type() 26 11 .parse() 27 12 .expect("infer mimetype should always be valid"), 28 - None => mime::APPLICATION_OCTET_STREAM, 13 + None => { 14 + // WORKAROUND: infer does not correctly detect SVG. 15 + // I have created PR to fix this at https://github.com/bojand/infer/pull/119 16 + // Until that is merged, this case will work around that limitation. 17 + const SVG_MARKER: &[u8; 4] = b"<svg"; 18 + const XML_MARKER: &[u8; 5] = b"<?xml"; 19 + const XML_SNIFFAHEAD: usize = 256; // How far after the XML marker to sniff ahead for the SVG marker. 20 + if buf.len() >= 4 && buf.starts_with(SVG_MARKER) 21 + || (buf.starts_with(XML_MARKER) 22 + && buf 23 + .get(..XML_SNIFFAHEAD) 24 + .unwrap_or(buf) 25 + .windows(SVG_MARKER.len()) 26 + .any(|w| w == SVG_MARKER)) 27 + { 28 + tracing::debug!("used svg workaround instead of regular inference"); 29 + return mime::IMAGE_SVG; 30 + } 31 + tracing::debug!("infer was unable to determine mimetype, using fallback value"); 32 + mime::APPLICATION_OCTET_STREAM 33 + } 29 34 } 30 35 } 31 36 37 + /// Whether the given [`Mime`] is apart of the allowed array by 38 + /// checking if it matches directly or by wildcard. 32 39 #[must_use] 33 40 pub fn is_mime_allowed(mime: &Mime, allowed: &[Mime]) -> bool { 34 41 const STAR: &str = "*"; ··· 63 70 use std::str::FromStr; 64 71 65 72 #[test] 66 - fn test_is_mime_allowed() { 67 - // Test PNG when nothing is allowed. 68 - assert_eq!( 69 - super::is_mime_allowed(&Mime::from_str("image/png").unwrap(), &[]), 70 - false 71 - ); 73 + fn no_match() { 74 + // PNG when nothing is allowed. 75 + assert!(!super::is_mime_allowed( 76 + &Mime::from_str("image/png").unwrap(), 77 + &[] 78 + )); 79 + } 80 + 81 + #[test] 82 + fn exact_match() { 83 + // PNG when PNG is allowed. 84 + assert!(super::is_mime_allowed( 85 + &Mime::from_str("image/png").unwrap(), 86 + &[mime::IMAGE_PNG], 87 + )); 72 88 73 - // Test PNG when PNG is allowed. 74 - assert_eq!( 75 - super::is_mime_allowed(&Mime::from_str("image/png").unwrap(), &[mime::IMAGE_PNG],), 76 - true 77 - ); 89 + // PNG when only JPG is allowed. 90 + assert!(!super::is_mime_allowed( 91 + &Mime::from_str("image/png").unwrap(), 92 + &[mime::IMAGE_JPEG], 93 + )); 94 + } 78 95 79 - // Test PNG when only JPG is allowed. 80 - assert_eq!( 81 - super::is_mime_allowed(&Mime::from_str("image/png").unwrap(), &[mime::IMAGE_JPEG],), 82 - false 83 - ); 96 + #[test] 97 + fn full_wildcard() { 98 + // PNG when anything is allowed. 99 + assert!(super::is_mime_allowed( 100 + &Mime::from_str("image/png").unwrap(), 101 + &[mime::STAR_STAR], 102 + )); 103 + } 84 104 85 - // Test PNG when any image subtype is allowed. 86 - assert_eq!( 87 - super::is_mime_allowed(&Mime::from_str("image/png").unwrap(), &[mime::IMAGE_STAR],), 88 - true 89 - ); 105 + #[test] 106 + fn subtype_wildcard() { 107 + // PNG when any image subtype is allowed. 108 + assert!(super::is_mime_allowed( 109 + &Mime::from_str("image/png").unwrap(), 110 + &[mime::IMAGE_STAR] 111 + )); 90 112 91 - // Test PNG when anything is allowed. 92 - assert_eq!( 93 - super::is_mime_allowed(&Mime::from_str("image/png").unwrap(), &[mime::STAR_STAR],), 94 - true 95 - ); 113 + // PNG when images and text are enabled. 114 + assert!(super::is_mime_allowed( 115 + &Mime::from_str("image/png").unwrap(), 116 + &[mime::TEXT_STAR, mime::IMAGE_STAR], 117 + )); 96 118 97 119 // Test HTML when any image subtype is enabled. 98 - assert_eq!( 99 - super::is_mime_allowed(&Mime::from_str("text/html").unwrap(), &[mime::IMAGE_STAR],), 100 - false 101 - ); 102 - 103 - // Test PNG when images and text are enabled. 104 - assert_eq!( 105 - super::is_mime_allowed( 106 - &Mime::from_str("image/png").unwrap(), 107 - &[mime::TEXT_STAR, mime::IMAGE_STAR], 108 - ), 109 - true 110 - ); 120 + assert!(!super::is_mime_allowed( 121 + &Mime::from_str("text/html").unwrap(), 122 + &[mime::IMAGE_STAR], 123 + )); 111 124 } 112 125 }
+161
src/policy_client.rs
··· 1 + use crate::{http::PORXIE_USER_AGENT, types::blob_cid::BlobCid}; 2 + use jacquard_common::types::did::Did; 3 + use moka::{future::Cache as MokaCache, policy::EvictionPolicy}; 4 + use reqwest::{ 5 + StatusCode, Url, 6 + header::{HeaderName, HeaderValue}, 7 + }; 8 + use std::{sync::Arc, time::Duration}; 9 + use thiserror::Error; 10 + use tracing::instrument; 11 + 12 + #[derive(Debug, Clone)] 13 + pub struct PolicyDecision { 14 + /// Whether the service allows this blob can be served. 15 + pub can_serve: bool, 16 + } 17 + 18 + #[derive(Debug, Error)] 19 + #[non_exhaustive] 20 + pub enum CreatePolicyClientError { 21 + /// An internal http client error occurred, see [`reqwest::Error`]. 22 + #[error(transparent)] 23 + HttpClient(#[from] reqwest::Error), 24 + } 25 + 26 + #[derive(Debug, Error)] 27 + #[non_exhaustive] 28 + pub enum GetBlobPolicyError { 29 + /// Policy service returned an unhandled status code (Not 200 OK or 410 GONE). 30 + #[error("received an unhandled status code from the policy service: {0}")] 31 + UnhandledStatusCode(StatusCode), 32 + /// An internal http client error occurred, see [`reqwest::Error`]. 33 + #[error(transparent)] 34 + HttpClient(#[from] reqwest::Error), 35 + } 36 + 37 + #[derive(Debug, Clone)] 38 + pub struct PolicyClientOptions { 39 + /// Maximum size in memory this cache is permitted to grow to. 40 + pub cache_max_memory_allocation: u64, 41 + /// Time-to-live duration of items in the cache. 42 + pub cache_ttl: Duration, 43 + /// HTTP timeout to apply to all identity requests. 44 + pub http_timeout: Duration, 45 + /// HTTP connection-phase timeout to apply to all policy requests. 46 + pub http_connect_timeout: Duration, 47 + /// URL to the policy service to query. 48 + pub policy_service_url: Url, 49 + /// Additional request headers to append to each policy service request. 50 + pub policy_service_req_headers: Vec<(HeaderName, HeaderValue)>, 51 + } 52 + 53 + pub struct PolicyClient { 54 + cache: MokaCache<(Did<'static>, BlobCid), PolicyDecision>, 55 + http_client: reqwest::Client, 56 + policy_service_req_headers: Vec<(HeaderName, HeaderValue)>, 57 + policy_service_url: Url, 58 + } 59 + 60 + impl PolicyClient { 61 + /// Create a new policy client. 62 + pub fn new(options: PolicyClientOptions) -> Result<Self, CreatePolicyClientError> { 63 + tracing::debug!("creating policy service client with options: {options:?}"); 64 + Ok(Self { 65 + cache: MokaCache::<(Did<'static>, BlobCid), PolicyDecision>::builder() 66 + .name("blob-policy") 67 + .weigher(|key, _value| { 68 + (key.0.len() + key.1.encoded_len()) 69 + .try_into() 70 + .unwrap_or(u32::MAX) 71 + }) 72 + .eviction_policy(EvictionPolicy::tiny_lfu()) 73 + .max_capacity(options.cache_max_memory_allocation) 74 + .time_to_live(options.cache_ttl) 75 + .support_invalidation_closures() 76 + .build(), 77 + http_client: reqwest::Client::builder() 78 + .user_agent(PORXIE_USER_AGENT) 79 + .https_only(false) 80 + .redirect(reqwest::redirect::Policy::limited(2)) 81 + .gzip(true) 82 + .brotli(true) 83 + .zstd(true) 84 + .deflate(true) 85 + .connect_timeout(options.http_connect_timeout) 86 + .timeout(options.http_timeout) 87 + .build() 88 + .map_err(CreatePolicyClientError::HttpClient)?, 89 + policy_service_url: options.policy_service_url, 90 + policy_service_req_headers: options.policy_service_req_headers, 91 + }) 92 + } 93 + 94 + /// Query the policy service for the policy decision of this blob. 95 + /// 96 + /// Concurrent requests for the same policy are coalesced. 97 + #[instrument(skip_all, fields(did = %did, cid = %cid))] 98 + pub async fn get_policy_for_blob( 99 + &self, 100 + did: &Did<'static>, 101 + cid: BlobCid, 102 + ) -> Result<PolicyDecision, Arc<GetBlobPolicyError>> { 103 + self.cache 104 + .try_get_with_by_ref(&(did.clone(), cid), async { 105 + tracing::debug!("querying policy service for the status"); 106 + 107 + let mut policy_service_url = self.policy_service_url.clone(); 108 + policy_service_url 109 + .path_segments_mut() 110 + .expect("policy service URL should not be cannot-be-a-base") 111 + .push(did.as_str()) 112 + .push(&cid.to_string()); 113 + 114 + let mut request = self.http_client.get(policy_service_url); 115 + for (name, value) in &self.policy_service_req_headers { 116 + request = request.header(name, value); 117 + } 118 + 119 + match request.send().await { 120 + Ok(response) => match response.status() { 121 + StatusCode::OK => { 122 + tracing::debug!("policy service allowed blob serving"); 123 + Ok(PolicyDecision { can_serve: true }) 124 + } 125 + StatusCode::GONE => { 126 + tracing::debug!("policy service forbids blob serving"); 127 + Ok(PolicyDecision { can_serve: false }) 128 + } 129 + status => { 130 + tracing::error!("policy service returned unexpected status: {status}"); 131 + Err(GetBlobPolicyError::UnhandledStatusCode(status)) 132 + } 133 + }, 134 + Err(err) => { 135 + tracing::error!("error occurred contacting the policy service: {err:?}"); 136 + Err(GetBlobPolicyError::HttpClient(err)) 137 + } 138 + } 139 + }) 140 + .await 141 + } 142 + 143 + /// Invalidate cached policy decisions with the given predicate. 144 + pub fn invalidate_policies< 145 + F: Fn(&(Did<'static>, BlobCid), &PolicyDecision) -> bool + Send + Sync + 'static, 146 + >( 147 + &self, 148 + predicate: F, 149 + ) { 150 + if let Err(err) = self.cache.invalidate_entries_if(predicate) { 151 + tracing::error!( 152 + "policy client cache has not enabled support for invalidation closures: {err:?}" 153 + ); 154 + } 155 + } 156 + } 157 + 158 + #[cfg(test)] 159 + mod tests { 160 + // TODO: Create an in-process mock policy service to write tests against. 161 + }
+83 -321
src/routes/blob/get.rs
··· 1 - use crate::http::{BytesStreamCappedError, bytes_stream_capped}; 2 - use crate::routes::ErrorResponse; 3 - use crate::types::blob_cid::BlobCid; 4 1 use crate::{ 5 2 AppState, 6 - cache::{CachedBlobData, CachedBlobPolicy}, 7 - mime::{is_mime_allowed, sniff_mime}, 3 + blob_service::{BlobDownloadError, BlobOwnershipError, BlobUrlResolver}, 4 + routes::{CACHE_CONTROL_NOCACHE_VALUE, ErrorResponse}, 5 + types::blob_cid::BlobCid, 8 6 }; 9 - use axum::Json; 10 7 use axum::{ 8 + Json, 11 9 body::Body, 12 10 extract::{Path, State}, 13 - http::{HeaderMap, HeaderValue, Response, StatusCode, header}, 11 + http::{HeaderName, HeaderValue, StatusCode, header}, 12 + response::Response, 14 13 }; 15 - use cid::Cid; 16 14 use jacquard_common::types::did::Did; 17 - use jacquard_identity::resolver::IdentityResolver; 18 - use multihash_codetable::{Code, MultihashDigest}; 19 - use reqwest::Url; 20 15 use std::sync::Arc; 21 16 22 - enum BlobPolicyError { 23 - /// The policy service returned an unexpected status code. 24 - UnhandledStatusCode, 25 - /// The request to the policy service failed, for example due to the server being unavailable. 26 - FetchFailed, 27 - } 28 - 29 - enum BlobDownloadError { 30 - /// Failed to resolve the PDS for the given DID. The DID may be invalid or the 31 - /// resolver may be unavailable. 32 - DidPdsResolutionFailure, 33 - /// The blob's computed CID does not match the requested CID. 34 - CidMismatch, 35 - /// The requested CID uses a multihash algorithm unsupported by this server. 36 - CidUnsupportedMultihash, 37 - /// The blob could not be found in the user's repository. 38 - NotFound, 39 - /// The blob exceeds the maximum size permitted by this server. 40 - TooLarge, 41 - /// The PDS returned a non-successful status code while fetching the blob, 42 - /// excluding 404 which is handled by [`Self::NotFound`]. 43 - ErrorStatusCode, 44 - /// The request to the PDS failed, for example due to the server being unavailable. 45 - FetchFailure, 46 - /// The blob stream was interrupted before it could be fully downloaded, 47 - /// for example due to the connection being unexpectedly reset. 48 - StreamFailed, 49 - /// The blob's detected MIME type is not permitted by this server. 50 - ForbiddenMimeType, 51 - } 52 - 53 - enum BlobOwnershipError { 54 - /// Failed to resolve the PDS for the given DID. The DID may be invalid or the 55 - /// resolver may be unavailable. 56 - DidPdsResolutionFailure, 57 - /// The blob could not be found in the user's repository. 58 - NotFound, 59 - /// The PDS returned a non-successful status code while fetching the blob, 60 - /// excluding 404 which is handled by [`Self::NotFound`]. 61 - ErrorStatusCode, 62 - /// The request to the PDS failed, for example due to the server being unavailable. 63 - FetchFailure, 64 - } 65 - 66 - /// Create a `/xrpc/com.atproto.sync.getBlob` Url for the DID+CID. 67 - #[inline] 68 - #[must_use] 69 - fn to_pds_blob_url(mut pds_url: Url, did: &Did<'_>, cid: &BlobCid) -> Url { 70 - pds_url.set_path("/xrpc/com.atproto.sync.getBlob"); 71 - pds_url 72 - .query_pairs_mut() 73 - .append_pair("did", did.as_str()) 74 - .append_pair("cid", &cid.to_string()); 75 - pds_url 76 - } 77 - 17 + /// Fetch a blob from a given upstream and return it. 78 18 pub async fn get_blob_handler( 79 19 Path((raw_did, raw_cid)): Path<(String, String)>, 80 20 State(state): State<Arc<AppState>>, 81 - ) -> Result<axum::response::Response, (StatusCode, Json<ErrorResponse>)> { 21 + ) -> Result< 22 + Response, 23 + ( 24 + StatusCode, 25 + [(HeaderName, &'static str); 1], 26 + Json<ErrorResponse>, 27 + ), 28 + > { 82 29 let (did, cid) = ( 83 30 match Did::new_owned(raw_did.as_str()) { 84 31 Ok(did) => did, 85 32 Err(_) => { 86 33 return Err(( 87 34 StatusCode::UNPROCESSABLE_ENTITY, 35 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 88 36 Json(ErrorResponse { 89 37 error: "MalformedDid", 90 38 message: Some("Invalid or unprocessable DID"), ··· 97 45 Err(_) => { 98 46 return Err(( 99 47 StatusCode::UNPROCESSABLE_ENTITY, 48 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 100 49 Json(ErrorResponse { 101 50 error: "MalformedCid", 102 51 message: Some("Invalid or unprocessable CID"), ··· 106 55 }, 107 56 ); 108 57 109 - // Check policy for this DID+CID; concurrent requests for the same key are coalesced. 110 - if let Some(ref policy_service_url) = state.policy_service_url { 111 - match state 112 - .cache 113 - .blob_policy 114 - .try_get_with_by_ref(&(did.clone(), cid), async { 115 - tracing::debug!("querying policy service for the status of blob"); 116 - 117 - let mut policy_service_url = policy_service_url.clone(); 118 - policy_service_url 119 - .path_segments_mut() 120 - .expect("policy service URL should not be a base") 121 - .push(did.as_str()) 122 - .push(raw_cid.as_str()); 123 - 124 - let mut request = state.policy_http_client.get(policy_service_url); 125 - for (name, value) in &state.policy_service_headers { 126 - request = request.header(name, value); 127 - } 128 - 129 - match request.send().await { 130 - Ok(response) => match response.status() { 131 - StatusCode::OK => { 132 - tracing::debug!("policy service returned 200 status, can serve blob"); 133 - Ok(CachedBlobPolicy { can_serve: true }) 134 - } 135 - StatusCode::GONE => { 136 - tracing::debug!( 137 - "policy service returned 410 status, cannot serve blob" 138 - ); 139 - Ok(CachedBlobPolicy { can_serve: false }) 140 - } 141 - status => { 142 - tracing::error!("policy service returned unexpected status: {status}"); 143 - Err(BlobPolicyError::UnhandledStatusCode) 144 - } 145 - }, 146 - Err(err) => { 147 - tracing::error!("error occurred contacting the policy service: {err:?}"); 148 - Err(BlobPolicyError::FetchFailed) 149 - } 150 - } 151 - }) 152 - .await 153 - { 58 + // Check the policy status of the blob. 59 + if let Some(ref policy_client) = state.policy_client { 60 + match policy_client.get_policy_for_blob(&did, cid).await { 154 61 Ok(policy) => { 155 62 if !policy.can_serve { 156 63 return Err(( 157 64 StatusCode::GONE, 65 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 158 66 Json(ErrorResponse { 159 - error: "BlobUnavailable", 160 - message: Some("Blob is not available through this service"), 67 + error: "PolicyForbidden", 68 + message: Some("Requested blob cannot be served by this service"), 161 69 }), 162 70 )); 163 71 } 164 72 } 165 73 Err(_) => { 166 - if !state.policy_service_fail_open { 74 + if !state.policy_fail_open { 75 + // TODO: Maybe give a more precise error? 167 76 return Err(( 168 77 StatusCode::INTERNAL_SERVER_ERROR, 78 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 169 79 Json(ErrorResponse { 170 80 error: "InternalServerError", 171 - message: Some("Internal Server Error"), 81 + message: Some("An internal server error occured."), 172 82 }), 173 83 )); 174 84 } ··· 176 86 } 177 87 } 178 88 179 - // Serve from cache, or fetch from upstream. Concurrent requests for the same key are 180 - // coalesced — if the initial fetch fails, the next pending request will try instead, 181 - // continuing until one succeeds or all have failed. 89 + // Fetch the blob from cache/origin. 182 90 let blob = match state 183 - .cache 184 - .blob_content 185 - .try_get_with_by_ref(&cid, async { 186 - tracing::debug!("fetching blob from PDS"); 187 - let blob_url = to_pds_blob_url( 188 - state 189 - .cache 190 - .identity 191 - .try_get_with_by_ref(&did, state.identity_resolver.pds_for_did(&did)) 192 - .await 193 - .map_err(|err| { 194 - tracing::debug!("failed to resolve PDS: {:?}", *err); 195 - BlobDownloadError::DidPdsResolutionFailure 196 - })?, 197 - &did, 198 - &cid, 199 - ); 200 - 201 - let validated_bytes = { 202 - let response = state 203 - .blob_fetch_http_client 204 - .get(blob_url) 205 - .send() 206 - .await 207 - .map_err(|err| { 208 - tracing::warn!("failed to request blob from PDS: {err:?}"); 209 - BlobDownloadError::FetchFailure 210 - })?; 211 - 212 - // Gracefully handle & abort if we do not receive a successful status code. 213 - if !response.status().is_success() { 214 - // Note: Bluesky's PDS implementation sends 400 instead of 404 when a blob is 215 - // not found. This will skip the 404 handler and instead count as an error. 216 - // This is not our responsibility to work around as other implementations do it right. 217 - return Err(match response.status() { 218 - StatusCode::NOT_FOUND => { 219 - tracing::debug!("pds returned 404 for blob"); 220 - BlobDownloadError::NotFound 221 - } 222 - status => { 223 - tracing::debug!("pds returned error status for blob: {status}"); 224 - BlobDownloadError::ErrorStatusCode 225 - } 226 - }); 227 - } 228 - 229 - // Download bytes as a stream, enforcing a max size limit 230 - // and aborting if it's crossed. 231 - let bytes = bytes_stream_capped(response, state.max_blob_size) 232 - .await 233 - .map_err(|err| match err { 234 - BytesStreamCappedError::TooLarge => { 235 - tracing::debug!( 236 - "blob exceeds max size of {} bytes", 237 - state.max_blob_size 238 - ); 239 - BlobDownloadError::TooLarge 240 - } 241 - BytesStreamCappedError::ClientError(err) => { 242 - tracing::warn!("error reading blob stream: {err:?}"); 243 - BlobDownloadError::StreamFailed 244 - } 245 - })?; 246 - 247 - // Verify request CID matches the blob's computed CID. 248 - // 249 - // This operation is done via spawn_blocking as creating the digest will block 250 - // this task's executor from switching to other tasks for as long it runs. 251 - tokio::task::spawn_blocking({ 252 - let bytes = bytes.clone(); 253 - move || { 254 - // Enabled Multihashes are set in the multihash-codetable crate features. 255 - let computed_cid = match Code::try_from(cid.hash().code()) { 256 - Ok(code) => Ok(Cid::new_v1(0x55, code.digest(&bytes))), 257 - Err(err) => { 258 - tracing::warn!("failed to compute CID: {err:?}"); 259 - Err(BlobDownloadError::CidUnsupportedMultihash) 260 - } 261 - }?; 262 - 263 - if computed_cid != *cid { 264 - tracing::warn!("cid mismatch: computed {computed_cid} expected {cid}"); 265 - return Err(BlobDownloadError::CidMismatch); 266 - } 267 - 268 - Ok(()) 269 - } 270 - }) 271 - .await 272 - .expect("CID computing task should not panic")?; 273 - 274 - bytes 275 - }; 276 - 277 - // Infer MIME type from content bytes rather than headers; this is fallible 278 - // and falls back to application/octet-stream if the type is unrecognised. 279 - let mime_type = sniff_mime(&validated_bytes); 280 - if !is_mime_allowed(&mime_type, &state.allowed_mimetypes) { 281 - tracing::debug!("blob was inferred to be a disallowed mime type: {mime_type}"); 282 - return Err(BlobDownloadError::ForbiddenMimeType); 283 - } 284 - 285 - // Build reusable cached headers. 286 - let mut headers = HeaderMap::new(); 287 - headers.insert( 288 - header::CONTENT_TYPE, 289 - mime_type 290 - .essence_str() 291 - .parse() 292 - .expect("should parse mime type as header value"), 293 - ); 294 - headers.insert(header::CACHE_CONTROL, state.cache_control_header.clone()); 295 - headers.insert( 296 - header::CONTENT_SECURITY_POLICY, 297 - const { HeaderValue::from_static("default-src 'none'; sandbox") }, 298 - ); 299 - headers.insert( 300 - header::X_CONTENT_TYPE_OPTIONS, 301 - const { HeaderValue::from_static("nosniff") }, 302 - ); 303 - headers.insert( 304 - header::CONTENT_DISPOSITION, 305 - const { HeaderValue::from_static("attachment") }, 306 - ); 307 - 308 - // Mark this key as verified in the ownership cache. 309 - state 310 - .cache 311 - .blob_ownership 312 - .insert((cid, did.clone()), ()) 313 - .await; 314 - 315 - Ok(CachedBlobData { 316 - bytes: validated_bytes, 317 - headers, 318 - }) 319 - }) 91 + .blob_service 92 + .fetch_blob( 93 + &did, 94 + &cid, 95 + BlobUrlResolver::Pds { 96 + identity_service: &state.identity_service, 97 + }, 98 + state.max_blob_size, 99 + &state.allowed_mimetypes, 100 + ) 320 101 .await 321 102 { 322 103 Ok(blob) => blob, ··· 324 105 return Err(match *err { 325 106 BlobDownloadError::NotFound => ( 326 107 StatusCode::NOT_FOUND, 108 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 327 109 Json(ErrorResponse { 328 110 error: "BlobNotFound", 329 111 message: Some("Blob not found"), ··· 331 113 ), 332 114 BlobDownloadError::TooLarge => ( 333 115 StatusCode::PAYLOAD_TOO_LARGE, 116 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 334 117 Json(ErrorResponse { 335 118 error: "BlobTooLarge", 336 119 message: Some("Blob exceeds maximum allowed size"), ··· 338 121 ), 339 122 BlobDownloadError::ForbiddenMimeType => ( 340 123 StatusCode::FORBIDDEN, 124 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 341 125 Json(ErrorResponse { 342 126 error: "BlobForbiddenType", 343 127 message: Some("Content type is not allowed"), ··· 345 129 ), 346 130 BlobDownloadError::CidMismatch => ( 347 131 StatusCode::BAD_GATEWAY, 132 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 348 133 Json(ErrorResponse { 349 134 error: "BlobCidMismatch", 350 135 message: Some("Blob content does not match CID"), ··· 352 137 ), 353 138 BlobDownloadError::CidUnsupportedMultihash => ( 354 139 StatusCode::NOT_IMPLEMENTED, 140 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 355 141 Json(ErrorResponse { 356 142 error: "CidUnsupported", 357 143 message: Some("Unsupported CID multihash"), 358 144 }), 359 145 ), 360 - BlobDownloadError::DidPdsResolutionFailure => ( 146 + BlobDownloadError::BlobResolutionFailure => ( 361 147 StatusCode::BAD_GATEWAY, 148 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 362 149 Json(ErrorResponse { 363 - error: "CannotResolvePds", 364 - message: Some("Failed to resolve PDS for DID"), 150 + error: "CannotResolve", 151 + message: Some("Failed to resolve source of blob"), 365 152 }), 366 153 ), 367 154 BlobDownloadError::FetchFailure 368 155 | BlobDownloadError::ErrorStatusCode 369 156 | BlobDownloadError::StreamFailed => ( 370 157 StatusCode::BAD_GATEWAY, 158 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 371 159 Json(ErrorResponse { 372 160 error: "BlobFetchFailed", 373 - message: Some("Failed to fetch blob from PDS"), 161 + message: Some("Failed to fetch blob from origin"), 374 162 }), 375 163 ), 376 164 }); 377 165 } 378 166 }; 379 167 380 - // Verify this DID owns the blob; will skip if we just fetched the blob from the same DID+CID pair. 381 - // Concurrent requests for the same key are coalesced. 168 + // Check if the user has a copy of this blob via cache/origin. 169 + // 170 + // Note: This will just return from cache if the blob was just fetched 171 + // using the same key. This check does not validate the blob cid matches, 172 + // just that the blob is reported to exist. 382 173 if let Err(err) = state 383 - .cache 384 - .blob_ownership 385 - .try_get_with((cid, did.clone()), async { 386 - tracing::debug!("verifying ownership of blob"); 387 - let blob_url = to_pds_blob_url( 388 - state 389 - .cache 390 - .identity 391 - .try_get_with_by_ref(&did, state.identity_resolver.pds_for_did(&did)) 392 - .await 393 - .map_err(|err| { 394 - tracing::debug!("failed to resolve PDS: {:?}", *err); 395 - BlobOwnershipError::DidPdsResolutionFailure 396 - })?, 397 - &did, 398 - &cid, 399 - ); 400 - 401 - // Request the blob with as little of the actual body as we can. 402 - // 403 - // While some PDS implementations (bsky, tranquil) support HTTP HEAD, it is not 404 - // actually a part of the XRPC specification and we cannot rely on it (for now). 405 - // Use a range request to avoid downloading the full body on servers that support it instead. 406 - match state 407 - .blob_fetch_http_client 408 - .get(blob_url) 409 - .header( 410 - header::RANGE, 411 - const { HeaderValue::from_static("bytes=0-1023") }, 412 - ) 413 - .send() 414 - .await 415 - .map_err(|err| { 416 - tracing::warn!("failed to request blob from PDS: {err:?}"); 417 - BlobOwnershipError::FetchFailure 418 - })? 419 - .status() 420 - { 421 - status if status.is_success() => { 422 - tracing::debug!("verified ownership of blob"); 423 - Ok(()) 424 - } 425 - StatusCode::NOT_FOUND | StatusCode::BAD_REQUEST => { 426 - tracing::debug!("pds returned 404 for blob"); 427 - Err(BlobOwnershipError::NotFound) 428 - } 429 - status => { 430 - tracing::debug!("pds returned error status for blob: {}", status); 431 - Err(BlobOwnershipError::ErrorStatusCode) 432 - } 433 - } 434 - }) 174 + .blob_service 175 + .fetch_blob_ownership( 176 + &did, 177 + cid, 178 + BlobUrlResolver::Pds { 179 + identity_service: &state.identity_service, 180 + }, 181 + ) 435 182 .await 436 183 { 437 184 return Err(match *err { 438 185 BlobOwnershipError::NotFound => ( 439 186 StatusCode::NOT_FOUND, 187 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 440 188 Json(ErrorResponse { 441 189 error: "BlobNotFound", 442 190 message: Some("Blob not found"), 443 191 }), 444 192 ), 445 - BlobOwnershipError::DidPdsResolutionFailure => ( 193 + BlobOwnershipError::BlobResolutionFailure => ( 446 194 StatusCode::BAD_GATEWAY, 195 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 447 196 Json(ErrorResponse { 448 197 error: "CannotResolvePds", 449 198 message: Some("Failed to resolve PDS for DID"), ··· 451 200 ), 452 201 BlobOwnershipError::ErrorStatusCode | BlobOwnershipError::FetchFailure => ( 453 202 StatusCode::BAD_GATEWAY, 203 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 454 204 Json(ErrorResponse { 455 205 error: "BlobFetchFailed", 456 - message: Some("Failed to fetch blob from PDS"), 206 + message: Some("Failed to fetch blob from origin"), 457 207 }), 458 208 ), 459 209 }); 460 210 } 461 211 462 - let mut response = Response::builder() 212 + Ok(Response::builder() 463 213 .status(StatusCode::OK) 214 + .header(header::CONTENT_TYPE, blob.mime_type.essence_str()) 215 + .header(header::CACHE_CONTROL, &state.cache_control_header) 216 + .header( 217 + header::CONTENT_SECURITY_POLICY, 218 + const { HeaderValue::from_static("default-src 'none'; sandbox") }, 219 + ) 220 + .header( 221 + header::X_CONTENT_TYPE_OPTIONS, 222 + const { HeaderValue::from_static("nosniff") }, 223 + ) 224 + .header( 225 + header::CONTENT_DISPOSITION, 226 + const { HeaderValue::from_static("attachment") }, 227 + ) 464 228 .body(Body::from(blob.bytes)) 465 - .expect("response should always build successfully"); 466 - response.headers_mut().extend(blob.headers); 467 - Ok(response) 229 + .expect("response should always build successfully")) 468 230 }
+1
src/routes/blob/mod.rs
··· 1 1 mod get; 2 + 2 3 pub use get::get_blob_handler;
+38 -80
src/routes/cache/delete.rs
··· 1 - use crate::{AppState, routes::ErrorResponse, types::blob_cid::BlobCid}; 1 + use crate::{ 2 + AppState, 3 + routes::{CACHE_CONTROL_NOCACHE_VALUE, ErrorResponse}, 4 + types::blob_cid::BlobCid, 5 + }; 2 6 use axum::{ 3 7 Json, 4 8 extract::{Path, State}, 5 - http::StatusCode, 9 + http::{HeaderName, StatusCode, header}, 6 10 }; 7 11 use axum_extra::{ 8 12 TypedHeader, ··· 10 14 }; 11 15 use jacquard_common::types::did::Did; 12 16 use std::sync::Arc; 17 + use subtle::ConstantTimeEq; 13 18 14 19 pub async fn delete_cache_handler( 15 20 Path(identifier): Path<String>, 16 21 State(state): State<Arc<AppState>>, 17 22 TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>, 18 - ) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> { 19 - if state.auth_token.as_deref() != Some(bearer.token()) { 23 + ) -> Result< 24 + StatusCode, 25 + ( 26 + StatusCode, 27 + [(HeaderName, &'static str); 1], 28 + Json<ErrorResponse>, 29 + ), 30 + > { 31 + if state 32 + .auth_token 33 + .as_ref() 34 + .map(|expected| expected.as_bytes().ct_eq(bearer.token().as_bytes()).into()) 35 + .unwrap_or(false) 36 + { 20 37 return Err(( 21 38 StatusCode::UNAUTHORIZED, 39 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 22 40 Json(ErrorResponse { 23 41 error: "Unauthorized", 24 42 message: None, ··· 26 44 )); 27 45 } 28 46 47 + // TODO: Really need to expose a nicer cache purging API, 48 + // matching on prefix sucks. 29 49 if identifier.starts_with("did:") { 30 50 tracing::info!("invalidating DID cache entries"); 31 51 let did = Did::new_owned(identifier).map_err(|_| { 32 52 ( 33 53 StatusCode::UNPROCESSABLE_ENTITY, 54 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 34 55 Json(ErrorResponse { 35 56 error: "MalformedDid", 36 57 message: Some("Invalid or unprocessable DID"), 37 58 }), 38 59 ) 39 60 })?; 40 - 41 - // Clear all identity, ownership and policy data for this DID. 42 - state 43 - .cache 44 - .identity 45 - .invalidate_entries_if({ 46 - let did = did.clone(); 47 - move |k, _v| *k == did 48 - }) 49 - .map_err(|err| { 50 - tracing::error!("failed to schedule identity cache invalidation: {err:?}"); 51 - ( 52 - StatusCode::INTERNAL_SERVER_ERROR, 53 - Json(ErrorResponse { 54 - error: "InternalServerError", 55 - message: Some("Failed to schedule cache invalidation"), 56 - }), 57 - ) 58 - })?; 59 - state 60 - .cache 61 - .blob_policy 62 - .invalidate_entries_if({ 61 + state.identity_service.invalidate_did_cache(&did).await; 62 + if let Some(ref policy_client) = state.policy_client { 63 + policy_client.invalidate_policies({ 63 64 let did = did.clone(); 64 65 move |k, _v| k.0 == did 65 66 }) 66 - .map_err(|err| { 67 - tracing::error!("failed to schedule blob policy cache invalidation: {err:?}"); 68 - ( 69 - StatusCode::INTERNAL_SERVER_ERROR, 70 - Json(ErrorResponse { 71 - error: "InternalServerError", 72 - message: Some("Failed to schedule cache invalidation"), 73 - }), 74 - ) 75 - })?; 67 + } 76 68 state 77 - .cache 78 - .blob_ownership 79 - .invalidate_entries_if(move |k, _v| k.1 == did) 80 - .map_err(|err| { 81 - tracing::error!("failed to schedule blob ownership cache invalidation: {err:?}"); 82 - ( 83 - StatusCode::INTERNAL_SERVER_ERROR, 84 - Json(ErrorResponse { 85 - error: "InternalServerError", 86 - message: Some("Failed to schedule cache invalidation"), 87 - }), 88 - ) 89 - })?; 69 + .blob_service 70 + .invalidate_blob_ownership(move |k, _v| k.1 == did); 90 71 } else { 91 72 tracing::info!("invalidating CID cache entries"); 92 73 let cid = BlobCid::try_from(identifier.as_str()).map_err(|_| { 93 74 ( 94 75 StatusCode::UNPROCESSABLE_ENTITY, 76 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 95 77 Json(ErrorResponse { 96 78 error: "MalformedCid", 97 79 message: Some("Invalid or unprocessable CID"), 98 80 }), 99 81 ) 100 82 })?; 101 - 102 - // Clear blob content from memory as well as ownership and policy data for this CID. 103 - state.cache.blob_content.invalidate(&cid).await; 83 + state.blob_service.invalidate_blob(&cid).await; 104 84 state 105 - .cache 106 - .blob_ownership 107 - .invalidate_entries_if(move |k, _v| k.0 == cid) 108 - .map_err(|err| { 109 - tracing::error!("failed to schedule blob ownership cache invalidation: {err:?}"); 110 - ( 111 - StatusCode::INTERNAL_SERVER_ERROR, 112 - Json(ErrorResponse { 113 - error: "InternalServerError", 114 - message: Some("Failed to schedule cache invalidation"), 115 - }), 116 - ) 117 - })?; 118 - state 119 - .cache 120 - .blob_policy 121 - .invalidate_entries_if(move |k, _v| k.1 == cid) 122 - .map_err(|err| { 123 - tracing::error!("failed to schedule blob policy cache invalidation: {err:?}"); 124 - ( 125 - StatusCode::INTERNAL_SERVER_ERROR, 126 - Json(ErrorResponse { 127 - error: "InternalServerError", 128 - message: Some("Failed to schedule cache invalidation"), 129 - }), 130 - ) 131 - })?; 85 + .blob_service 86 + .invalidate_blob_ownership(move |k, _v| k.0 == cid); 87 + if let Some(ref policy_client) = state.policy_client { 88 + policy_client.invalidate_policies(move |k, _v| k.1 == cid) 89 + } 132 90 } 133 91 134 92 Ok(StatusCode::OK)
+1
src/routes/cache/mod.rs
··· 1 1 mod delete; 2 + 2 3 pub use delete::*;
+26
src/routes/health/get.rs
··· 1 + use crate::routes::CACHE_CONTROL_NOCACHE_VALUE; 2 + use axum::{ 3 + Json, 4 + http::{StatusCode, header}, 5 + response::IntoResponse, 6 + }; 7 + use serde::Serialize; 8 + 9 + #[derive(Serialize)] 10 + struct GetHealthResponse { 11 + version: &'static str, 12 + } 13 + 14 + pub async fn get_health_handler() -> impl IntoResponse { 15 + ( 16 + StatusCode::OK, 17 + [(header::CACHE_CONTROL, CACHE_CONTROL_NOCACHE_VALUE)], 18 + Json(GetHealthResponse { 19 + version: concat!( 20 + env!("CARGO_PKG_VERSION_MAJOR"), 21 + ".", 22 + env!("CARGO_PKG_VERSION_MINOR") 23 + ), 24 + }), 25 + ) 26 + }
+3
src/routes/health/mod.rs
··· 1 + pub mod get; 2 + 3 + pub use get::get_health_handler;
+9 -13
src/routes/mod.rs
··· 1 1 mod blob; 2 2 mod cache; 3 + mod health; 3 4 4 5 pub use blob::get_blob_handler; 5 6 pub use cache::delete_cache_handler; 7 + pub use health::get_health_handler; 6 8 7 - use axum::http::{HeaderName, HeaderValue, header}; 8 - use serde::Serialize; 9 + /// A header value for [`header::CACHE_CONTROL`] indicating the response cannot be cached at all. 10 + pub const CACHE_CONTROL_NOCACHE_VALUE: &str = "must-understand, no-store"; 9 11 10 - #[derive(Serialize)] 12 + #[derive(serde::Serialize)] 11 13 pub struct ErrorResponse { 12 14 error: &'static str, 13 15 message: Option<&'static str>, 14 16 } 15 17 16 - pub async fn get_index_handler() -> ([(HeaderName, HeaderValue); 1], &'static str) { 17 - ( 18 - [( 19 - header::CACHE_CONTROL, 20 - const { HeaderValue::from_static("public, max-age=31536000, immutable") }, 21 - )], 22 - r#" 18 + pub async fn get_index_handler() -> &'static str { 19 + r#" 23 20 _____ _ 24 21 | __ \ (_) 25 22 | |__) |__ _ ____ ___ ___ ··· 36 33 37 34 Routes: 38 35 - HTTP GET /{did}/{cid} - Resolve and fetch a blob from its origin. 39 - - HTTP DELETE /cache/{cid or did} - Invalidate cache for either a CID (blob, policy, ownership) or for a DID (ownerships and policies). Requires configured bearer auth token. 40 - "#, 41 - ) 36 + - HTTP DELETE /cache/{cid or did} - Invalidate cache for either a CID (blob, policy, ownership) or for a DID (ownerships and policies). Requires auth. 37 + "# 42 38 }
+39 -17
src/types/blob_cid.rs
··· 1 1 // TODO: Transfer this implementation to a standalone ATProto types crate in the future. 2 2 3 + use cid::Version; 3 4 use serde::Serialize; 4 5 use thiserror::Error; 5 6 6 - pub mod codecs { 7 - pub const RAW: u64 = 0x55; 8 - } 9 - 10 7 #[derive(Debug, Error)] 11 8 pub enum BlobCidError { 12 - /// The CID uses a codec other than raw (`0x55`), which is the only codec 13 - /// permitted for ATProto blobs. 9 + /// The CID uses an invalid codec type. 14 10 #[error("invalid blob codec 0x{0:x}, the only supported codec is raw (0x55)")] 15 11 InvalidBlobCodec(u64), 16 - 17 - /// The underlying CID could not be parsed. 12 + /// The CID uses an invalid version. 13 + #[error("invalid blob version {0:?}, the only supported version is v1")] 14 + InvalidBlobVersion(Version), 15 + /// The CID uses an invalid multihash. 16 + #[error("invalid multihash {0:?}, the only supported version is sha256")] 17 + InvalidMultihash(multihash_codetable::Multihash), 18 + /// An error from the CID crate. 18 19 #[error(transparent)] 19 20 CidError(#[from] cid::Error), 20 21 } 21 22 22 - /// A [`cid::Cid`] wrapper that guarantees the codec is raw (`0x55`), conforming 23 - /// to the ATProto blob CID specification. 23 + /// A [`cid::Cid`] wrapper that guarantees that data conforms to the 24 + /// ATProto blob CID specification where possible. 24 25 /// 25 - /// Specification: <https://atproto.com/specs/blob> (Conformant as of **13/03/26**). 26 + /// Note: BlobCid does not currently attempt to validate the 27 + /// encoding representation of the given value. 28 + /// 29 + /// Specification: <https://atproto.com/specs/blob>. 26 30 #[derive(Copy, PartialEq, Eq, Clone, PartialOrd, Ord, Hash, Debug, Serialize)] 27 31 pub struct BlobCid(cid::Cid); 28 32 29 33 impl BlobCid { 30 - pub fn new(cid: cid::Cid) -> Result<Self, BlobCidError> { 31 - if cid.codec() != codecs::RAW { 34 + pub fn try_from_cid(cid: cid::Cid) -> Result<Self, BlobCidError> { 35 + // Ensure the cid uses an accepted codec. 36 + if !matches!( 37 + cid.codec(), 38 + 0x55 // Raw 39 + ) { 32 40 return Err(BlobCidError::InvalidBlobCodec(cid.codec())); 33 41 } 42 + 43 + // Ensure the cid uses an accepted version. 44 + if !matches!(cid.version(), Version::V1) { 45 + return Err(BlobCidError::InvalidBlobVersion(cid.version())); 46 + } 47 + 48 + // Ensure the cid uses an accepted multihash. 49 + if !matches!( 50 + multihash_codetable::Code::try_from(cid.hash().code()), 51 + Ok(multihash_codetable::Code::Sha2_256) 52 + ) { 53 + return Err(BlobCidError::InvalidMultihash(*cid.hash())); 54 + } 55 + 34 56 Ok(Self(cid)) 35 57 } 36 58 } ··· 38 60 impl<'de> serde::Deserialize<'de> for BlobCid { 39 61 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> { 40 62 let cid = cid::Cid::deserialize(deserializer)?; 41 - Self::new(cid).map_err(serde::de::Error::custom) 63 + Self::try_from_cid(cid).map_err(serde::de::Error::custom) 42 64 } 43 65 } 44 66 45 67 impl core::convert::TryFrom<&str> for BlobCid { 46 68 type Error = BlobCidError; 47 69 fn try_from(value: &str) -> Result<Self, Self::Error> { 48 - Self::new(cid::Cid::try_from(value)?) 70 + Self::try_from_cid(cid::Cid::try_from(value)?) 49 71 } 50 72 } 51 73 52 74 impl core::convert::TryFrom<String> for BlobCid { 53 75 type Error = BlobCidError; 54 76 fn try_from(value: String) -> Result<Self, Self::Error> { 55 - Self::new(cid::Cid::try_from(value)?) 77 + Self::try_from_cid(cid::Cid::try_from(value)?) 56 78 } 57 79 } 58 80 59 81 impl core::convert::TryFrom<Vec<u8>> for BlobCid { 60 82 type Error = BlobCidError; 61 83 fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> { 62 - Self::new(cid::Cid::try_from(value)?) 84 + Self::try_from_cid(cid::Cid::try_from(value)?) 63 85 } 64 86 } 65 87
-1
src/types/mod.rs
··· 1 1 pub mod blob_cid; 2 - // pub mod validated_blob;s
-80
src/types/validated_blob.rs
··· 1 - // // TODO: Consider transferring this implementation to a standalone ATProto crate in the future. 2 - 3 - // use crate::types::blob_cid::{self}; 4 - // use multihash_codetable::{Code, MultihashDigest}; 5 - // use thiserror::Error; 6 - 7 - // #[derive(Debug, Error)] 8 - // pub enum ValidatedBlobError { 9 - // /// The CID's multihash codec is not supported by the codetable. 10 - // #[error("unsupported multihash codec 0x{0:x}")] 11 - // CidUnsupportedMultihash(u64), 12 - // /// The computed CID of the blob content does not match the expected CID. 13 - // #[error("CID mismatch: computed {computed} but expected {expected}")] 14 - // CidMismatch { 15 - // computed: blob_cid::BlobCid, 16 - // expected: blob_cid::BlobCid, 17 - // }, 18 - // } 19 - 20 - // /// Blob content whose integrity has been verified against a [`blob_cid::BlobCid`]. 21 - // #[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, Hash)] 22 - // pub struct ValidatedBlob(bytes::Bytes); 23 - 24 - // impl ValidatedBlob { 25 - // /// Verify that `bytes` matches the expected `checksum` CID. 26 - // pub fn new<B: Into<bytes::Bytes>>( 27 - // bytes: B, 28 - // checksum: blob_cid::BlobCid, 29 - // ) -> Result<Self, ValidatedBlobError> { 30 - // let bytes = bytes.into(); 31 - 32 - // // Enabled Multihashes are set in the multihash-codetable crate features. 33 - // let hash_code = checksum.hash().code(); 34 - // let computed_cid = match Code::try_from(hash_code) { 35 - // Ok(code) => Ok(blob_cid::BlobCid::new(cid::Cid::new_v1( 36 - // blob_cid::codecs::RAW, 37 - // code.digest(&bytes), 38 - // )) 39 - // .expect("computed CID with raw codec should always be a valid BlobCid")), 40 - // Err(err) => { 41 - // tracing::warn!("failed to compute CID: {err:?}"); 42 - // Err(ValidatedBlobError::CidUnsupportedMultihash(hash_code)) 43 - // } 44 - // }?; 45 - 46 - // if computed_cid != checksum { 47 - // tracing::warn!("cid mismatch: computed {computed_cid} expected {checksum}"); 48 - // return Err(ValidatedBlobError::CidMismatch { 49 - // computed: computed_cid, 50 - // expected: checksum, 51 - // }); 52 - // } 53 - 54 - // Ok(Self(bytes)) 55 - // } 56 - 57 - // #[must_use] 58 - // pub fn into_inner(self) -> bytes::Bytes { 59 - // self.0 60 - // } 61 - // } 62 - 63 - // impl core::convert::AsRef<bytes::Bytes> for ValidatedBlob { 64 - // fn as_ref(&self) -> &bytes::Bytes { 65 - // &self.0 66 - // } 67 - // } 68 - 69 - // impl core::ops::Deref for ValidatedBlob { 70 - // type Target = bytes::Bytes; 71 - // fn deref(&self) -> &Self::Target { 72 - // &self.0 73 - // } 74 - // } 75 - 76 - // impl core::borrow::Borrow<bytes::Bytes> for ValidatedBlob { 77 - // fn borrow(&self) -> &bytes::Bytes { 78 - // &self.0 79 - // } 80 - // }