this repo has no description
1use crate::utils::CmprssOutput;
2use clap::Args;
3use indicatif::{HumanBytes, ProgressBar};
4use std::io::{self, Read, Write};
5use std::str::FromStr;
6use std::time::Duration;
7use std::time::Instant;
8
9#[derive(clap::ValueEnum, Clone, Copy, Debug, Default)]
10pub enum ProgressDisplay {
11 #[default]
12 Auto,
13 On,
14 Off,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct ChunkSize {
19 pub size_in_bytes: usize,
20}
21
22impl Default for ChunkSize {
23 fn default() -> Self {
24 ChunkSize {
25 size_in_bytes: 8192,
26 }
27 }
28}
29
30impl FromStr for ChunkSize {
31 type Err = &'static str;
32
33 fn from_str(s: &str) -> Result<Self, Self::Err> {
34 // Try to parse s as just a number
35 if let Ok(num) = s.parse::<usize>() {
36 if num == 0 {
37 return Err("Invalid number");
38 }
39 return Ok(ChunkSize { size_in_bytes: num });
40 }
41 // Simplify so that we always assume base 2, regardless of whether we see
42 // 'kb' or 'kib'
43 let mut s = s.to_lowercase();
44 if s.ends_with("ib") {
45 s.truncate(s.len() - 2);
46 s.push('b');
47 };
48 let (num_str, unit) = s.split_at(s.len() - 2);
49 let num = num_str.parse::<usize>().map_err(|_| "Invalid number")?;
50
51 let size_in_bytes = match unit {
52 "kb" => num * 1024,
53 "mb" => num * 1024 * 1024,
54 "gb" => num * 1024 * 1024 * 1024,
55 _ => return Err("Invalid unit"),
56 };
57 if size_in_bytes == 0 {
58 return Err("Invalid number");
59 }
60
61 Ok(ChunkSize { size_in_bytes })
62 }
63}
64
65#[derive(Args, Debug, Default, Clone, Copy)]
66pub struct ProgressArgs {
67 /// Show progress.
68 #[arg(long, value_enum, default_value = "auto")]
69 pub progress: ProgressDisplay,
70
71 /// Chunk size to use during the copy when showing the progress bar.
72 #[arg(long, default_value = "8kib")]
73 pub chunk_size: ChunkSize,
74}
75
76/// Create a progress bar if necessary based on settings
77pub fn create_progress_bar(
78 input_size: Option<u64>,
79 progress: ProgressDisplay,
80 output: &CmprssOutput,
81) -> Option<ProgressBar> {
82 match (progress, output) {
83 (ProgressDisplay::Auto, CmprssOutput::Pipe(_)) => None,
84 (ProgressDisplay::Off, _) => None,
85 (_, _) => {
86 let bar = match input_size {
87 Some(size) => ProgressBar::new(size),
88 None => ProgressBar::new_spinner(),
89 };
90 bar.set_style(
91 indicatif::ProgressStyle::default_bar()
92 .template("{spinner:.green} [{elapsed_precise}] ({eta}) [{bar:40.cyan/blue}] {bytes}/{total_bytes} => {msg}").unwrap()
93 .progress_chars("#>-"),
94 );
95 bar.enable_steady_tick(Duration::from_millis(100));
96 Some(bar)
97 }
98 }
99}
100
101/// A reader that tracks progress of bytes read
102pub struct ProgressReader<R> {
103 inner: R,
104 bar: Option<ProgressBar>,
105 total_read: u64,
106 last_update: Instant,
107 bytes_since_update: u64,
108 bytes_per_update: u64,
109}
110
111impl<R: Read> ProgressReader<R> {
112 pub fn new(inner: R, bar: Option<ProgressBar>) -> Self {
113 ProgressReader {
114 inner,
115 bar,
116 total_read: 0,
117 last_update: Instant::now(),
118 bytes_since_update: 0,
119 bytes_per_update: 8192, // Start with 8KB, will adjust dynamically
120 }
121 }
122
123 /// Updates the progress bar if enough bytes have been read since the last update.
124 /// Dynamically adjusts the update frequency to target ~100ms between updates by
125 /// tracking the elapsed time and adjusting bytes_per_update accordingly.
126 fn maybe_update_progress(&mut self, bytes_read: u64) {
127 if let Some(ref bar) = self.bar {
128 self.bytes_since_update += bytes_read;
129
130 if self.bytes_since_update >= self.bytes_per_update {
131 let now = Instant::now();
132 let elapsed = now.duration_since(self.last_update);
133
134 // Update the progress
135 bar.set_position(self.total_read);
136
137 // Adjust bytes_per_update to target ~100ms between updates
138 if elapsed < Duration::from_millis(50) {
139 self.bytes_per_update *= 2;
140 } else if elapsed > Duration::from_millis(150) {
141 self.bytes_per_update = (self.bytes_per_update / 2).max(1024);
142 }
143
144 self.last_update = now;
145 self.bytes_since_update = 0;
146 }
147 }
148 }
149}
150
151impl<R: Read> Read for ProgressReader<R> {
152 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
153 let bytes_read = self.inner.read(buf)?;
154 if bytes_read > 0 {
155 self.total_read += bytes_read as u64;
156 self.maybe_update_progress(bytes_read as u64);
157 }
158 Ok(bytes_read)
159 }
160}
161
162/// A writer that tracks progress of bytes written
163pub struct ProgressWriter<W> {
164 inner: W,
165 bar: Option<ProgressBar>,
166 total_written: u64,
167 last_update: Instant,
168 bytes_since_update: u64,
169 bytes_per_update: u64,
170}
171
172impl<W: Write> ProgressWriter<W> {
173 pub fn new(inner: W, bar: Option<ProgressBar>) -> Self {
174 ProgressWriter {
175 inner,
176 bar,
177 total_written: 0,
178 last_update: Instant::now(),
179 bytes_since_update: 0,
180 bytes_per_update: 8192, // Start with 8KB, will adjust dynamically
181 }
182 }
183
184 pub fn finish(self) {
185 if let Some(bar) = self.bar {
186 bar.finish();
187 }
188 }
189
190 /// Updates the progress bar if enough bytes have been written since the last update.
191 /// Dynamically adjusts the update frequency to target ~100ms between updates by
192 /// tracking the elapsed time and adjusting bytes_per_update accordingly.
193 fn maybe_update_progress(&mut self, bytes_written: u64) {
194 if let Some(ref bar) = self.bar {
195 self.bytes_since_update += bytes_written;
196
197 if self.bytes_since_update >= self.bytes_per_update {
198 let now = Instant::now();
199 let elapsed = now.duration_since(self.last_update);
200
201 // Update the progress
202 bar.set_message(HumanBytes(self.total_written).to_string());
203
204 // Adjust bytes_per_update to target ~100ms between updates
205 if elapsed < Duration::from_millis(50) {
206 self.bytes_per_update *= 2;
207 } else if elapsed > Duration::from_millis(150) {
208 self.bytes_per_update = (self.bytes_per_update / 2).max(1024);
209 }
210
211 self.last_update = now;
212 self.bytes_since_update = 0;
213 }
214 }
215 }
216}
217
218impl<W: Write> Write for ProgressWriter<W> {
219 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
220 let bytes_written = self.inner.write(buf)?;
221 if bytes_written > 0 {
222 self.total_written += bytes_written as u64;
223 self.maybe_update_progress(bytes_written as u64);
224 }
225 Ok(bytes_written)
226 }
227
228 fn flush(&mut self) -> io::Result<()> {
229 self.inner.flush()
230 }
231}
232
233/// Process data with progress bar updates
234pub fn copy_with_progress<R: Read, W: Write>(
235 reader: R,
236 writer: W,
237 chunk_size: usize,
238 input_size: Option<u64>,
239 progress_display: ProgressDisplay,
240 output: &CmprssOutput,
241) -> io::Result<()> {
242 // Create the progress bar if needed
243 let progress_bar = create_progress_bar(input_size, progress_display, output);
244
245 // Create reader and writer with progress tracking
246 let mut reader = ProgressReader::new(reader, progress_bar.clone());
247 let mut writer = ProgressWriter::new(writer, progress_bar);
248
249 let mut buffer = vec![0; chunk_size];
250 loop {
251 let bytes_read = reader.read(&mut buffer)?;
252 if bytes_read == 0 {
253 break;
254 }
255 writer.write_all(&buffer[..bytes_read])?;
256 }
257 writer.flush()?;
258 writer.finish();
259 Ok(())
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn chunk_size_parsing() {
268 assert!(ChunkSize::from_str("0").is_err());
269 assert!(ChunkSize::from_str("0mb").is_err());
270 assert_eq!(
271 ChunkSize::from_str("1").unwrap(),
272 ChunkSize { size_in_bytes: 1 }
273 );
274 assert_eq!(
275 ChunkSize::from_str("1kb").unwrap(),
276 ChunkSize {
277 size_in_bytes: 1024
278 }
279 );
280 assert_eq!(
281 ChunkSize::from_str("16kib").unwrap(),
282 ChunkSize {
283 size_in_bytes: 16 * 1024
284 }
285 );
286 assert_eq!(
287 ChunkSize::from_str("8mib").unwrap(),
288 ChunkSize {
289 size_in_bytes: 8 * 1024 * 1024
290 }
291 );
292 assert_eq!(
293 ChunkSize::from_str("16mb").unwrap(),
294 ChunkSize {
295 size_in_bytes: 16 * 1024 * 1024
296 }
297 );
298 assert_eq!(
299 ChunkSize::from_str("1gb").unwrap(),
300 ChunkSize {
301 size_in_bytes: 1024 * 1024 * 1024
302 }
303 );
304 assert_eq!(
305 ChunkSize::from_str("16gib").unwrap(),
306 ChunkSize {
307 size_in_bytes: 16 * 1024 * 1024 * 1024
308 }
309 );
310 }
311}