this repo has no description
1use crate::progress::{ProgressArgs, copy_with_progress};
2use crate::utils::*;
3use anyhow::bail;
4use clap::Args;
5use std::fs::File;
6use std::io::{self, BufReader, BufWriter, Read, Write};
7use zstd::stream::{read::Decoder, write::Encoder};
8
9/// Zstd-specific compression validator (-7 to 22 range)
10#[derive(Debug, Clone, Copy)]
11pub struct ZstdCompressionValidator;
12
13impl CompressionLevelValidator for ZstdCompressionValidator {
14 fn min_level(&self) -> i32 {
15 -7
16 }
17 fn max_level(&self) -> i32 {
18 22
19 }
20 fn default_level(&self) -> i32 {
21 1
22 }
23
24 fn name_to_level(&self, name: &str) -> Option<i32> {
25 match name.to_lowercase().as_str() {
26 "none" => Some(-7),
27 "fast" => Some(1),
28 "best" => Some(22),
29 _ => None,
30 }
31 }
32}
33
34#[derive(Args, Debug)]
35pub struct ZstdArgs {
36 #[clap(flatten)]
37 pub common_args: CommonArgs,
38
39 #[clap(flatten)]
40 pub level_args: LevelArgs,
41
42 #[clap(flatten)]
43 pub progress_args: ProgressArgs,
44}
45
46pub struct Zstd {
47 pub compression_level: i32,
48 pub progress_args: ProgressArgs,
49}
50
51impl Default for Zstd {
52 fn default() -> Self {
53 let validator = ZstdCompressionValidator;
54 Zstd {
55 compression_level: validator.default_level(),
56 progress_args: ProgressArgs::default(),
57 }
58 }
59}
60
61impl Zstd {
62 pub fn new(args: &ZstdArgs) -> Zstd {
63 let validator = ZstdCompressionValidator;
64 let mut level = args.level_args.level.level;
65
66 // Validate and clamp the level to zstd's valid range
67 level = validator.validate_and_clamp_level(level);
68
69 Zstd {
70 compression_level: level,
71 progress_args: args.progress_args,
72 }
73 }
74}
75
76impl Compressor for Zstd {
77 /// The standard extension for the zstd format.
78 fn extension(&self) -> &str {
79 "zst"
80 }
81
82 /// Full name for zstd.
83 fn name(&self) -> &str {
84 "zstd"
85 }
86
87 /// Compress an input file or pipe to a zstd archive
88 fn compress(&self, input: CmprssInput, output: CmprssOutput) -> Result {
89 if let CmprssOutput::Path(out_path) = &output
90 && out_path.is_dir()
91 {
92 bail!(
93 "Zstd does not support compressing to a directory. Please specify an output file."
94 );
95 }
96 if let CmprssInput::Path(input_paths) = &input {
97 for x in input_paths {
98 if x.is_dir() {
99 bail!(
100 "Zstd does not support compressing a directory. Please specify only files."
101 );
102 }
103 }
104 }
105 let mut file_size = None;
106 let mut input_stream: Box<dyn Read + Send> = match input {
107 CmprssInput::Path(paths) => {
108 if paths.len() > 1 {
109 bail!("Multiple input files not supported for zstd");
110 }
111 let path = &paths[0];
112 file_size = Some(std::fs::metadata(path)?.len());
113 Box::new(BufReader::new(File::open(path)?))
114 }
115 CmprssInput::Pipe(stdin) => Box::new(BufReader::new(stdin)),
116 CmprssInput::Reader(reader) => reader.0,
117 };
118
119 if let CmprssOutput::Writer(writer) = output {
120 let mut encoder = Encoder::new(writer, self.compression_level)?;
121 io::copy(&mut input_stream, &mut encoder)?;
122 encoder.finish()?;
123 } else {
124 let output_stream: Box<dyn Write + Send> = match &output {
125 CmprssOutput::Path(path) => Box::new(BufWriter::new(File::create(path)?)),
126 CmprssOutput::Pipe(stdout) => Box::new(BufWriter::new(stdout)),
127 CmprssOutput::Writer(_) => unreachable!(),
128 };
129 let mut encoder = Encoder::new(output_stream, self.compression_level)?;
130 copy_with_progress(
131 &mut input_stream,
132 &mut encoder,
133 self.progress_args.chunk_size.size_in_bytes,
134 file_size,
135 self.progress_args.progress,
136 &output,
137 )?;
138 encoder.finish()?;
139 }
140
141 Ok(())
142 }
143
144 /// Extract a zstd archive to an output file or pipe
145 fn extract(&self, input: CmprssInput, output: CmprssOutput) -> Result {
146 if let CmprssOutput::Path(out_path) = &output
147 && out_path.is_dir()
148 {
149 bail!(
150 "Zstd does not support extracting to a directory. Please specify an output file."
151 );
152 }
153
154 let mut file_size = None;
155 let input_stream: Box<dyn Read + Send> = match input {
156 CmprssInput::Path(paths) => {
157 if paths.len() > 1 {
158 bail!("Multiple input files not supported for zstd extraction");
159 }
160 let path = &paths[0];
161 file_size = Some(std::fs::metadata(path)?.len());
162 Box::new(BufReader::new(File::open(path)?))
163 }
164 CmprssInput::Pipe(stdin) => Box::new(BufReader::new(stdin)),
165 CmprssInput::Reader(reader) => reader.0,
166 };
167
168 let mut decoder = Decoder::new(input_stream)?;
169
170 if let CmprssOutput::Writer(mut writer) = output {
171 io::copy(&mut decoder, &mut writer)?;
172 } else {
173 let mut output_stream: Box<dyn Write + Send> = match &output {
174 CmprssOutput::Path(path) => Box::new(BufWriter::new(File::create(path)?)),
175 CmprssOutput::Pipe(stdout) => Box::new(BufWriter::new(stdout)),
176 CmprssOutput::Writer(_) => unreachable!(),
177 };
178 copy_with_progress(
179 &mut decoder,
180 &mut output_stream,
181 self.progress_args.chunk_size.size_in_bytes,
182 file_size,
183 self.progress_args.progress,
184 &output,
185 )?;
186 }
187
188 Ok(())
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::test_utils::*;
196
197 /// Test the basic interface of the Zstd compressor
198 #[test]
199 fn test_zstd_interface() {
200 let compressor = Zstd::default();
201 test_compressor_interface(&compressor, "zstd", Some("zst"));
202 }
203
204 /// Test the default compression level
205 #[test]
206 fn test_zstd_default_compression() -> Result {
207 let compressor = Zstd::default();
208 test_compression(&compressor)
209 }
210
211 /// Test fast compression level
212 #[test]
213 fn test_zstd_fast_compression() -> Result {
214 let fast_compressor = Zstd {
215 compression_level: 1,
216 progress_args: ProgressArgs::default(),
217 };
218 test_compression(&fast_compressor)
219 }
220
221 /// Test best compression level
222 #[test]
223 fn test_zstd_best_compression() -> Result {
224 let best_compressor = Zstd {
225 compression_level: 22,
226 progress_args: ProgressArgs::default(),
227 };
228 test_compression(&best_compressor)
229 }
230
231 #[test]
232 fn test_zstd_compression_validator() {
233 let validator = ZstdCompressionValidator;
234 test_compression_validator_helper(
235 &validator,
236 -7, // min_level
237 22, // max_level
238 1, // default_level
239 Some(1), // fast_name_level
240 Some(22), // best_name_level
241 Some(-7), // none_name_level
242 );
243 }
244}