A fork of attic a self-hostable Nix Binary Cache server
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

attic/hash_reader: Add AsyncBufRead support

+138 -33
+1
Cargo.lock
··· 242 242 "hex", 243 243 "lazy_static", 244 244 "nix-base32", 245 + "pin-project", 245 246 "regex", 246 247 "serde", 247 248 "serde_json",
+1
attic/Cargo.toml
··· 16 16 hex = "0.4.3" 17 17 lazy_static = "1.5.0" 18 18 nix-base32 = "0.2.0" 19 + pin-project = "1.1.10" 19 20 regex = "1.11.1" 20 21 serde = { version = "1.0.219", features = ["derive"] } 21 22 serde_with = "3.14.0"
+136 -33
attic/src/io/hash_reader.rs
··· 1 1 use std::marker::Unpin; 2 2 use std::pin::Pin; 3 3 use std::sync::Arc; 4 - use std::task::{Context, Poll}; 4 + use std::task::{ready, Context, Poll}; 5 5 6 6 use digest::{Digest, Output as DigestOutput}; 7 - use tokio::io::{AsyncRead, ReadBuf}; 7 + use pin_project::pin_project; 8 + use tokio::io::{self, AsyncBufRead, AsyncRead, ReadBuf}; 8 9 use tokio::sync::OnceCell; 9 10 10 11 /// AsyncRead filter that hashes the bytes that have been read. 11 12 /// 12 13 /// The hash is finalized when EOF is reached. 14 + #[pin_project(project = HashReaderProj)] 13 15 pub struct HashReader<R, D> 14 16 where 15 17 R: AsyncRead + Unpin, 16 18 D: Digest + Unpin, 17 19 { 20 + #[pin] 18 21 inner: R, 22 + state: State<D>, 23 + } 24 + 25 + struct State<D> 26 + where 27 + D: Digest + Unpin, 28 + { 19 29 digest: Option<D>, 20 - bytes_read: usize, 30 + bytes_hashed: usize, 31 + bytes_consumed: usize, 21 32 finalized: Arc<OnceCell<(DigestOutput<D>, usize)>>, 22 33 } 23 34 35 + impl<D> State<D> 36 + where 37 + D: Digest + Unpin, 38 + { 39 + fn hash_unconsumed(&mut self, unconsumed: &[u8]) { 40 + let unhashed_offset = self.bytes_hashed - self.bytes_consumed; 41 + 42 + // It's technically possible for the `poll_read`/`poll_fill_buf` implementation 43 + // to return less data than the unconsumed portion returned by a previous 44 + // call to `AsyncBufRead::poll_fill_buf`. 45 + if unhashed_offset < unconsumed.len() { 46 + let unhashed = &unconsumed[unhashed_offset..]; 47 + self.bytes_hashed += unhashed.len(); 48 + 49 + let digest = self.digest.as_mut().expect("Stream has data after EOF"); 50 + digest.update(unhashed); 51 + } 52 + } 53 + 54 + fn eof(&mut self) { 55 + if let Some(digest) = self.digest.take() { 56 + assert!(self.bytes_hashed == self.bytes_consumed, "bytes_hashed != bytes_consumed but EOF - Unconsumed bytes disappeared from buffer??"); 57 + self.finalized 58 + .set((digest.finalize(), self.bytes_hashed)) 59 + .expect("Hash has already been finalized"); 60 + } 61 + } 62 + } 63 + 24 64 impl<R, D> HashReader<R, D> 25 65 where 26 66 R: AsyncRead + Unpin, ··· 32 72 ( 33 73 Self { 34 74 inner, 35 - digest: Some(digest), 36 - bytes_read: 0, 37 - finalized: finalized.clone(), 75 + state: State { 76 + digest: Some(digest), 77 + bytes_hashed: 0, 78 + bytes_consumed: 0, 79 + finalized: finalized.clone(), 80 + }, 38 81 }, 39 82 finalized, 40 83 ) ··· 47 90 D: Digest + Unpin, 48 91 { 49 92 fn poll_read( 50 - mut self: Pin<&mut Self>, 93 + self: Pin<&mut Self>, 51 94 cx: &mut Context<'_>, 52 95 buf: &mut ReadBuf<'_>, 53 - ) -> Poll<tokio::io::Result<()>> { 96 + ) -> Poll<io::Result<()>> { 97 + let this = self.project(); 98 + 54 99 let old_filled = buf.filled().len(); 55 - let r = Pin::new(&mut self.inner).poll_read(cx, buf); 56 - let read_len = buf.filled().len() - old_filled; 100 + ready!(this.inner.poll_read(cx, buf))?; 57 101 58 - match r { 59 - Poll::Ready(Ok(())) => { 60 - if read_len == 0 { 61 - // EOF 62 - if let Some(digest) = self.digest.take() { 63 - self.finalized 64 - .set((digest.finalize(), self.bytes_read)) 65 - .expect("Hash has already been finalized"); 66 - } 67 - } else { 68 - // Read something 69 - let digest = self.digest.as_mut().expect("Stream has data after EOF"); 102 + let filled = buf.filled(); 103 + let unconsumed = &filled[old_filled..]; 104 + if unconsumed.len() == 0 { 105 + this.state.eof(); 106 + } else { 107 + this.state.hash_unconsumed(unconsumed); 108 + this.state.bytes_consumed += unconsumed.len(); 109 + } 110 + 111 + debug_assert!(this.state.bytes_consumed <= this.state.bytes_hashed); 112 + Poll::Ready(Ok(())) 113 + } 114 + } 115 + 116 + impl<R, D> AsyncBufRead for HashReader<R, D> 117 + where 118 + R: AsyncBufRead + Unpin, 119 + D: Digest + Unpin, 120 + { 121 + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 122 + let this = self.project(); 123 + let unconsumed = ready!(this.inner.poll_fill_buf(cx))?; 70 124 71 - let filled = buf.filled(); 72 - digest.update(&filled[filled.len() - read_len..]); 73 - self.bytes_read += read_len; 74 - } 75 - } 76 - Poll::Ready(Err(_)) => { 77 - assert!(read_len == 0); 78 - } 79 - Poll::Pending => {} 125 + if unconsumed.len() == 0 { 126 + this.state.eof(); 127 + } else { 128 + this.state.hash_unconsumed(unconsumed); 80 129 } 81 130 82 - r 131 + debug_assert!(this.state.bytes_consumed <= this.state.bytes_hashed); 132 + Poll::Ready(Ok(unconsumed)) 133 + } 134 + 135 + fn consume(self: Pin<&mut Self>, amt: usize) { 136 + let this = self.project(); 137 + this.inner.consume(amt); 138 + this.state.bytes_consumed += amt; 139 + 140 + debug_assert!(this.state.bytes_consumed <= this.state.bytes_hashed); 83 141 } 84 142 } 85 143 ··· 87 145 mod tests { 88 146 use super::*; 89 147 90 - use tokio::io::AsyncReadExt; 148 + use tokio::io::{AsyncBufReadExt, AsyncReadExt}; 91 149 92 150 #[tokio::test] 93 151 async fn test_hash_reader() { ··· 118 176 .read(&mut buf[bytes_read..bytes_read + 5]) 119 177 .await 120 178 .unwrap(); 179 + 180 + assert_eq!(expected.len(), bytes_read); 181 + assert_eq!(expected, &buf[..bytes_read]); 182 + 183 + let (hash, count) = finalized.get().expect("Hash wasn't finalized"); 184 + 185 + assert_eq!(expected_sha256.as_slice(), hash.as_slice()); 186 + assert_eq!(expected.len(), *count); 187 + eprintln!("finalized = {:x?}", finalized); 188 + } 189 + 190 + #[tokio::test] 191 + async fn test_hash_reader_buf() { 192 + let expected = b"hello world"; 193 + let expected_sha256 = 194 + hex::decode("b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9") 195 + .unwrap(); 196 + 197 + let (mut read, finalized) = HashReader::new(expected.as_slice(), sha2::Sha256::new()); 198 + assert!(finalized.get().is_none()); 199 + 200 + let mut buf = vec![0u8; 100]; 201 + let mut bytes_read = 0; 202 + 203 + // Mix AsyncRead::read() and AsyncBufRead::fill_buf() 204 + 205 + bytes_read += read 206 + .read(&mut buf[bytes_read..bytes_read + 1]) 207 + .await 208 + .unwrap(); 209 + 210 + loop { 211 + // Perform multiple AsyncBufRead::fill_buf()s _without_ consuming 212 + let _ = read.fill_buf().await.unwrap(); 213 + let _ = read.fill_buf().await.unwrap(); 214 + let read_buf = read.fill_buf().await.unwrap(); 215 + 216 + if read_buf.is_empty() { 217 + break; 218 + } 219 + 220 + buf[bytes_read] = read_buf[0]; 221 + read.consume(1); 222 + bytes_read += 1; 223 + } 121 224 122 225 assert_eq!(expected.len(), bytes_read); 123 226 assert_eq!(expected, &buf[..bytes_read]);