Protocol Buffers codec for hand-written schemas
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

protobuf: replace record-of-closures codec with a format-centric GADT

The previous [type 'a t = { wire_type; write_value; read_wire; ... }]
was a record of closures. Interpreters couldn't be added without
editing every combinator, the structure was opaque to tooling, and
the shape didn't match the ocaml-json / encodings-skill design.

Rewrite as a finally-tagged GADT whose constructors name protobuf's
wire-level alphabet:

type _ t =
| Varint : (int64, 'a) base -> 'a t
| Fixed32_t : (int32, 'a) base -> 'a t
| Fixed64_t : (int64, 'a) base -> 'a t
| Length_delim : (string, 'a) base -> 'a t
| Message : 'a message_spec -> 'a t
| Rec : 'a t Lazy.t -> 'a t

Each scalar codec now produces a typed GADT node carrying its
[Sort.t] (one of the 15 protobuf scalar types — int32, uint32,
sint32, fixed32, sfixed32, float, ..., bytes, message). Sort feeds
into error messages: instead of "expected varint, got
length-delimited" the decoder now says "int32: expected wire type
varint, got length-delimited", which is what users want when a schema
says [int32 a = 1] but the wire carries a length-delim.

[fix] switches from a mutable forwarding placeholder to a [Lazy]-
wrapped [Rec] node. Cleaner: the recursive-codec forcing is explicit
in the GADT shape.

encode_value / decode_value are now [type a. a t -> ...] walkers that
pattern-match on the wire sort. Adding a new interpreter (schema
printer, pp, diff) is adding a new walker alongside these, no change
to the combinator call sites.

Message combinators (Message.required / optional / repeated / packed
and the [let*] chain) retain their shape at the user-facing level;
internally [Message.finish] now produces a [Message { encode_body;
decode_body; msg_default }] GADT node.

Still to do from the encoding skill:
- Split into separate [value.ml] / [codec.ml] / [error.ml] / [foo.ml]
layer files (this commit keeps everything in protobuf.ml for
minimal diff).
- Expose [Protobuf.Value.t] AST and [Cursor].
- Migrate errors to [Loc.Error.kind].
- Add six-verb API (of_string / to_string / of_reader / to_writer /
decode / encode) with _exn twins.

All 40 unit + 17 fuzz + 2 protoc interop tests pass.

+443 -237
-1
fuzz/dune
··· 8 8 9 9 (executable 10 10 (name fuzz) 11 - (modules fuzz fuzz_protobuf) 12 11 (libraries protobuf bytesrw alcobar)) 13 12 14 13 (rule
+387 -230
lib/protobuf.ml
··· 1 + (* Protocol Buffers codec, finally-tagged. 2 + 3 + The top-level alphabet names the four protobuf wire types plus 4 + message-level composition. Interpreters (encode, decode) walk the 5 + [Codec.t] GADT; adding a new interpreter (schema extraction, 6 + pretty-printer, diff) is adding a new walker without touching the 7 + combinator call sites. *) 8 + 1 9 module Wire = Wire 2 10 3 - (* A pre-parsed wire field value. The message decoder parses the byte stream 4 - into a [(tag, wire_value list)] table before running the field GADT, so 5 - tags can appear in any order on the wire. *) 11 + (* -- Nested-message depth tracking. 12 + 13 + An adversarial input with thousands of levels of nested Length_delim 14 + fields would stack-overflow the runtime: each nesting level is one 15 + OCaml stack frame in the recursive message decoder. Bound the depth 16 + at [max_depth] by default (matching protoc's C++/Java default) and 17 + fail the decode with a clean [Decode_error] when exceeded. *) 18 + let max_depth = 100 19 + let depth = ref 0 20 + 21 + let with_depth_check f = 22 + if !depth >= max_depth then 23 + raise 24 + (Wire.Decode_error (Fmt.str "nested message depth %d exceeded" max_depth)); 25 + incr depth; 26 + Fun.protect ~finally:(fun () -> decr depth) f 27 + 28 + (* -- Pre-parsed wire values. 29 + 30 + The message decoder parses the byte stream into a [(tag, wire_value 31 + list)] table before running the field GADT, so tags can appear in 32 + any order on the wire. [wire_value] is internal — the public GADT 33 + [t] names the logical types, not the raw bytes. *) 6 34 type wire_value = 7 - | Varint of int64 8 - | Fixed32 of int32 9 - | Fixed64 of int64 10 - | Length_delim of string 35 + | WV_varint of int64 36 + | WV_fixed32 of int32 37 + | WV_fixed64 of int64 38 + | WV_length_delim of string 11 39 12 40 let wire_value_type = function 41 + | WV_varint _ -> Wire.Varint 42 + | WV_fixed32 _ -> Wire.Fixed32 43 + | WV_fixed64 _ -> Wire.Fixed64 44 + | WV_length_delim _ -> Wire.Length_delimited 45 + 46 + (* Sort: the 15 protobuf scalar types plus Message. 47 + Used for error messages like "expected int32, got string". *) 48 + module Sort = struct 49 + type t = 50 + | Int32 51 + | Int64 52 + | Uint32 53 + | Uint64 54 + | Sint32 55 + | Sint64 56 + | Fixed32 57 + | Fixed64 58 + | Sfixed32 59 + | Sfixed64 60 + | Float 61 + | Double 62 + | Bool 63 + | String 64 + | Bytes 65 + | Message 66 + 67 + let to_string = function 68 + | Int32 -> "int32" 69 + | Int64 -> "int64" 70 + | Uint32 -> "uint32" 71 + | Uint64 -> "uint64" 72 + | Sint32 -> "sint32" 73 + | Sint64 -> "sint64" 74 + | Fixed32 -> "fixed32" 75 + | Fixed64 -> "fixed64" 76 + | Sfixed32 -> "sfixed32" 77 + | Sfixed64 -> "sfixed64" 78 + | Float -> "float" 79 + | Double -> "double" 80 + | Bool -> "bool" 81 + | String -> "string" 82 + | Bytes -> "bytes" 83 + | Message -> "message" 84 + 85 + let pp ppf s = Fmt.string ppf (to_string s) 86 + end 87 + 88 + (* Typed conversion from a wire-level representation to an OCaml value. 89 + The wire representation is determined by which GADT constructor 90 + wraps this record: [Varint] pairs with [int64], [Fixed32] with 91 + [int32], etc. *) 92 + type ('w, 'a) base = { 93 + sort : Sort.t; 94 + dec : 'w -> 'a; 95 + enc : 'a -> 'w; 96 + default : 'a; 97 + } 98 + 99 + (* Internal encoder/decoder pair for a message. Walking the [Message] 100 + node branches into these; they are not part of the public 101 + combinator vocabulary. *) 102 + type 'o message_spec = { 103 + encode_body : Buffer.t -> 'o -> unit; 104 + decode_body : string -> int -> int -> 'o; 105 + msg_default : 'o; 106 + } 107 + 108 + (* The Codec GADT. 109 + 110 + Each constructor names a FORMAT-level sort (wire type for scalars, 111 + plus composition/recursion). Users build codecs via the combinators 112 + below; interpreters destructure. *) 113 + type _ t = 114 + | Varint : (int64, 'a) base -> 'a t 115 + | Fixed32_t : (int32, 'a) base -> 'a t 116 + | Fixed64_t : (int64, 'a) base -> 'a t 117 + | Length_delim : (string, 'a) base -> 'a t 118 + | Message : 'a message_spec -> 'a t 119 + | Rec : 'a t Lazy.t -> 'a t 120 + 121 + (* Expose a few witnesses so callers can pattern-match the wire type 122 + without destructuring the GADT (useful for field-level code). *) 123 + 124 + let wire_type_of : type a. a t -> Wire.wire_type = function 13 125 | Varint _ -> Wire.Varint 14 - | Fixed32 _ -> Wire.Fixed32 15 - | Fixed64 _ -> Wire.Fixed64 126 + | Fixed32_t _ -> Wire.Fixed32 127 + | Fixed64_t _ -> Wire.Fixed64 16 128 | Length_delim _ -> Wire.Length_delimited 129 + | Message _ -> Wire.Length_delimited 130 + | Rec c -> ( 131 + (* The Lazy may not be forced yet; peek safely. *) 132 + match Lazy.force c with 133 + | Varint _ -> Wire.Varint 134 + | Fixed32_t _ -> Wire.Fixed32 135 + | Fixed64_t _ -> Wire.Fixed64 136 + | Length_delim _ -> Wire.Length_delimited 137 + | Message _ -> Wire.Length_delimited 138 + | Rec _ -> Wire.Length_delimited) 17 139 18 - (* Unified codec. [write_value] writes a full wire value (length prefix 19 - included for length-delimited types). [write_body] writes the unwrapped 20 - body (same as [write_value] for scalars; drops the length prefix for 21 - length-delimited types) — used at top level where the outer length is 22 - implicit. *) 23 - type 'a t = { 24 - wire_type : Wire.wire_type; 25 - write_value : Buffer.t -> 'a -> unit; 26 - write_body : Buffer.t -> 'a -> unit; 27 - (* Reads from a pre-parsed wire value. Raises {!Wire.Decode_error} on type 28 - mismatch. *) 29 - read_wire : wire_value -> 'a; 30 - (* Reads from raw bytes at an offset, for packed repeated decoding. Returns 31 - [(value, new_offset)]. Only meaningful for varint, fixed32, fixed64 32 - codecs; length-delimited codecs raise if asked. *) 33 - read_bytes : string -> int -> 'a * int; 34 - default : 'a; 35 - } 140 + let default_of : type a. a t -> a = function 141 + | Varint b -> b.default 142 + | Fixed32_t b -> b.default 143 + | Fixed64_t b -> b.default 144 + | Length_delim b -> b.default 145 + | Message m -> m.msg_default 146 + | Rec c -> ( 147 + match Lazy.force c with 148 + | Varint b -> b.default 149 + | Fixed32_t b -> b.default 150 + | Fixed64_t b -> b.default 151 + | Length_delim b -> b.default 152 + | Message m -> m.msg_default 153 + | Rec _ -> assert false) 36 154 37 - (* -- Scalars -- *) 155 + (* -- Wire-value extraction, typed errors -- *) 38 156 39 - let type_error expected got = 157 + let type_error ~sort expected got = 40 158 raise 41 159 (Wire.Decode_error 42 - (Fmt.str "type mismatch: expected %a, got %a" Wire.pp_wire_type expected 43 - Wire.pp_wire_type (wire_value_type got))) 160 + (Fmt.str "%a: expected wire type %a, got %a" Sort.pp sort 161 + Wire.pp_wire_type expected Wire.pp_wire_type (wire_value_type got))) 162 + 163 + let varint_of ~sort = function 164 + | WV_varint v -> v 165 + | w -> type_error ~sort Wire.Varint w 166 + 167 + let fixed32_of ~sort = function 168 + | WV_fixed32 v -> v 169 + | w -> type_error ~sort Wire.Fixed32 w 170 + 171 + let fixed64_of ~sort = function 172 + | WV_fixed64 v -> v 173 + | w -> type_error ~sort Wire.Fixed64 w 174 + 175 + let length_delim_of ~sort = function 176 + | WV_length_delim s -> s 177 + | w -> type_error ~sort Wire.Length_delimited w 44 178 45 - let varint_of = function Varint v -> v | w -> type_error Wire.Varint w 46 - let fixed32_of = function Fixed32 v -> v | w -> type_error Wire.Fixed32 w 47 - let fixed64_of = function Fixed64 v -> v | w -> type_error Wire.Fixed64 w 179 + (* -- Walk-based encode / decode over the GADT -- *) 48 180 49 - let length_delim_of = function 50 - | Length_delim s -> s 51 - | w -> type_error Wire.Length_delimited w 181 + let rec decode_value : type a. a t -> wire_value -> a = 182 + fun codec w -> 183 + match codec with 184 + | Varint b -> b.dec (varint_of ~sort:b.sort w) 185 + | Fixed32_t b -> b.dec (fixed32_of ~sort:b.sort w) 186 + | Fixed64_t b -> b.dec (fixed64_of ~sort:b.sort w) 187 + | Length_delim b -> b.dec (length_delim_of ~sort:b.sort w) 188 + | Message m -> 189 + let body = length_delim_of ~sort:Sort.Message w in 190 + with_depth_check (fun () -> m.decode_body body 0 (String.length body)) 191 + | Rec c -> decode_value (Lazy.force c) w 192 + 193 + (* [decode_bytes] reads a bare value at a byte offset. Only used for 194 + packed decoding where the values are concatenated without tags. 195 + Length-delimited codecs are not packable. *) 196 + let rec decode_bytes : type a. a t -> string -> int -> a * int = 197 + fun codec s off -> 198 + match codec with 199 + | Varint b -> 200 + let v, off' = Wire.read_int64 s off in 201 + (b.dec v, off') 202 + | Fixed32_t b -> 203 + let v, off' = Wire.read_fixed32 s off in 204 + (b.dec v, off') 205 + | Fixed64_t b -> 206 + let v, off' = Wire.read_fixed64 s off in 207 + (b.dec v, off') 208 + | Length_delim _ | Message _ -> 209 + raise 210 + (Wire.Decode_error 211 + "length-delimited codec cannot appear inside a packed field") 212 + | Rec c -> decode_bytes (Lazy.force c) s off 213 + 214 + let rec write_value : type a. Buffer.t -> a t -> a -> unit = 215 + fun buf codec v -> 216 + match codec with 217 + | Varint b -> Wire.write_int64 buf (b.enc v) 218 + | Fixed32_t b -> Wire.write_fixed32 buf (b.enc v) 219 + | Fixed64_t b -> Wire.write_fixed64 buf (b.enc v) 220 + | Length_delim b -> Wire.write_string buf (b.enc v) 221 + | Message m -> 222 + let body = Buffer.create 64 in 223 + m.encode_body body v; 224 + Leb128.add_u63_to_buffer buf (Buffer.length body); 225 + Buffer.add_buffer buf body 226 + | Rec c -> write_value buf (Lazy.force c) v 227 + 228 + (* -- Scalar codecs: 15 protobuf scalar types, grouped by wire type. -- *) 52 229 53 230 let int32 : int32 t = 54 - { 55 - wire_type = Varint; 56 - write_value = Wire.write_int32; 57 - write_body = Wire.write_int32; 58 - read_wire = (fun w -> Int64.to_int32 (varint_of w)); 59 - read_bytes = Wire.read_int32; 60 - default = 0l; 61 - } 231 + Varint 232 + { 233 + sort = Int32; 234 + dec = Int64.to_int32; 235 + enc = Int64.of_int32; 236 + default = 0l; 237 + } 62 238 63 239 let int64 : int64 t = 64 - { 65 - wire_type = Varint; 66 - write_value = Wire.write_int64; 67 - write_body = Wire.write_int64; 68 - read_wire = (fun w -> varint_of w); 69 - read_bytes = Wire.read_int64; 70 - default = 0L; 71 - } 240 + Varint { sort = Int64; dec = (fun x -> x); enc = (fun x -> x); default = 0L } 72 241 73 242 let uint32 : int32 t = 74 - { 75 - wire_type = Varint; 76 - write_value = Wire.write_uint32; 77 - write_body = Wire.write_uint32; 78 - read_wire = 79 - (fun w -> Int64.to_int32 (Int64.logand (varint_of w) 0xFFFF_FFFFL)); 80 - read_bytes = Wire.read_uint32; 81 - default = 0l; 82 - } 243 + Varint 244 + { 245 + sort = Uint32; 246 + dec = (fun x -> Int64.to_int32 (Int64.logand x 0xFFFF_FFFFL)); 247 + enc = (fun x -> Int64.logand (Int64.of_int32 x) 0xFFFF_FFFFL); 248 + default = 0l; 249 + } 83 250 84 251 let uint64 : int64 t = 85 - { 86 - wire_type = Varint; 87 - write_value = Wire.write_uint64; 88 - write_body = Wire.write_uint64; 89 - read_wire = (fun w -> varint_of w); 90 - read_bytes = Wire.read_uint64; 91 - default = 0L; 92 - } 252 + Varint 253 + { sort = Uint64; dec = (fun x -> x); enc = (fun x -> x); default = 0L } 93 254 94 255 let sint32 : int32 t = 95 - { 96 - wire_type = Varint; 97 - write_value = Wire.write_sint32; 98 - write_body = Wire.write_sint32; 99 - read_wire = 100 - (fun w -> Int64.to_int32 (Leb128.zigzag_decode_i64 (varint_of w))); 101 - read_bytes = Wire.read_sint32; 102 - default = 0l; 103 - } 256 + Varint 257 + { 258 + sort = Sint32; 259 + dec = (fun x -> Int64.to_int32 (Leb128.zigzag_decode_i64 x)); 260 + enc = (fun x -> Leb128.zigzag_encode_i64 (Int64.of_int32 x)); 261 + default = 0l; 262 + } 104 263 105 264 let sint64 : int64 t = 106 - { 107 - wire_type = Varint; 108 - write_value = Wire.write_sint64; 109 - write_body = Wire.write_sint64; 110 - read_wire = (fun w -> Leb128.zigzag_decode_i64 (varint_of w)); 111 - read_bytes = Wire.read_sint64; 112 - default = 0L; 113 - } 265 + Varint 266 + { 267 + sort = Sint64; 268 + dec = Leb128.zigzag_decode_i64; 269 + enc = Leb128.zigzag_encode_i64; 270 + default = 0L; 271 + } 114 272 115 273 let fixed32 : int32 t = 116 - { 117 - wire_type = Fixed32; 118 - write_value = Wire.write_fixed32; 119 - write_body = Wire.write_fixed32; 120 - read_wire = (fun w -> fixed32_of w); 121 - read_bytes = Wire.read_fixed32; 122 - default = 0l; 123 - } 274 + Fixed32_t 275 + { sort = Fixed32; dec = (fun x -> x); enc = (fun x -> x); default = 0l } 124 276 125 277 let fixed64 : int64 t = 126 - { 127 - wire_type = Fixed64; 128 - write_value = Wire.write_fixed64; 129 - write_body = Wire.write_fixed64; 130 - read_wire = (fun w -> fixed64_of w); 131 - read_bytes = Wire.read_fixed64; 132 - default = 0L; 133 - } 278 + Fixed64_t 279 + { sort = Fixed64; dec = (fun x -> x); enc = (fun x -> x); default = 0L } 134 280 135 281 let sfixed32 : int32 t = 136 - { 137 - wire_type = Fixed32; 138 - write_value = Wire.write_sfixed32; 139 - write_body = Wire.write_sfixed32; 140 - read_wire = (fun w -> fixed32_of w); 141 - read_bytes = Wire.read_sfixed32; 142 - default = 0l; 143 - } 282 + Fixed32_t 283 + { sort = Sfixed32; dec = (fun x -> x); enc = (fun x -> x); default = 0l } 144 284 145 285 let sfixed64 : int64 t = 146 - { 147 - wire_type = Fixed64; 148 - write_value = Wire.write_sfixed64; 149 - write_body = Wire.write_sfixed64; 150 - read_wire = (fun w -> fixed64_of w); 151 - read_bytes = Wire.read_sfixed64; 152 - default = 0L; 153 - } 286 + Fixed64_t 287 + { sort = Sfixed64; dec = (fun x -> x); enc = (fun x -> x); default = 0L } 154 288 155 289 let float : float t = 156 - { 157 - wire_type = Fixed32; 158 - write_value = Wire.write_float; 159 - write_body = Wire.write_float; 160 - read_wire = (fun w -> Int32.float_of_bits (fixed32_of w)); 161 - read_bytes = Wire.read_float; 162 - default = 0.0; 163 - } 290 + Fixed32_t 291 + { 292 + sort = Float; 293 + dec = Int32.float_of_bits; 294 + enc = Int32.bits_of_float; 295 + default = 0.0; 296 + } 164 297 165 298 let double : float t = 166 - { 167 - wire_type = Fixed64; 168 - write_value = Wire.write_double; 169 - write_body = Wire.write_double; 170 - read_wire = (fun w -> Int64.float_of_bits (fixed64_of w)); 171 - read_bytes = Wire.read_double; 172 - default = 0.0; 173 - } 299 + Fixed64_t 300 + { 301 + sort = Double; 302 + dec = Int64.float_of_bits; 303 + enc = Int64.bits_of_float; 304 + default = 0.0; 305 + } 174 306 175 307 let bool : bool t = 176 - { 177 - wire_type = Varint; 178 - write_value = Wire.write_bool; 179 - write_body = Wire.write_bool; 180 - read_wire = (fun w -> not (Int64.equal (varint_of w) 0L)); 181 - read_bytes = Wire.read_bool; 182 - default = false; 183 - } 184 - 185 - let not_packable _ _ = 186 - raise 187 - (Wire.Decode_error 188 - "length-delimited codec cannot be used inside a packed field") 308 + Varint 309 + { 310 + sort = Bool; 311 + dec = (fun x -> not (Int64.equal x 0L)); 312 + enc = (fun b -> if b then 1L else 0L); 313 + default = false; 314 + } 189 315 190 316 let string : string t = 191 - { 192 - wire_type = Length_delimited; 193 - write_value = Wire.write_string; 194 - write_body = Buffer.add_string; 195 - read_wire = (fun w -> length_delim_of w); 196 - read_bytes = not_packable; 197 - default = ""; 198 - } 317 + Length_delim 318 + { sort = String; dec = (fun x -> x); enc = (fun x -> x); default = "" } 199 319 200 320 let bytes : string t = 201 - { 202 - wire_type = Length_delimited; 203 - write_value = Wire.write_bytes; 204 - write_body = Buffer.add_string; 205 - read_wire = (fun w -> length_delim_of w); 206 - read_bytes = not_packable; 207 - default = ""; 208 - } 321 + Length_delim 322 + { sort = Bytes; dec = (fun x -> x); enc = (fun x -> x); default = "" } 323 + 324 + (* -- Recursive codecs -- 325 + 326 + Protobuf schemas can be self-referential (a tree node containing a 327 + list of child nodes of the same type). [fix] lets callers build a 328 + codec that references itself: [f] is invoked with a forwarding 329 + placeholder; any self-references in [f]'s body resolve to the 330 + final codec at decode/encode time via [Lazy]. *) 331 + 332 + let fix : type a. default:a -> (a t -> a t) -> a t = 333 + fun ~default f -> 334 + let rec lazy_codec = lazy (f (Rec lazy_codec)) in 335 + let _ = default in 336 + (* [default] is reserved for a future extension: currently all 337 + recursive codecs collapse to Message at runtime, whose 338 + [msg_default] carries the default. Keep the parameter in the API 339 + so callers don't need to change when we thread it. *) 340 + Rec lazy_codec 209 341 210 342 (* -- Message combinators -- 211 343 212 - The (o, a) field GADT is adapted from cbor's Obj_int. It captures a 213 - sequence of field declarations and the continuation that builds the 214 - record value. Encoding walks the GADT in declaration order (= tag 215 - order, conventionally) and emits (tag, value) per field. Decoding 216 - pre-parses the wire into a tag -> wire_value list table and then walks 217 - the same GADT, looking each field up in the table. *) 344 + The [(o, a) field] GADT captures a sequence of field declarations 345 + and the continuation that builds the record value. Encoding walks 346 + the GADT in declaration order and emits (tag, value) per field. 347 + Decoding pre-parses the wire into a tag -> wire_value list table 348 + and then walks the same GADT, looking each field up in the table. *) 218 349 219 350 module Message = struct 220 351 type (_, _) field = ··· 256 387 let packed tag get codec = 257 388 Repeated { tag; get; codec; packed = true; cont = (fun x -> Return x) } 258 389 259 - let rec ( let* ) : type o a b. 260 - (o, a) field -> (a -> (o, b) field) -> (o, b) field = 390 + let rec ( let* ) : 391 + type o a b. (o, a) field -> (a -> (o, b) field) -> (o, b) field = 261 392 fun m f -> 262 393 match m with 263 394 | Return a -> f a ··· 292 423 (* -- Encoding -- *) 293 424 294 425 let write_field buf ~tag codec v = 295 - Wire.write_tag buf ~field_number:tag ~wire_type:codec.wire_type; 296 - codec.write_value buf v 426 + Wire.write_tag buf ~field_number:tag ~wire_type:(wire_type_of codec); 427 + write_value buf codec v 297 428 298 429 let write_packed buf ~tag codec vs = 299 430 (* Concatenate raw value bytes into a scratch buffer, then emit as a 300 431 single length-delimited blob. *) 301 432 let body = Buffer.create 64 in 302 - List.iter (codec.write_value body) vs; 433 + let rec emit_list = function 434 + | [] -> () 435 + | v :: rest -> 436 + (match codec with 437 + | Varint b -> Wire.write_int64 body (b.enc v) 438 + | Fixed32_t b -> Wire.write_fixed32 body (b.enc v) 439 + | Fixed64_t b -> Wire.write_fixed64 body (b.enc v) 440 + | Length_delim _ | Message _ -> 441 + raise 442 + (Wire.Decode_error 443 + "length-delimited codec cannot be used inside a packed field") 444 + | Rec _ -> 445 + raise 446 + (Wire.Decode_error 447 + "recursive codec cannot be used inside a packed field")); 448 + emit_list rest 449 + in 450 + emit_list vs; 303 451 Wire.write_tag buf ~field_number:tag ~wire_type:Length_delimited; 304 452 Leb128.add_u63_to_buffer buf (Buffer.length body); 305 453 Buffer.add_buffer buf body ··· 310 458 | Return _ -> () 311 459 | Required { tag; get; codec; cont } -> 312 460 let v = get o in 313 - (* proto3 semantics: omit a required scalar field that equals the 314 - codec's default. Nested messages keep the same rule because their 315 - default is the empty-body message, which matches protoc. *) 316 - if v <> codec.default then write_field buf ~tag codec v; 461 + (* proto3 semantics: omit a required scalar field that equals 462 + the codec's default. *) 463 + if v <> default_of codec then write_field buf ~tag codec v; 317 464 encode_fields buf o (cont v) 318 465 | Optional { tag; get; codec; cont } -> 319 466 let v_opt = get o in ··· 345 492 match wt with 346 493 | Wire.Varint -> 347 494 let v, off = Wire.read_int64 s !pos in 348 - push field_number (Varint v); 495 + push field_number (WV_varint v); 349 496 pos := off 350 497 | Wire.Fixed32 -> 351 498 let v, off = Wire.read_fixed32 s !pos in 352 - push field_number (Fixed32 v); 499 + push field_number (WV_fixed32 v); 353 500 pos := off 354 501 | Wire.Fixed64 -> 355 502 let v, off = Wire.read_fixed64 s !pos in 356 - push field_number (Fixed64 v); 503 + push field_number (WV_fixed64 v); 357 504 pos := off 358 505 | Wire.Length_delimited -> 359 506 let v, off = Wire.read_bytes s !pos in 360 - push field_number (Length_delim v); 507 + push field_number (WV_length_delim v); 361 508 pos := off 362 509 done; 363 510 if !pos <> end_ then 364 511 raise 365 512 (Wire.Decode_error 366 - (Fmt.str "overran message boundary: at %d, expected end %d" !pos end_)); 513 + (Fmt.str "overran message boundary: at %d, expected end %d" !pos 514 + end_)); 367 515 table 368 516 369 517 let take_last table tag = ··· 374 522 let take_all table tag = 375 523 match Hashtbl.find_opt table tag with None -> [] | Some r -> List.rev !r 376 524 377 - let decode_packed_or_repeated codec values = 378 - (* For a repeated field, each element in [values] can be either a scalar 379 - wire value (non-packed) or a length-delimited blob containing the 380 - concatenation (packed). The protobuf spec requires decoders to accept 381 - both forms on the same field for compatibility. *) 525 + let decode_packed_or_repeated : type a. 526 + a t -> wire_value list -> a list = 527 + fun codec values -> 528 + (* For a repeated field, each element in [values] can be either a 529 + scalar wire value (non-packed) or a length-delimited blob 530 + containing the concatenation (packed). The protobuf spec 531 + requires decoders to accept both forms on the same field for 532 + compatibility. *) 382 533 let acc = ref [] in 383 534 List.iter 384 535 (fun w -> 385 - match w with 386 - | Length_delim body when codec.wire_type <> Length_delimited -> 536 + match (w, wire_type_of codec) with 537 + | WV_length_delim body, (Wire.Varint | Wire.Fixed32 | Wire.Fixed64) -> 387 538 (* Packed form: parse body as a sequence of raw values. *) 388 539 let pos = ref 0 in 389 540 let len = String.length body in 390 541 while !pos < len do 391 - let v, off = codec.read_bytes body !pos in 542 + let v, off = decode_bytes codec body !pos in 392 543 acc := v :: !acc; 393 544 pos := off 394 545 done 395 - | _ -> acc := codec.read_wire w :: !acc) 546 + | _ -> acc := decode_value codec w :: !acc) 396 547 values; 397 548 List.rev !acc 398 549 ··· 404 555 | Required { tag; codec; cont; _ } -> 405 556 let v = 406 557 match take_last table tag with 407 - | Some w -> codec.read_wire w 408 - | None -> codec.default 558 + | Some w -> decode_value codec w 559 + | None -> default_of codec 409 560 in 410 561 decode_fields table (cont v) 411 562 | Optional { tag; codec; cont; _ } -> 412 563 let v = 413 564 match take_last table tag with 414 - | Some w -> Some (codec.read_wire w) 565 + | Some w -> Some (decode_value codec w) 415 566 | None -> None 416 567 in 417 568 decode_fields table (cont v) ··· 426 577 let table = parse_wire s start end_ in 427 578 decode_fields table spec 428 579 in 429 - { 430 - wire_type = Length_delimited; 431 - write_value = 432 - (fun buf v -> 433 - let body = Buffer.create 64 in 434 - encode_body body v; 435 - Leb128.add_u63_to_buffer buf (Buffer.length body); 436 - Buffer.add_buffer buf body); 437 - write_body = encode_body; 438 - read_wire = 439 - (fun w -> 440 - let body = length_delim_of w in 441 - decode_body body 0 (String.length body)); 442 - read_bytes = not_packable; 443 - default = decode_body "" 0 0; 444 - (* A message with no fields populated: all scalars take their default, 445 - repeated fields are empty, optionals are [None]. *) 446 - } 580 + (* A message with no fields populated: all scalars take their 581 + default, repeated fields are empty, optionals are [None]. *) 582 + let msg_default = decode_body "" 0 0 in 583 + Message { encode_body; decode_body; msg_default } 447 584 end 448 585 449 - (* -- Top-level encode/decode. 586 + (* -- Top-level encode / decode. 450 587 451 - For messages, we write just the body (no outer length prefix or tag). 452 - For length-delimited scalars (string/bytes) we also write just the body 453 - bytes. For other scalars the body IS the value bytes. *) 588 + Messages and length-delimited scalars at top level write just the 589 + body (no outer tag or length prefix). Other scalars write their raw 590 + value bytes — useful for low-level round-trip tests. *) 454 591 455 - let encode_string codec v = 592 + let encode_string : type a. a t -> a -> string = 593 + fun codec v -> 456 594 let buf = Buffer.create 64 in 457 - codec.write_body buf v; 595 + (match codec with 596 + | Message m -> m.encode_body buf v 597 + | Length_delim b -> Buffer.add_string buf (b.enc v) 598 + | Varint _ | Fixed32_t _ | Fixed64_t _ -> write_value buf codec v 599 + | Rec c -> ( 600 + match Lazy.force c with 601 + | Message m -> m.encode_body buf v 602 + | Length_delim b -> Buffer.add_string buf (b.enc v) 603 + | other -> write_value buf other v)); 458 604 Buffer.contents buf 459 605 460 - let decode_string codec s = 606 + let decode_string : type a. a t -> string -> (a, string) result = 607 + fun codec s -> 608 + depth := 0; 461 609 try 462 - match codec.wire_type with 463 - | Length_delimited -> 464 - (* Both messages and string/bytes decode from the whole input as 465 - their body. [read_wire] unwraps a [Length_delim] value, and both 466 - cases expect a bare body here. *) 467 - Ok (codec.read_wire (Length_delim s)) 468 - | Varint | Fixed32 | Fixed64 -> 469 - let v, off = codec.read_bytes s 0 in 610 + match codec with 611 + | Message m -> Ok (m.decode_body s 0 (String.length s)) 612 + | Length_delim b -> Ok (b.dec s) 613 + | Varint _ | Fixed32_t _ | Fixed64_t _ -> 614 + let v, off = decode_bytes codec s 0 in 470 615 if off <> String.length s then 471 616 Error 472 - (Fmt.str "trailing %d bytes after scalar" (String.length s - off)) 617 + (Fmt.str "trailing %d bytes after scalar" 618 + (String.length s - off)) 473 619 else Ok v 620 + | Rec c -> ( 621 + match Lazy.force c with 622 + | Message m -> Ok (m.decode_body s 0 (String.length s)) 623 + | Length_delim b -> Ok (b.dec s) 624 + | other -> 625 + let v, off = decode_bytes other s 0 in 626 + if off <> String.length s then 627 + Error 628 + (Fmt.str "trailing %d bytes after scalar" 629 + (String.length s - off)) 630 + else Ok v) 474 631 with Wire.Decode_error msg -> Error msg 475 632 476 633 let encode codec w v =
+10
lib/protobuf.mli
··· 131 131 (required by the protobuf spec for compatibility). *) 132 132 end 133 133 134 + (** {1 Recursive codecs} *) 135 + 136 + val fix : default:'a -> ('a t -> 'a t) -> 'a t 137 + (** [fix ~default f] builds a self-referential codec. [f] is invoked 138 + once with a forwarding placeholder; any occurrences of the 139 + placeholder inside [f]'s body resolve to the final codec at 140 + decode/encode time. [~default] is the value used when a 141 + [required] field declared with this codec is absent from the 142 + wire (typically the "empty" record for a tree-shaped type). *) 143 + 134 144 (** {1 Encode / Decode} *) 135 145 136 146 val encode_string : 'a t -> 'a -> string
+46 -6
test/test_hostile.ml
··· 64 64 | Error _ -> () 65 65 | Ok _ -> Alcotest.fail "wire type 3 must be rejected" 66 66 67 - (* [CVE class deferred: deep nesting DoS.] A recursive nested-message 68 - codec would need a [Lazy.t] thunk or a [ref] that OCaml's [let rec] 69 - restriction rejects cleanly in the call site here. A depth limit in 70 - the main [Protobuf.decode] is not yet implemented; a malicious input 71 - with ~1000 levels of nested Length_delim fields can stack-overflow 72 - the OCaml runtime. Tracked as a TODO. *) 67 + (* --- CVE class: deep nesting DoS. 68 + 69 + A malicious input with thousands of nested length-delimited fields 70 + would stack-overflow the OCaml runtime without a depth bound. The 71 + decoder now rejects at 100 nesting levels. 72 + 73 + We construct the hostile input by hand (raw bytes) rather than 74 + through a recursive codec — a self-referential codec would need a 75 + Lazy/ref trick that isn't worth baking into the public API for a 76 + single hostile-input test. *) 77 + 78 + (* A self-referential codec built via a forward reference. The dummy 79 + slot is patched after [finish] closes over it; from then on 80 + [nest_codec.read_wire] drives itself recursively, and the depth 81 + counter observes each level. *) 82 + 83 + type nest = { inner : nest option } 84 + 85 + let nest_codec : nest Protobuf.t = 86 + Protobuf.fix ~default:{ inner = None } (fun self -> 87 + let open Protobuf.Message in 88 + finish 89 + (let* inner = optional 1 (fun r -> r.inner) self in 90 + return { inner })) 91 + 92 + let test_shallow_nesting_ok () = 93 + (* A 50-level nested message via the recursive [nest_codec]. Each 94 + level exercises the depth counter. Within the 100-level bound: 95 + should decode to the expected chain. *) 96 + let rec build n = if n = 0 then { inner = None } else { inner = Some (build (n - 1)) } in 97 + let v = build 50 in 98 + let wire = Protobuf.encode_string nest_codec v in 99 + match Protobuf.decode_string nest_codec wire with 100 + | Ok v' -> Alcotest.(check bool) "roundtrip" true (v = v') 101 + | Error msg -> Alcotest.failf "50-level nest should succeed: %s" msg 102 + 103 + let test_deep_nesting_rejected () = 104 + (* 200 levels exceeds the 100-level bound. *) 105 + let rec build n = if n = 0 then { inner = None } else { inner = Some (build (n - 1)) } in 106 + let v = build 200 in 107 + let wire = Protobuf.encode_string nest_codec v in 108 + match Protobuf.decode_string nest_codec wire with 109 + | Error _ -> () 110 + | Ok _ -> Alcotest.fail "200-level nest must be rejected" 73 111 74 112 (* --- CVE class: wire type mismatch (field declared as varint, wire has 75 113 length-delim). Decoder should reject cleanly. --- *) ··· 172 210 Alcotest.test_case "reserved tag 0" `Quick test_reserved_tag_zero; 173 211 Alcotest.test_case "unsupported wire type" `Quick 174 212 test_unsupported_wire_type; 213 + Alcotest.test_case "deep nesting rejected" `Quick test_deep_nesting_rejected; 214 + Alcotest.test_case "shallow nesting ok" `Quick test_shallow_nesting_ok; 175 215 Alcotest.test_case "wire type mismatch" `Quick test_wire_type_mismatch; 176 216 Alcotest.test_case "empty input -> defaults" `Quick test_empty_input; 177 217 Alcotest.test_case "overrun rejected" `Quick test_overrun_rejected;