diff --git a/crates/compression-codecs/src/gzip/decoder.rs b/crates/compression-codecs/src/gzip/decoder.rs index 647959c..44064b0 100644 --- a/crates/compression-codecs/src/gzip/decoder.rs +++ b/crates/compression-codecs/src/gzip/decoder.rs @@ -64,7 +64,8 @@ impl GzipDecoder { loop { match &mut self.state { State::Header(parser) => { - if parser.input(input)?.is_some() { + if parser.input(&mut self.crc, input)?.is_some() { + self.crc.reset(); self.state = State::Decoding; } } diff --git a/crates/compression-codecs/src/gzip/header.rs b/crates/compression-codecs/src/gzip/header.rs index 9c894e9..f28231b 100644 --- a/crates/compression-codecs/src/gzip/header.rs +++ b/crates/compression-codecs/src/gzip/header.rs @@ -1,4 +1,5 @@ use compression_core::util::PartialBuffer; +use flate2::Crc; use std::io; #[derive(Debug, Default)] @@ -61,18 +62,39 @@ impl Header { } } +fn consume_input(crc: &mut Crc, n: usize, input: &mut PartialBuffer<&[u8]>) { + crc.update(&input.unwritten()[..n]); + input.advance(n); +} + +fn consume_cstr(crc: &mut Crc, input: &mut PartialBuffer<&[u8]>) -> Option<()> { + if let Some(len) = memchr::memchr(0, input.unwritten()) { + consume_input(crc, len + 1, input); + Some(()) + } else { + consume_input(crc, input.unwritten().len(), input); + None + } +} + impl Parser { - pub(super) fn input(&mut self, input: &mut PartialBuffer<&[u8]>) -> io::Result> { + pub(super) fn input( + &mut self, + crc: &mut Crc, + input: &mut PartialBuffer<&[u8]>, + ) -> io::Result> { loop { match &mut self.state { State::Fixed(data) => { data.copy_unwritten_from(input); if data.unwritten().is_empty() { - self.header = Header::parse(&data.take().into_inner())?; + let data = data.get_mut(); + crc.update(data); + self.header = Header::parse(data)?; self.state = State::ExtraLen(<_>::default()); } else { - return Ok(None); + break Ok(None); } } @@ -85,22 +107,24 @@ impl Parser { data.copy_unwritten_from(input); if data.unwritten().is_empty() { - let len = u16::from_le_bytes(data.take().into_inner()); + let data = data.get_mut(); + crc.update(data); + let len = u16::from_le_bytes(*data); self.state = State::Extra(len.into()); } else { - return Ok(None); + break Ok(None); } } State::Extra(bytes_to_consume) => { let n = input.unwritten().len().min(*bytes_to_consume); *bytes_to_consume -= n; - input.advance(n); + consume_input(crc, n, input); if *bytes_to_consume == 0 { self.state = State::Filename; } else { - return Ok(None); + break Ok(None); } } @@ -110,13 +134,11 @@ impl Parser { continue; } - if let Some(len) = memchr::memchr(0, input.unwritten()) { - input.advance(len + 1); - self.state = State::Comment; - } else { - input.advance(input.unwritten().len()); - return Ok(None); + if consume_cstr(crc, input).is_none() { + break Ok(None); } + + self.state = State::Comment; } State::Comment => { @@ -125,35 +147,43 @@ impl Parser { continue; } - if let Some(len) = memchr::memchr(0, input.unwritten()) { - input.advance(len + 1); - self.state = State::Crc(<_>::default()); - } else { - input.advance(input.unwritten().len()); - return Ok(None); + if consume_cstr(crc, input).is_none() { + break Ok(None); } + + self.state = State::Crc(<_>::default()); } State::Crc(data) => { + let header = std::mem::take(&mut self.header); + if !self.header.flags.crc { self.state = State::Done; - return Ok(Some(std::mem::take(&mut self.header))); + break Ok(Some(header)); } data.copy_unwritten_from(input); - if data.unwritten().is_empty() { + break if data.unwritten().is_empty() { + let data = data.take().into_inner(); self.state = State::Done; - return Ok(Some(std::mem::take(&mut self.header))); + let checksum = crc.sum().to_le_bytes(); + + if data == checksum[..2] { + Ok(Some(header)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "CRC computed for header does not match", + )) + } } else { - return Ok(None); - } + Ok(None) + }; } - State::Done => { - return Err(io::Error::other("parser used after done")); - } - }; + State::Done => break Err(io::Error::other("parser used after done")), + } } } }