this repo has no description
1use crate::progress::{copy_with_progress, ProgressArgs};
2use crate::utils::{
3 cmprss_error, CmprssInput, CmprssOutput, CommonArgs, CompressionLevelValidator, Compressor,
4 LevelArgs,
5};
6use clap::Args;
7use std::fs::File;
8use std::io::{self, BufReader, BufWriter, Read, Write};
9use zstd::stream::{read::Decoder, write::Encoder};
10
11/// Zstd-specific compression validator (-7 to 22 range)
12#[derive(Debug, Clone, Copy)]
13pub struct ZstdCompressionValidator;
14
15impl CompressionLevelValidator for ZstdCompressionValidator {
16 fn min_level(&self) -> i32 {
17 -7
18 }
19 fn max_level(&self) -> i32 {
20 22
21 }
22 fn default_level(&self) -> i32 {
23 1
24 }
25
26 fn name_to_level(&self, name: &str) -> Option<i32> {
27 match name.to_lowercase().as_str() {
28 "none" => Some(-7),
29 "fast" => Some(1),
30 "best" => Some(22),
31 _ => None,
32 }
33 }
34}
35
36#[derive(Args, Debug)]
37pub struct ZstdArgs {
38 #[clap(flatten)]
39 pub common_args: CommonArgs,
40
41 #[clap(flatten)]
42 pub level_args: LevelArgs,
43
44 #[clap(flatten)]
45 pub progress_args: ProgressArgs,
46}
47
48pub struct Zstd {
49 pub compression_level: i32,
50 pub progress_args: ProgressArgs,
51}
52
53impl Default for Zstd {
54 fn default() -> Self {
55 let validator = ZstdCompressionValidator;
56 Zstd {
57 compression_level: validator.default_level(),
58 progress_args: ProgressArgs::default(),
59 }
60 }
61}
62
63impl Zstd {
64 pub fn new(args: &ZstdArgs) -> Zstd {
65 let validator = ZstdCompressionValidator;
66 let mut level = args.level_args.level.level;
67
68 // Validate and clamp the level to zstd's valid range
69 level = validator.validate_and_clamp_level(level);
70
71 Zstd {
72 compression_level: level,
73 progress_args: args.progress_args,
74 }
75 }
76}
77
78impl Compressor for Zstd {
79 /// The standard extension for the zstd format.
80 fn extension(&self) -> &str {
81 "zst"
82 }
83
84 /// Full name for zstd.
85 fn name(&self) -> &str {
86 "zstd"
87 }
88
89 /// Generate a default extracted filename
90 /// zstd does not support extracting to a directory, so we return a default filename
91 fn default_extracted_filename(&self, in_path: &std::path::Path) -> String {
92 // If the file has no extension, return a default filename
93 if in_path.extension().is_none() {
94 return "archive".to_string();
95 }
96 // Otherwise, return the filename without the extension
97 in_path.file_stem().unwrap().to_str().unwrap().to_string()
98 }
99
100 /// Compress an input file or pipe to a zstd archive
101 fn compress(&self, input: CmprssInput, output: CmprssOutput) -> Result<(), io::Error> {
102 if let CmprssOutput::Path(out_path) = &output {
103 if out_path.is_dir() {
104 return cmprss_error("Zstd does not support compressing to a directory. Please specify an output file.");
105 }
106 }
107 if let CmprssInput::Path(input_paths) = &input {
108 for x in input_paths {
109 if x.is_dir() {
110 return cmprss_error(
111 "Zstd does not support compressing a directory. Please specify only files.",
112 );
113 }
114 }
115 }
116 let mut file_size = None;
117 let mut input_stream: Box<dyn Read + Send> = match input {
118 CmprssInput::Path(paths) => {
119 if paths.len() > 1 {
120 return Err(io::Error::new(
121 io::ErrorKind::InvalidInput,
122 "Multiple input files not supported for zstd",
123 ));
124 }
125 let path = &paths[0];
126 file_size = Some(std::fs::metadata(path)?.len());
127 Box::new(BufReader::new(File::open(path)?))
128 }
129 CmprssInput::Pipe(stdin) => Box::new(BufReader::new(stdin)),
130 };
131
132 let output_stream: Box<dyn Write + Send> = match &output {
133 CmprssOutput::Path(path) => Box::new(BufWriter::new(File::create(path)?)),
134 CmprssOutput::Pipe(stdout) => Box::new(BufWriter::new(stdout)),
135 };
136
137 // Create a zstd encoder with the specified compression level
138 let mut encoder = Encoder::new(output_stream, self.compression_level)?;
139
140 // Copy the input to the encoder with progress reporting
141 copy_with_progress(
142 &mut input_stream,
143 &mut encoder,
144 self.progress_args.chunk_size.size_in_bytes,
145 file_size,
146 self.progress_args.progress,
147 &output,
148 )?;
149
150 // Finish the encoder to ensure all data is written
151 encoder.finish()?;
152
153 Ok(())
154 }
155
156 /// Extract a zstd archive to an output file or pipe
157 fn extract(&self, input: CmprssInput, output: CmprssOutput) -> Result<(), io::Error> {
158 if let CmprssOutput::Path(out_path) = &output {
159 if out_path.is_dir() {
160 return cmprss_error("Zstd does not support extracting to a directory. Please specify an output file.");
161 }
162 }
163
164 let input_stream: Box<dyn Read + Send> = match input {
165 CmprssInput::Path(paths) => {
166 if paths.len() > 1 {
167 return Err(io::Error::new(
168 io::ErrorKind::InvalidInput,
169 "Multiple input files not supported for zstd",
170 ));
171 }
172 let path = &paths[0];
173 Box::new(BufReader::new(File::open(path)?))
174 }
175 CmprssInput::Pipe(stdin) => Box::new(BufReader::new(stdin)),
176 };
177
178 // Create a zstd decoder
179 let mut decoder = Decoder::new(input_stream)?;
180
181 let mut output_stream: Box<dyn Write + Send> = match &output {
182 CmprssOutput::Path(path) => Box::new(BufWriter::new(File::create(path)?)),
183 CmprssOutput::Pipe(stdout) => Box::new(BufWriter::new(stdout)),
184 };
185
186 // Copy the decoded data to the output with progress reporting
187 copy_with_progress(
188 &mut decoder,
189 &mut output_stream,
190 self.progress_args.chunk_size.size_in_bytes,
191 None,
192 self.progress_args.progress,
193 &output,
194 )?;
195
196 Ok(())
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use tempfile::tempdir;
204
205 #[test]
206 fn roundtrip() -> Result<(), Box<dyn std::error::Error>> {
207 let dir = tempdir()?;
208 let input_path = dir.path().join("input.txt");
209 let compressed_path = dir.path().join("input.txt.zst");
210 let output_path = dir.path().join("output.txt");
211
212 // Create a test file
213 let test_data = b"Hello, world! This is a test file for zstd compression.";
214 std::fs::write(&input_path, test_data)?;
215
216 // Compress the file
217 let zstd = Zstd::default();
218 zstd.compress(
219 CmprssInput::Path(vec![input_path.clone()]),
220 CmprssOutput::Path(compressed_path.clone()),
221 )?;
222
223 // Extract the file
224 zstd.extract(
225 CmprssInput::Path(vec![compressed_path]),
226 CmprssOutput::Path(output_path.clone()),
227 )?;
228
229 // Verify the contents
230 let output_data = std::fs::read(output_path)?;
231 assert_eq!(test_data.to_vec(), output_data);
232
233 Ok(())
234 }
235
236 #[test]
237 fn test_zstd_compression_validator() {
238 let validator = ZstdCompressionValidator;
239
240 // Test range
241 assert_eq!(validator.min_level(), -7);
242 assert_eq!(validator.max_level(), 22);
243 assert_eq!(validator.default_level(), 1);
244
245 // Test validation
246 assert!(validator.is_valid_level(-7));
247 assert!(validator.is_valid_level(0));
248 assert!(validator.is_valid_level(22));
249 assert!(!validator.is_valid_level(-8));
250 assert!(!validator.is_valid_level(23));
251
252 // Test clamping
253 assert_eq!(validator.validate_and_clamp_level(-8), -7);
254 assert_eq!(validator.validate_and_clamp_level(0), 0);
255 assert_eq!(validator.validate_and_clamp_level(23), 22);
256
257 // Test special names
258 assert_eq!(validator.name_to_level("none"), Some(-7));
259 assert_eq!(validator.name_to_level("fast"), Some(1));
260 assert_eq!(validator.name_to_level("best"), Some(22));
261 assert_eq!(validator.name_to_level("invalid"), None);
262 }
263}