this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at eeb60a130d5c62af7e5265c53679a7a52ce4f786 263 lines 8.5 kB view raw
1use crate::progress::{copy_with_progress, ProgressArgs}; 2use crate::utils::{ 3 cmprss_error, CmprssInput, CmprssOutput, CommonArgs, CompressionLevelValidator, Compressor, 4 LevelArgs, 5}; 6use clap::Args; 7use std::fs::File; 8use std::io::{self, BufReader, BufWriter, Read, Write}; 9use zstd::stream::{read::Decoder, write::Encoder}; 10 11/// Zstd-specific compression validator (-7 to 22 range) 12#[derive(Debug, Clone, Copy)] 13pub struct ZstdCompressionValidator; 14 15impl CompressionLevelValidator for ZstdCompressionValidator { 16 fn min_level(&self) -> i32 { 17 -7 18 } 19 fn max_level(&self) -> i32 { 20 22 21 } 22 fn default_level(&self) -> i32 { 23 1 24 } 25 26 fn name_to_level(&self, name: &str) -> Option<i32> { 27 match name.to_lowercase().as_str() { 28 "none" => Some(-7), 29 "fast" => Some(1), 30 "best" => Some(22), 31 _ => None, 32 } 33 } 34} 35 36#[derive(Args, Debug)] 37pub struct ZstdArgs { 38 #[clap(flatten)] 39 pub common_args: CommonArgs, 40 41 #[clap(flatten)] 42 pub level_args: LevelArgs, 43 44 #[clap(flatten)] 45 pub progress_args: ProgressArgs, 46} 47 48pub struct Zstd { 49 pub compression_level: i32, 50 pub progress_args: ProgressArgs, 51} 52 53impl Default for Zstd { 54 fn default() -> Self { 55 let validator = ZstdCompressionValidator; 56 Zstd { 57 compression_level: validator.default_level(), 58 progress_args: ProgressArgs::default(), 59 } 60 } 61} 62 63impl Zstd { 64 pub fn new(args: &ZstdArgs) -> Zstd { 65 let validator = ZstdCompressionValidator; 66 let mut level = args.level_args.level.level; 67 68 // Validate and clamp the level to zstd's valid range 69 level = validator.validate_and_clamp_level(level); 70 71 Zstd { 72 compression_level: level, 73 progress_args: args.progress_args, 74 } 75 } 76} 77 78impl Compressor for Zstd { 79 /// The standard extension for the zstd format. 80 fn extension(&self) -> &str { 81 "zst" 82 } 83 84 /// Full name for zstd. 85 fn name(&self) -> &str { 86 "zstd" 87 } 88 89 /// Generate a default extracted filename 90 /// zstd does not support extracting to a directory, so we return a default filename 91 fn default_extracted_filename(&self, in_path: &std::path::Path) -> String { 92 // If the file has no extension, return a default filename 93 if in_path.extension().is_none() { 94 return "archive".to_string(); 95 } 96 // Otherwise, return the filename without the extension 97 in_path.file_stem().unwrap().to_str().unwrap().to_string() 98 } 99 100 /// Compress an input file or pipe to a zstd archive 101 fn compress(&self, input: CmprssInput, output: CmprssOutput) -> Result<(), io::Error> { 102 if let CmprssOutput::Path(out_path) = &output { 103 if out_path.is_dir() { 104 return cmprss_error("Zstd does not support compressing to a directory. Please specify an output file."); 105 } 106 } 107 if let CmprssInput::Path(input_paths) = &input { 108 for x in input_paths { 109 if x.is_dir() { 110 return cmprss_error( 111 "Zstd does not support compressing a directory. Please specify only files.", 112 ); 113 } 114 } 115 } 116 let mut file_size = None; 117 let mut input_stream: Box<dyn Read + Send> = match input { 118 CmprssInput::Path(paths) => { 119 if paths.len() > 1 { 120 return Err(io::Error::new( 121 io::ErrorKind::InvalidInput, 122 "Multiple input files not supported for zstd", 123 )); 124 } 125 let path = &paths[0]; 126 file_size = Some(std::fs::metadata(path)?.len()); 127 Box::new(BufReader::new(File::open(path)?)) 128 } 129 CmprssInput::Pipe(stdin) => Box::new(BufReader::new(stdin)), 130 }; 131 132 let output_stream: Box<dyn Write + Send> = match &output { 133 CmprssOutput::Path(path) => Box::new(BufWriter::new(File::create(path)?)), 134 CmprssOutput::Pipe(stdout) => Box::new(BufWriter::new(stdout)), 135 }; 136 137 // Create a zstd encoder with the specified compression level 138 let mut encoder = Encoder::new(output_stream, self.compression_level)?; 139 140 // Copy the input to the encoder with progress reporting 141 copy_with_progress( 142 &mut input_stream, 143 &mut encoder, 144 self.progress_args.chunk_size.size_in_bytes, 145 file_size, 146 self.progress_args.progress, 147 &output, 148 )?; 149 150 // Finish the encoder to ensure all data is written 151 encoder.finish()?; 152 153 Ok(()) 154 } 155 156 /// Extract a zstd archive to an output file or pipe 157 fn extract(&self, input: CmprssInput, output: CmprssOutput) -> Result<(), io::Error> { 158 if let CmprssOutput::Path(out_path) = &output { 159 if out_path.is_dir() { 160 return cmprss_error("Zstd does not support extracting to a directory. Please specify an output file."); 161 } 162 } 163 164 let input_stream: Box<dyn Read + Send> = match input { 165 CmprssInput::Path(paths) => { 166 if paths.len() > 1 { 167 return Err(io::Error::new( 168 io::ErrorKind::InvalidInput, 169 "Multiple input files not supported for zstd", 170 )); 171 } 172 let path = &paths[0]; 173 Box::new(BufReader::new(File::open(path)?)) 174 } 175 CmprssInput::Pipe(stdin) => Box::new(BufReader::new(stdin)), 176 }; 177 178 // Create a zstd decoder 179 let mut decoder = Decoder::new(input_stream)?; 180 181 let mut output_stream: Box<dyn Write + Send> = match &output { 182 CmprssOutput::Path(path) => Box::new(BufWriter::new(File::create(path)?)), 183 CmprssOutput::Pipe(stdout) => Box::new(BufWriter::new(stdout)), 184 }; 185 186 // Copy the decoded data to the output with progress reporting 187 copy_with_progress( 188 &mut decoder, 189 &mut output_stream, 190 self.progress_args.chunk_size.size_in_bytes, 191 None, 192 self.progress_args.progress, 193 &output, 194 )?; 195 196 Ok(()) 197 } 198} 199 200#[cfg(test)] 201mod tests { 202 use super::*; 203 use tempfile::tempdir; 204 205 #[test] 206 fn roundtrip() -> Result<(), Box<dyn std::error::Error>> { 207 let dir = tempdir()?; 208 let input_path = dir.path().join("input.txt"); 209 let compressed_path = dir.path().join("input.txt.zst"); 210 let output_path = dir.path().join("output.txt"); 211 212 // Create a test file 213 let test_data = b"Hello, world! This is a test file for zstd compression."; 214 std::fs::write(&input_path, test_data)?; 215 216 // Compress the file 217 let zstd = Zstd::default(); 218 zstd.compress( 219 CmprssInput::Path(vec![input_path.clone()]), 220 CmprssOutput::Path(compressed_path.clone()), 221 )?; 222 223 // Extract the file 224 zstd.extract( 225 CmprssInput::Path(vec![compressed_path]), 226 CmprssOutput::Path(output_path.clone()), 227 )?; 228 229 // Verify the contents 230 let output_data = std::fs::read(output_path)?; 231 assert_eq!(test_data.to_vec(), output_data); 232 233 Ok(()) 234 } 235 236 #[test] 237 fn test_zstd_compression_validator() { 238 let validator = ZstdCompressionValidator; 239 240 // Test range 241 assert_eq!(validator.min_level(), -7); 242 assert_eq!(validator.max_level(), 22); 243 assert_eq!(validator.default_level(), 1); 244 245 // Test validation 246 assert!(validator.is_valid_level(-7)); 247 assert!(validator.is_valid_level(0)); 248 assert!(validator.is_valid_level(22)); 249 assert!(!validator.is_valid_level(-8)); 250 assert!(!validator.is_valid_level(23)); 251 252 // Test clamping 253 assert_eq!(validator.validate_and_clamp_level(-8), -7); 254 assert_eq!(validator.validate_and_clamp_level(0), 0); 255 assert_eq!(validator.validate_and_clamp_level(23), 22); 256 257 // Test special names 258 assert_eq!(validator.name_to_level("none"), Some(-7)); 259 assert_eq!(validator.name_to_level("fast"), Some(1)); 260 assert_eq!(validator.name_to_level("best"), Some(22)); 261 assert_eq!(validator.name_to_level("invalid"), None); 262 } 263}