A better Rust ATProto crate
1// Buffered writer types, pulling HEAVILY from the std library and tokio
2// Single writer supporting sync and async traits
3
4use crate::io::{self, ErrorKind, SeekFrom, Write};
5
6/// Wraps a writer and buffers its output.
7///
8/// It can be excessively inefficient to work directly with something that
9/// implements [`AsyncWrite`]. A `BufWriter` keeps an in-memory buffer of data and
10/// writes it to an underlying writer in large, infrequent batches.
11///
12/// `BufWriter` can improve the speed of programs that make *small* and
13/// *repeated* write calls to the same file or network socket. It does not
14/// help when writing very large amounts at once, or writing just one or a few
15/// times. It also provides no advantage when writing to a destination that is
16/// in memory, like a `Vec<u8>`.
17///
18/// When the `BufWriter` is dropped, the contents of its buffer will be
19/// discarded. Creating multiple instances of a `BufWriter` on the same
20/// stream can cause data loss. If you need to write out the contents of its
21/// buffer, you must manually call flush before the writer is dropped.
22///
23/// [`AsyncWrite`]: AsyncWrite
24/// [`flush`]: super::AsyncWriteExt::flush
25///
26#[pin_project::pin_project]
27pub struct BufWriter<W: ?Sized> {
28 pub(super) buf: Vec<u8>,
29 pub(super) written: usize,
30 pub(super) seek_state: SeekState,
31 // #30888: If the inner writer panics in a call to write, we don't want to
32 // write the buffered data a second time in BufWriter's destructor. This
33 // flag tells the Drop impl if it should skip the flush.
34 panicked: bool,
35 #[pin]
36 pub(super) inner: W,
37}
38
39#[derive(Debug, Clone, Copy)]
40pub(super) enum SeekState {
41 /// `start_seek` has not been called.
42 Init,
43 /// `start_seek` has been called, but `poll_complete` has not yet been called.
44 Start(SeekFrom),
45 /// Waiting for completion of the first `poll_complete` in the `n.checked_sub(remainder).is_none()` branch.
46 PendingOverflowed(i64),
47 /// Waiting for completion of `poll_complete`.
48 Pending,
49}
50
51impl<W: ?Sized + Write> BufWriter<W> {
52 /// Send data in our local buffer into the inner writer, looping as
53 /// necessary until either it's all been sent or an error occurs.
54 ///
55 /// Because all the data in the buffer has been reported to our owner as
56 /// "successfully written" (by returning nonzero success values from
57 /// `write`), any 0-length writes from `inner` must be reported as i/o
58 /// errors from this method.
59 pub(crate) fn flush_buf(&mut self) -> io::Result<()> {
60 /// Helper struct to ensure the buffer is updated after all the writes
61 /// are complete. It tracks the number of written bytes and drains them
62 /// all from the front of the buffer when dropped.
63 struct BufGuard<'a> {
64 buffer: &'a mut Vec<u8>,
65 written: usize,
66 }
67
68 impl<'a> BufGuard<'a> {
69 fn new(buffer: &'a mut Vec<u8>) -> Self {
70 Self { buffer, written: 0 }
71 }
72
73 /// The unwritten part of the buffer
74 fn remaining(&self) -> &[u8] {
75 &self.buffer[self.written..]
76 }
77
78 /// Flag some bytes as removed from the front of the buffer
79 fn consume(&mut self, amt: usize) {
80 self.written += amt;
81 }
82
83 /// true if all of the bytes have been written
84 fn done(&self) -> bool {
85 self.written >= self.buffer.len()
86 }
87 }
88
89 impl Drop for BufGuard<'_> {
90 fn drop(&mut self) {
91 if self.written > 0 {
92 self.buffer.drain(..self.written);
93 }
94 }
95 }
96
97 let mut guard = BufGuard::new(&mut self.buf);
98 while !guard.done() {
99 self.panicked = true;
100 let r = self.inner.write(guard.remaining());
101 self.panicked = false;
102
103 match r {
104 Ok(0) => {
105 return Err(io::Error::new(
106 ErrorKind::WriteZero,
107 "failed to write the buffered data".into(),
108 ));
109 }
110 Ok(n) => guard.consume(n),
111 Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
112 Err(e) => return Err(e),
113 }
114 }
115 Ok(())
116 }
117
118 /// Buffer some data without flushing it, regardless of the size of the
119 /// data. Writes as much as possible without exceeding capacity. Returns
120 /// the number of bytes written.
121 pub(super) fn write_to_buf(&mut self, buf: &[u8]) -> usize {
122 let available = self.spare_capacity();
123 let amt_to_buffer = available.min(buf.len());
124
125 // SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
126 unsafe {
127 self.write_to_buffer_unchecked(&buf[..amt_to_buffer]);
128 }
129
130 amt_to_buffer
131 }
132
133 /// Gets a reference to the underlying writer.
134 ///
135 /// # Examples
136 ///
137 /// ```no_run
138 /// use std::io::BufWriter;
139 /// use std::net::TcpStream;
140 ///
141 /// let mut buffer = BufWriter::new(TcpStream::connect("127.0.0.1:34254").unwrap());
142 ///
143 /// // we can use reference just like buffer
144 /// let reference = buffer.get_ref();
145 /// ```
146 pub fn get_ref(&self) -> &W {
147 &self.inner
148 }
149
150 /// Gets a mutable reference to the underlying writer.
151 ///
152 /// It is inadvisable to directly write to the underlying writer.
153 ///
154 /// # Examples
155 ///
156 /// ```no_run
157 /// use std::io::BufWriter;
158 /// use std::net::TcpStream;
159 ///
160 /// let mut buffer = BufWriter::new(TcpStream::connect("127.0.0.1:34254").unwrap());
161 ///
162 /// // we can use reference just like buffer
163 /// let reference = buffer.get_mut();
164 /// ```
165 pub fn get_mut(&mut self) -> &mut W {
166 &mut self.inner
167 }
168
169 /// Returns a reference to the internally buffered data.
170 ///
171 /// # Examples
172 ///
173 /// ```no_run
174 /// use std::io::BufWriter;
175 /// use std::net::TcpStream;
176 ///
177 /// let buf_writer = BufWriter::new(TcpStream::connect("127.0.0.1:34254").unwrap());
178 ///
179 /// // See how many bytes are currently buffered
180 /// let bytes_buffered = buf_writer.buffer().len();
181 /// ```
182 pub fn buffer(&self) -> &[u8] {
183 &self.buf
184 }
185
186 /// Returns a mutable reference to the internal buffer.
187 ///
188 /// This can be used to write data directly into the buffer without triggering writers
189 /// to the underlying writer.
190 ///
191 /// That the buffer is a `Vec` is an implementation detail.
192 /// Callers should not modify the capacity as there currently is no public API to do so
193 /// and thus any capacity changes would be unexpected by the user.
194 pub(in crate::io) fn buffer_mut(&mut self) -> &mut Vec<u8> {
195 &mut self.buf
196 }
197
198 /// Returns the number of bytes the internal buffer can hold without flushing.
199 ///
200 /// # Examples
201 ///
202 /// ```no_run
203 /// use std::io::BufWriter;
204 /// use std::net::TcpStream;
205 ///
206 /// let buf_writer = BufWriter::new(TcpStream::connect("127.0.0.1:34254").unwrap());
207 ///
208 /// // Check the capacity of the inner buffer
209 /// let capacity = buf_writer.capacity();
210 /// // Calculate how many bytes can be written without flushing
211 /// let without_flush = capacity - buf_writer.buffer().len();
212 /// ```
213 pub fn capacity(&self) -> usize {
214 self.buf.capacity()
215 }
216
217 // Ensure this function does not get inlined into `write`, so that it
218 // remains inlineable and its common path remains as short as possible.
219 // If this function ends up being called frequently relative to `write`,
220 // it's likely a sign that the client is using an improperly sized buffer
221 // or their write patterns are somewhat pathological.
222 #[cold]
223 #[inline(never)]
224 fn write_cold(&mut self, buf: &[u8]) -> io::Result<usize> {
225 if buf.len() > self.spare_capacity() {
226 self.flush_buf()?;
227 }
228
229 // Why not len > capacity? To avoid a needless trip through the buffer when the input
230 // exactly fills it. We'd just need to flush it to the underlying writer anyway.
231 if buf.len() >= self.buf.capacity() {
232 self.panicked = true;
233 let r = self.get_mut().write(buf);
234 self.panicked = false;
235 r
236 } else {
237 // Write to the buffer. In this case, we write to the buffer even if it fills it
238 // exactly. Doing otherwise would mean flushing the buffer, then writing this
239 // input to the inner writer, which in many cases would be a worse strategy.
240
241 // SAFETY: There was either enough spare capacity already, or there wasn't and we
242 // flushed the buffer to ensure that there is. In the latter case, we know that there
243 // is because flushing ensured that our entire buffer is spare capacity, and we entered
244 // this block because the input buffer length is less than that capacity. In either
245 // case, it's safe to write the input buffer to our buffer.
246 unsafe {
247 self.write_to_buffer_unchecked(buf);
248 }
249
250 Ok(buf.len())
251 }
252 }
253
254 // Ensure this function does not get inlined into `write_all`, so that it
255 // remains inlineable and its common path remains as short as possible.
256 // If this function ends up being called frequently relative to `write_all`,
257 // it's likely a sign that the client is using an improperly sized buffer
258 // or their write patterns are somewhat pathological.
259 #[cold]
260 #[inline(never)]
261 fn write_all_cold(&mut self, buf: &[u8]) -> io::Result<()> {
262 // Normally, `write_all` just calls `write` in a loop. We can do better
263 // by calling `self.get_mut().write_all()` directly, which avoids
264 // round trips through the buffer in the event of a series of partial
265 // writes in some circumstances.
266
267 if buf.len() > self.spare_capacity() {
268 self.flush_buf()?;
269 }
270
271 // Why not len > capacity? To avoid a needless trip through the buffer when the input
272 // exactly fills it. We'd just need to flush it to the underlying writer anyway.
273 if buf.len() >= self.buf.capacity() {
274 self.panicked = true;
275 let r = self.get_mut().write_all(buf);
276 self.panicked = false;
277 r
278 } else {
279 // Write to the buffer. In this case, we write to the buffer even if it fills it
280 // exactly. Doing otherwise would mean flushing the buffer, then writing this
281 // input to the inner writer, which in many cases would be a worse strategy.
282
283 // SAFETY: There was either enough spare capacity already, or there wasn't and we
284 // flushed the buffer to ensure that there is. In the latter case, we know that there
285 // is because flushing ensured that our entire buffer is spare capacity, and we entered
286 // this block because the input buffer length is less than that capacity. In either
287 // case, it's safe to write the input buffer to our buffer.
288 unsafe {
289 self.write_to_buffer_unchecked(buf);
290 }
291
292 Ok(())
293 }
294 }
295
296 // SAFETY: Requires `buf.len() <= self.buf.capacity() - self.buf.len()`,
297 // i.e., that input buffer length is less than or equal to spare capacity.
298 #[inline]
299 unsafe fn write_to_buffer_unchecked(&mut self, buf: &[u8]) {
300 debug_assert!(buf.len() <= self.spare_capacity());
301 let old_len = self.buf.len();
302 let buf_len = buf.len();
303 let src = buf.as_ptr();
304 unsafe {
305 let dst = self.buf.as_mut_ptr().add(old_len);
306 core::ptr::copy_nonoverlapping(src, dst, buf_len);
307 self.buf.set_len(old_len + buf_len);
308 }
309 }
310
311 #[inline]
312 fn spare_capacity(&self) -> usize {
313 self.buf.capacity() - self.buf.len()
314 }
315}
316
317impl<W: ?Sized + Write> Write for BufWriter<W> {
318 #[inline]
319 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
320 // Use < instead of <= to avoid a needless trip through the buffer in some cases.
321 // See `write_cold` for details.
322 if buf.len() < self.spare_capacity() {
323 // SAFETY: safe by above conditional.
324 unsafe {
325 self.write_to_buffer_unchecked(buf);
326 }
327
328 Ok(buf.len())
329 } else {
330 self.write_cold(buf)
331 }
332 }
333
334 #[inline]
335 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
336 // Use < instead of <= to avoid a needless trip through the buffer in some cases.
337 // See `write_all_cold` for details.
338 if buf.len() < self.spare_capacity() {
339 // SAFETY: safe by above conditional.
340 unsafe {
341 self.write_to_buffer_unchecked(buf);
342 }
343
344 Ok(())
345 } else {
346 self.write_all_cold(buf)
347 }
348 }
349
350 fn flush(&mut self) -> io::Result<()> {
351 self.flush_buf().and_then(|()| self.get_mut().flush())
352 }
353}