this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

make varint encode/decode robust, add tests

+38 -6
+10 -6
src/atmst/blockstore/car_file.py
··· 6 6 from . import BlockStore 7 7 8 8 # should be equivalent to multiformats.varint.decode(), but not extremely slow for no reason. 9 - def parse_varint(stream: BinaryIO): 9 + def decode_varint(stream: BinaryIO): 10 10 n = 0 11 - shift = 0 12 - while True: 11 + for shift in range(0, 63, 7): 13 12 val = stream.read(1) 14 13 if not val: 15 - raise ValueError("eof") # match varint.decode() 14 + raise ValueError("unexpected end of varint input") 16 15 val = val[0] 17 16 n |= (val & 0x7f) << shift 18 17 if not val & 0x80: 18 + if shift and not val: 19 + raise ValueError("varint not minimally encoded") 19 20 return n 20 21 shift += 7 22 + raise ValueError("varint too long") 21 23 22 24 def encode_varint(n: int) -> bytes: 25 + if not 0 <= n < 2**63: 26 + raise ValueError("integer out of encodable varint range") 23 27 res = [] 24 28 while n > 0x7f: 25 29 res.append(0x80 | (n & 0x7f)) ··· 47 51 file.seek(0) 48 52 49 53 # parse out CAR header 50 - header_len = parse_varint(file) 54 + header_len = decode_varint(file) 51 55 header = file.read(header_len) 52 56 if len(header) != header_len: 53 57 raise EOFError("not enough CAR header bytes") ··· 62 66 self.block_offsets = {} 63 67 while True: 64 68 try: 65 - length = parse_varint(file) 69 + length = decode_varint(file) 66 70 except ValueError: 67 71 break # EOF 68 72 start = file.tell()
+28
tests/test_varint.py
··· 1 + import unittest 2 + import io 3 + 4 + from atmst.blockstore.car_file import decode_varint, encode_varint 5 + 6 + class MSTDiffTestCase(unittest.TestCase): 7 + def test_varint_encode(self): 8 + self.assertEqual(encode_varint(0), b"\x00") 9 + self.assertEqual(encode_varint(1), b"\x01") 10 + self.assertEqual(encode_varint(127), b"\x7f") 11 + self.assertEqual(encode_varint(128), b"\x80\x01") 12 + self.assertEqual(encode_varint(2**63-1), b'\xff\xff\xff\xff\xff\xff\xff\xff\x7f') 13 + self.assertRaises(ValueError, encode_varint, 2**63) 14 + self.assertRaises(ValueError, encode_varint, -1) 15 + 16 + def test_varint_decode(self): 17 + self.assertEqual(decode_varint(io.BytesIO(b"\x00")), 0) 18 + self.assertEqual(decode_varint(io.BytesIO(b"\x01")), 1) 19 + self.assertEqual(decode_varint(io.BytesIO(b"\x7f")), 127) 20 + self.assertEqual(decode_varint(io.BytesIO(b"\x80\x01")), 128) 21 + self.assertEqual(decode_varint(io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\x7f')), 2**63-1) 22 + self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7f')) # too big 23 + self.assertRaises(ValueError, decode_varint, io.BytesIO(b"")) # too short 24 + self.assertRaises(ValueError, decode_varint, io.BytesIO(b'\xff')) # truncated 25 + self.assertRaises(ValueError, decode_varint, io.BytesIO(b"\x80\x00")) # too minimally encoded 26 + 27 + if __name__ == '__main__': 28 + unittest.main(module="tests.test_varint")