don't
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(knot): improve entity tag handling

Signed-off-by: tjh <x@tjh.dev>

tjh 427d18d1 b62483ca

+363 -57
+357 -53
crates/gordian-knot/src/extract/if_none_match.rs
··· 1 1 use core::fmt; 2 2 3 3 use axum::extract::FromRequestParts; 4 - use axum::extract::OptionalFromRequestParts; 5 - use axum::http::HeaderMap; 4 + use axum::http::HeaderValue; 6 5 use axum::http::StatusCode; 6 + use axum::http::header::ETAG; 7 + use axum::http::header::IF_NONE_MATCH; 8 + use axum::http::header::InvalidHeaderValue; 9 + use axum::http::request::Parts; 7 10 use axum::response::IntoResponse; 8 - use reqwest::header::IF_NONE_MATCH; 9 - 10 - #[derive(Debug, PartialEq, Eq)] 11 - pub enum ETag { 12 - Strong(String), 13 - Weak(String), 14 - } 15 - 16 - impl fmt::Display for ETag { 17 - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 18 - match self { 19 - Self::Strong(value) => write!(f, "\"{value}\""), 20 - Self::Weak(value) => write!(f, "W/\"{value}\""), 21 - } 22 - } 23 - } 11 + use axum::response::IntoResponseParts; 12 + use axum::response::ResponseParts; 24 13 25 14 #[derive(Debug)] 26 - pub struct IfNoneMatch(pub Vec<ETag>); 15 + pub struct EntityTag { 16 + weak: bool, 17 + value: String, 18 + } 27 19 28 - impl IfNoneMatch { 29 - pub fn contains_strong(&self, value: &str) -> bool { 30 - for tag in &self.0 { 31 - match tag { 32 - ETag::Strong(strong) if strong == value => return true, 33 - _ => continue, 34 - } 35 - } 36 - false 20 + impl EntityTag { 21 + /// Create a strong [`EntityTag`]. 22 + /// 23 + /// # Example 24 + /// 25 + /// ```rust 26 + /// # use gordian_knot::extract::if_none_match::EntityTag; 27 + /// let etag = EntityTag::strong("a16f46ac"); 28 + /// assert_eq!(etag.to_string(), r#""a16f46ac""#); 29 + /// assert!(!etag.is_weak()); 30 + /// ``` 31 + /// 32 + /// # Panics 33 + /// 34 + /// Panics if `value` contains non-ascii characters. 35 + pub fn strong(value: impl Into<String>) -> Self { 36 + Self::new(false, value.into()) 37 37 } 38 38 39 - pub fn contains(&self, value: &str) -> bool { 40 - for tag in &self.0 { 41 - match tag { 42 - ETag::Strong(strong) if strong == value => return true, 43 - ETag::Weak(weak) if weak == value => return true, 44 - _ => continue, 45 - } 39 + /// Create a weak [`EntityTag`]. 40 + /// 41 + /// # Example 42 + /// 43 + /// ```rust 44 + /// # use gordian_knot::extract::if_none_match::EntityTag; 45 + /// let etag = EntityTag::weak("a16f46ac"); 46 + /// assert_eq!(etag.to_string(), r#"W/"a16f46ac""#); 47 + /// assert!(etag.is_weak()); 48 + /// ``` 49 + /// 50 + /// # Panics 51 + /// 52 + /// Panics if `value` contains non-ascii characters. 53 + pub fn weak(value: impl Into<String>) -> Self { 54 + Self::new(true, value.into()) 55 + } 56 + 57 + pub const fn is_weak(&self) -> bool { 58 + self.weak 59 + } 60 + 61 + /// Compare `self` with `other` using strong comparison. 62 + /// 63 + /// # Example 64 + /// 65 + /// ```rust 66 + /// # use gordian_knot::extract::if_none_match::EntityTag; 67 + /// assert!(!EntityTag::weak("1").strong_eq(&EntityTag::weak("1"))); 68 + /// assert!(!EntityTag::weak("1").strong_eq(&EntityTag::weak("2"))); 69 + /// assert!(!EntityTag::weak("1").strong_eq(&EntityTag::strong("1"))); 70 + /// assert!(EntityTag::strong("1").strong_eq(&EntityTag::strong("1"))); 71 + /// ``` 72 + /// 73 + pub fn strong_eq(&self, other: &Self) -> bool { 74 + !self.weak && !other.weak && self.value == other.value 75 + } 76 + 77 + /// Compare `self` with `other` using weak comparison. 78 + /// 79 + /// # Example 80 + /// 81 + /// ```rust 82 + /// # use gordian_knot::extract::if_none_match::EntityTag; 83 + /// assert!(EntityTag::weak("1").weak_eq(&EntityTag::weak("1"))); 84 + /// assert!(!EntityTag::weak("1").weak_eq(&EntityTag::weak("2"))); 85 + /// assert!(EntityTag::weak("1").weak_eq(&EntityTag::strong("1"))); 86 + /// assert!(EntityTag::strong("1").weak_eq(&EntityTag::strong("1"))); 87 + /// ``` 88 + /// 89 + pub fn weak_eq(&self, other: &Self) -> bool { 90 + self.value == other.value 91 + } 92 + 93 + /// Convert the [`EntityTag`] to a [`HeaderValue`]. 94 + pub fn to_header_value(&self) -> Result<HeaderValue, InvalidHeaderValue> { 95 + self.to_string().parse() 96 + } 97 + 98 + fn new(weak: bool, value: String) -> Self { 99 + assert!(value.is_ascii()); 100 + Self { weak, value } 101 + } 102 + } 103 + 104 + impl fmt::Display for EntityTag { 105 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 106 + if self.weak { 107 + f.write_str("W/")?; 46 108 } 47 - false 109 + write!(f, "\"{}\"", self.value) 110 + } 111 + } 112 + 113 + impl std::str::FromStr for EntityTag { 114 + type Err = &'static str; 115 + 116 + fn from_str(value: &str) -> Result<Self, Self::Err> { 117 + let (weak, value) = match value.strip_prefix("W/") { 118 + Some(value) => (true, value), 119 + None => (false, value), 120 + }; 121 + 122 + if !value.is_ascii() { 123 + return Err("entity tag contains non-ascii characters"); 124 + } 125 + 126 + Ok(Self { 127 + weak, 128 + value: value.trim_matches('"').to_string(), 129 + }) 130 + } 131 + } 132 + 133 + impl IntoResponseParts for EntityTag { 134 + type Error = StatusCode; 135 + 136 + fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> { 137 + let header_value = self 138 + .to_string() 139 + .parse() 140 + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 141 + res.headers_mut().insert(ETAG, header_value); 142 + Ok(res) 143 + } 144 + } 145 + 146 + impl IntoResponse for EntityTag { 147 + fn into_response(self) -> axum::response::Response { 148 + (self, ()).into_response() 149 + } 150 + } 151 + 152 + pub trait StrongEq { 153 + fn strong_eq(&self, other: &EntityTag) -> bool; 154 + } 155 + 156 + impl StrongEq for EntityTag { 157 + fn strong_eq(&self, other: &EntityTag) -> bool { 158 + EntityTag::strong_eq(self, other) 159 + } 160 + } 161 + 162 + impl StrongEq for str { 163 + fn strong_eq(&self, other: &EntityTag) -> bool { 164 + !other.weak && self == other.value 165 + } 166 + } 167 + 168 + impl StrongEq for String { 169 + fn strong_eq(&self, other: &EntityTag) -> bool { 170 + !other.weak && self == &other.value 171 + } 172 + } 173 + 174 + pub trait WeakEq { 175 + fn weak_eq(&self, other: &EntityTag) -> bool; 176 + } 177 + 178 + impl WeakEq for EntityTag { 179 + fn weak_eq(&self, other: &EntityTag) -> bool { 180 + EntityTag::weak_eq(self, other) 181 + } 182 + } 183 + 184 + impl WeakEq for str { 185 + fn weak_eq(&self, other: &EntityTag) -> bool { 186 + self == other.value 187 + } 188 + } 189 + 190 + impl WeakEq for String { 191 + fn weak_eq(&self, other: &EntityTag) -> bool { 192 + self == &other.value 193 + } 194 + } 195 + 196 + /// Extractor for the `If-None-Match` header. 197 + /// 198 + /// If the header contains invalid utf8, the request will be rejected 199 + /// with a `400 Bad Request` response. 200 + #[derive(Debug)] 201 + pub struct IfNoneMatch { 202 + wild: bool, 203 + tags: Vec<EntityTag>, 204 + } 205 + 206 + impl IfNoneMatch { 207 + pub fn match_strong(&self, value: &(impl StrongEq + ?Sized)) -> bool { 208 + self.wild || self.tags.iter().find(|tag| value.strong_eq(tag)).is_some() 209 + } 210 + 211 + pub fn match_weak(&self, value: &(impl WeakEq + ?Sized)) -> bool { 212 + self.wild || self.tags.iter().find(|tag| value.weak_eq(tag)).is_some() 213 + } 214 + 215 + fn wild() -> Self { 216 + Self { 217 + wild: true, 218 + tags: Vec::new(), 219 + } 48 220 } 49 221 } 50 222 ··· 235 63 } 236 64 } 237 65 238 - impl<S: Send + Sync> OptionalFromRequestParts<S> for IfNoneMatch { 66 + impl<S: Send + Sync> FromRequestParts<S> for IfNoneMatch { 239 67 type Rejection = ETagRejection; 240 68 241 - async fn from_request_parts( 242 - parts: &mut axum::http::request::Parts, 243 - state: &S, 244 - ) -> Result<Option<Self>, Self::Rejection> { 245 - let headers = HeaderMap::from_request_parts(parts, state) 246 - .await 247 - .expect("HeaderMap extractor is infallible"); 248 - 69 + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { 249 70 let mut tags = Vec::new(); 250 - for header_value in headers.get_all(IF_NONE_MATCH) { 251 - let value = std::str::from_utf8(header_value.as_bytes())?; 252 - for tag in value.split(',').map(|val| val.trim()) { 253 - tags.push(match tag.strip_prefix("W/") { 254 - None => ETag::Strong(tag.trim_matches('"').to_string()), 255 - Some(tag) => ETag::Weak(tag.trim_matches('"').to_string()), 256 - }); 71 + 72 + for header_value in parts.headers.get_all(IF_NONE_MATCH) { 73 + let header_str = std::str::from_utf8(header_value.as_bytes())?; 74 + if let Some("*") = header_str.get(..1) { 75 + return Ok(Self::wild()); 76 + } 77 + 78 + for tag_str in header_str.split(',').map(|val| val.trim()) { 79 + let Ok(tag) = tag_str.parse() else { 80 + tracing::warn!(?tag_str, "failed to parse entity tag from if-none-match"); 81 + continue; 82 + }; 83 + tags.push(tag); 257 84 } 258 85 } 259 86 260 - Ok((!tags.is_empty()).then_some(Self(tags))) 87 + Ok(Self { wild: false, tags }) 88 + } 89 + } 90 + 91 + #[cfg(test)] 92 + mod tests { 93 + use core::error; 94 + 95 + use axum::Router; 96 + use axum::body::Body; 97 + use axum::http::Request; 98 + use axum::http::StatusCode; 99 + use axum::http::header::IF_NONE_MATCH; 100 + use axum::response::IntoResponse; 101 + use tower::ServiceExt as _; 102 + 103 + use crate::extract::if_none_match::EntityTag; 104 + use crate::extract::if_none_match::StrongEq; 105 + use crate::extract::if_none_match::WeakEq as _; 106 + 107 + use super::IfNoneMatch; 108 + 109 + async fn weak_0815(if_none_match: IfNoneMatch) -> impl IntoResponse { 110 + if if_none_match.match_weak("0815") { 111 + StatusCode::NOT_MODIFIED 112 + } else { 113 + StatusCode::OK 114 + } 115 + } 116 + 117 + async fn strong_0815(if_none_match: IfNoneMatch) -> impl IntoResponse { 118 + if if_none_match.match_strong("0815") { 119 + StatusCode::NOT_MODIFIED 120 + } else { 121 + StatusCode::OK 122 + } 123 + } 124 + 125 + fn app() -> Router { 126 + Router::new() 127 + .route("/weak_0815", axum::routing::get(weak_0815)) 128 + .route("/strong_0815", axum::routing::get(strong_0815)) 129 + } 130 + 131 + fn request<'a>(path: &str, etags: impl IntoIterator<Item = &'a str>) -> Request<Body> { 132 + let mut request = Request::get(path); 133 + let headers = request.headers_mut().unwrap(); 134 + for etag in etags { 135 + headers.append(IF_NONE_MATCH, etag.parse().unwrap()); 136 + } 137 + request.body(Body::empty()).unwrap() 138 + } 139 + 140 + #[test] 141 + fn can_parse_weak() { 142 + let tag: EntityTag = r#"W/"a16f46ac""#.parse().unwrap(); 143 + assert!(tag.is_weak()); 144 + assert_eq!(tag.value, "a16f46ac"); 145 + } 146 + 147 + #[test] 148 + fn can_parse_strong() { 149 + let tag: EntityTag = r#""a16f46ac""#.parse().unwrap(); 150 + assert!(!tag.is_weak()); 151 + assert_eq!(tag.value, "a16f46ac"); 152 + } 153 + 154 + #[test] 155 + fn strong_eq_str() { 156 + assert!("a16f46ac".strong_eq(&EntityTag::strong("a16f46ac"))); 157 + assert!(!"a16f46ac".strong_eq(&EntityTag::weak("a16f46ac"))); 158 + 159 + assert!(!"4abc05c9".strong_eq(&EntityTag::strong("a16f46ac"))); 160 + assert!(!"4abc05c9".strong_eq(&EntityTag::weak("a16f46ac"))); 161 + } 162 + 163 + #[test] 164 + fn weak_eq_str() { 165 + assert!("a16f46ac".weak_eq(&EntityTag::strong("a16f46ac"))); 166 + assert!("a16f46ac".weak_eq(&EntityTag::weak("a16f46ac"))); 167 + 168 + assert!(!"4abc05c9".weak_eq(&EntityTag::strong("a16f46ac"))); 169 + assert!(!"4abc05c9".weak_eq(&EntityTag::weak("a16f46ac"))); 170 + } 171 + 172 + #[tokio::test] 173 + async fn if_none_match_weak() -> Result<(), Box<dyn error::Error>> { 174 + let response = app() 175 + .oneshot(request("/weak_0815", Some(r#"W/"0815""#))) 176 + .await?; 177 + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); 178 + 179 + // Strong should also match. 180 + let response = app() 181 + .oneshot(request("/weak_0815", Some(r#""0815""#))) 182 + .await?; 183 + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); 184 + 185 + // Wild-card should also match. 186 + let response = app().oneshot(request("/weak_0815", Some(r#"*"#))).await?; 187 + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); 188 + 189 + // Ensure it doesn't just accept everything. 190 + let response = app() 191 + .oneshot(request("/weak_0815", Some(r#"W/"08159""#))) 192 + .await?; 193 + assert_eq!(response.status(), StatusCode::OK); 194 + 195 + Ok(()) 196 + } 197 + 198 + #[tokio::test] 199 + async fn if_none_match_strong() -> Result<(), Box<dyn error::Error>> { 200 + const PATH: &str = "/strong_0815"; 201 + 202 + let response = app().oneshot(request(PATH, Some(r#"W/"0815""#))).await?; 203 + assert_eq!(response.status(), StatusCode::OK); 204 + 205 + // Strong should also match. 206 + let response = app().oneshot(request(PATH, Some(r#""0815""#))).await?; 207 + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); 208 + 209 + // Wild-card should also match. 210 + let response = app().oneshot(request(PATH, Some(r#"*"#))).await?; 211 + assert_eq!(response.status(), StatusCode::NOT_MODIFIED); 212 + 213 + // Ensure it doesn't just accept everything. 214 + let response = app().oneshot(request(PATH, Some(r#"W/"08159""#))).await?; 215 + assert_eq!(response.status(), StatusCode::OK); 216 + 217 + let response = app().oneshot(request(PATH, Some(r#""08159""#))).await?; 218 + assert_eq!(response.status(), StatusCode::OK); 219 + 220 + Ok(()) 261 221 } 262 222 }
+6 -4
crates/gordian-knot/src/public/xrpc/sh_tangled/repo/impl_blob.rs
··· 18 18 use reqwest::header::ETAG; 19 19 use tokio_rayon::AsyncThreadPool as _; 20 20 21 + use crate::extract::if_none_match::EntityTag; 21 22 use crate::extract::if_none_match::IfNoneMatch; 22 23 use crate::model::Knot; 23 24 use crate::model::repository::ResolveRevspec as _; ··· 36 35 )] 37 36 pub async fn handle( 38 37 State(knot): State<Knot>, 39 - if_none_match: Option<IfNoneMatch>, 38 + if_none_match: IfNoneMatch, 40 39 XrpcQuery(Input { 41 40 repo, 42 41 rev, ··· 50 49 let ResolvedRevspec { commit, immutable } = 51 50 repository.resolve_revspec(&Some(rev.as_str()))?; 52 51 53 - // Use the tree object ID as an ETag. 52 + // Use the tree object ID as an entity tag. 54 53 // 55 54 // 1. If the blob content has changed, the blob object ID will be different, and 56 55 // therefore the tree object ID will also be different. ··· 58 57 // 2. Using the tree object ID avoids searching the tree for the blob path. 59 58 60 59 let tree = repository.get_tree(&commit)?; 61 - if if_none_match.is_some_and(|etags| etags.contains(&tree.id.to_string())) { 60 + let etag = EntityTag::strong(tree.id.to_hex().to_string()); 61 + if if_none_match.match_weak(&etag) { 62 62 return Ok(StatusCode::NOT_MODIFIED.into_response()); 63 63 } 64 64 ··· 85 83 86 84 headers.insert( 87 85 ETAG, 88 - HeaderValue::from_str(&format!("\"{}\"", tree.id)) 86 + etag.to_header_value() 89 87 .expect("Hex-string should be a valid header value"), 90 88 ); 91 89