Protocol Buffers codec for hand-written schemas
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

protobuf: add oneof combinator

Protobuf [oneof] groups a set of mutually exclusive optional fields at
distinct tags. Encoding emits whichever case matches; decoding picks
the case with the highest wire-order sequence (protobuf "last wins").

API:

val case : int -> 'a t -> inject:('a -> 'b) -> extract:('b -> 'a option)
-> 'b case
val oneof : default:'a -> ('o -> 'a) -> 'a case list -> ('o, 'a) field

Typical usage lifts the oneof alternatives into an OCaml polymorphic
variant:

type payload = [ `None | `Text of string | `Num of int32 ]

let msg_codec =
finish
(let* payload =
oneof ~default:`None (fun r -> r.payload)
[ case 1 string ~inject:(fun s -> `Text s)
~extract:(function `Text s -> Some s | _ -> None);
case 2 int32 ~inject:(fun n -> `Num n)
~extract:(function `Num n -> Some n | _ -> None) ] in
return { payload })

Internals:

- [parse_wire] now stamps each wire entry with a sequence counter so
[take_oneof_last] can find the case whose tag came last in wire
order. Hashtbl buckets still record per-tag wire order (reversed,
prepend-on-insert); the counter adds cross-tag ordering.
- [decode_fields] handles the new [Oneof] GADT constructor.
- [encode_fields] iterates the case list, picks the first whose
[extract] returns [Some], and emits that tag. If every extractor
returns [None] (e.g. value is the default variant), no wire bytes
are written -- matching protoc's behaviour for unset oneofs.
- [take_oneof_last] consumes every case tag from the table on exit
so oneof fields don't leak into the unknown-fields bag.

Tests: roundtrip through each case variant, empty-wire for the
default variant, and a "last wins" test where three consecutive
oneof tags appear on the wire and the decoder picks the third.

All 53 unit + 17 fuzz + 2 protoc interop tests pass.

+187 -10
+96 -10
lib/protobuf.ml
··· 379 379 cont : 'x list -> ('o, 'a) field; 380 380 } 381 381 -> ('o, 'a) field 382 + | Oneof : { 383 + get : 'o -> 'x; 384 + default : 'x; 385 + cases : 'x case list; 386 + cont : 'x -> ('o, 'a) field; 387 + } 388 + -> ('o, 'a) field 389 + 390 + and 'a case = 391 + | Case : { 392 + tag : int; 393 + codec : 'x t; 394 + inject : 'x -> 'a; 395 + extract : 'a -> 'x option; 396 + } 397 + -> 'a case 382 398 383 399 let return v = Return v 384 400 ··· 393 409 394 410 let packed tag get codec = 395 411 Repeated { tag; get; codec; packed = true; cont = (fun x -> Return x) } 412 + 413 + let case tag codec ~inject ~extract = Case { tag; codec; inject; extract } 414 + 415 + let oneof ~default get cases = 416 + Oneof { get; default; cases; cont = (fun x -> Return x) } 396 417 397 418 let rec ( let* ) : type o a b. 398 419 (o, a) field -> (a -> (o, b) field) -> (o, b) field = ··· 426 447 let* y = r.cont x in 427 448 f y); 428 449 } 450 + | Oneof r -> 451 + Oneof 452 + { 453 + r with 454 + cont = 455 + (fun x -> 456 + let* y = r.cont x in 457 + f y); 458 + } 429 459 430 460 (* -- Encoding -- *) 431 461 ··· 480 510 | _ when packed -> write_packed buf ~tag codec vs 481 511 | _ -> List.iter (write_field buf ~tag codec) vs); 482 512 encode_fields buf o (cont vs) 513 + | Oneof { get; cases; cont; _ } -> 514 + let v = get o in 515 + let rec emit_case = function 516 + | [] -> () 517 + | Case { tag; codec; extract; _ } :: rest -> ( 518 + match extract v with 519 + | Some x -> write_field buf ~tag codec x 520 + | None -> emit_case rest) 521 + in 522 + emit_case cases; 523 + encode_fields buf o (cont v) 483 524 484 525 (* -- Decoding helpers -- *) 485 526 486 - (* Parse the wire into a tag -> [wire_value] table. Order within each 487 - bucket reflects wire order (first to last). *) 488 - let parse_wire s start end_ : (int, wire_value list ref) Hashtbl.t = 527 + (* Parse the wire into a tag -> [(seq, wire_value) list] table. The 528 + sequence counter captures global wire order across tags, so 529 + [oneof] can determine which of its alternative cases came last. 530 + Within each bucket the list is stored in reverse wire order 531 + (prepend on insert), so [List.hd] is the last-added entry. *) 532 + let parse_wire s start end_ : (int, (int * wire_value) list ref) Hashtbl.t = 489 533 let table = Hashtbl.create 8 in 534 + let seq = ref 0 in 490 535 let push tag v = 536 + let entry = (!seq, v) in 537 + incr seq; 491 538 match Hashtbl.find_opt table tag with 492 - | Some r -> r := v :: !r 493 - | None -> Hashtbl.add table tag (ref [ v ]) 539 + | Some r -> r := entry :: !r 540 + | None -> Hashtbl.add table tag (ref [ entry ]) 494 541 in 495 542 let pos = ref start in 496 543 while !pos < end_ do ··· 528 575 match Hashtbl.find_opt table tag with 529 576 | None -> None 530 577 | Some r -> ( 531 - match List.rev !r with 578 + (* The list is in reverse wire order (last-added first). *) 579 + match !r with 532 580 | [] -> None 533 - | v :: _ -> 581 + | (_, v) :: _ -> 534 582 Hashtbl.remove table tag; 535 583 Some v) 536 584 ··· 539 587 | None -> [] 540 588 | Some r -> 541 589 Hashtbl.remove table tag; 542 - List.rev !r 590 + List.rev_map snd !r 591 + 592 + (* For oneof: find the case whose tag has the highest wire-sequence 593 + number. Returns [Some (case, wire_value)] if any case was on the 594 + wire, [None] otherwise. Removes the consumed case's entry from 595 + the table. *) 596 + let take_oneof_last : type a. 597 + (int, (int * wire_value) list ref) Hashtbl.t -> 598 + a case list -> 599 + (a case * wire_value) option = 600 + fun table cases -> 601 + let best = ref None in 602 + List.iter 603 + (fun (Case { tag; _ } as c) -> 604 + match Hashtbl.find_opt table tag with 605 + | None -> () 606 + | Some r -> ( 607 + match !r with 608 + | [] -> () 609 + | (seq, v) :: _ -> ( 610 + match !best with 611 + | None -> best := Some (seq, c, v) 612 + | Some (best_seq, _, _) when seq > best_seq -> 613 + best := Some (seq, c, v) 614 + | Some _ -> ()))) 615 + cases; 616 + (* Consume every case's tag from the table so they don't leak to 617 + unknowns. *) 618 + List.iter (fun (Case { tag; _ }) -> Hashtbl.remove table tag) cases; 619 + match !best with None -> None | Some (_, c, v) -> Some (c, v) 543 620 544 621 let write_unknown_field buf tag = function 545 622 | WV_varint v -> ··· 563 640 List.iter 564 641 (fun tag -> 565 642 let rvals = Hashtbl.find table tag in 566 - List.iter (fun wv -> write_unknown_field buf tag wv) (List.rev !rvals)) 643 + List.iter 644 + (fun (_, wv) -> write_unknown_field buf tag wv) 645 + (List.rev !rvals)) 567 646 tags; 568 647 Buffer.contents buf 569 648 ··· 592 671 List.rev !acc 593 672 594 673 let rec decode_fields : type o a. 595 - (int, wire_value list ref) Hashtbl.t -> (o, a) field -> a = 674 + (int, (int * wire_value) list ref) Hashtbl.t -> (o, a) field -> a = 596 675 fun table m -> 597 676 match m with 598 677 | Return a -> a ··· 613 692 | Repeated { tag; codec; cont; _ } -> 614 693 let vs = decode_packed_or_repeated codec (take_all table tag) in 615 694 decode_fields table (cont vs) 695 + | Oneof { default; cases; cont; _ } -> 696 + let v = 697 + match take_oneof_last table cases with 698 + | None -> default 699 + | Some (Case { codec; inject; _ }, w) -> inject (decode_value codec w) 700 + in 701 + decode_fields table (cont v) 616 702 617 703 (* A [map<K, V>] field is sugar for [repeated Entry { K key = 1; V value 618 704 = 2 }] on the wire — each entry is a length-delimited submessage with
+23
lib/protobuf.mli
··· 145 145 146 146 Protobuf restricts map keys to the integer/bool/string scalars; this API 147 147 does not enforce that — use a valid key codec. *) 148 + 149 + (** {2 Oneof} *) 150 + 151 + type 'a case 152 + (** One alternative in an [oneof] group. *) 153 + 154 + val case : 155 + int -> 'a t -> inject:('a -> 'b) -> extract:('b -> 'a option) -> 'b case 156 + (** [case tag codec ~inject ~extract] declares one oneof alternative at [tag] 157 + carrying an [a]. [inject] lifts the decoded value into the oneof's sum 158 + type; [extract] is its inverse, returning [Some] when the oneof value 159 + matches this case, [None] otherwise. *) 160 + 161 + val oneof : default:'a -> ('o -> 'a) -> 'a case list -> ('o, 'a) field 162 + (** [oneof ~default get cases] declares a oneof group: at most one of the 163 + listed cases may be set on the wire. 164 + 165 + - Encoding: [get o] is called to obtain the current oneof value; the first 166 + [case] whose [extract] returns [Some x] is emitted. If every [extract] 167 + returns [None], no tag is written. 168 + - Decoding: the case with the highest wire position wins (protobuf "last 169 + one on the wire" rule for oneofs). If no case appears on the wire, 170 + [~default] is used. *) 148 171 end 149 172 150 173 (** {1 Recursive codecs} *)
+68
test/test_protobuf.ml
··· 374 374 | Ok (_, unknowns) -> 375 375 Alcotest.(check int) "no unknowns" 0 (String.length unknowns) 376 376 377 + (* --- Test 13: oneof --- *) 378 + 379 + type payload = [ `None | `Text of string | `Num of int32 | `Blob of string ] 380 + 381 + type msg_with_payload = { payload : payload } 382 + 383 + let msg_with_payload_codec : msg_with_payload Protobuf.t = 384 + let open Protobuf.Message in 385 + finish 386 + (let* payload = 387 + oneof ~default:`None 388 + (fun r -> r.payload) 389 + [ 390 + case 1 Protobuf.string 391 + ~inject:(fun s -> `Text s) 392 + ~extract:(function `Text s -> Some s | _ -> None); 393 + case 2 Protobuf.int32 394 + ~inject:(fun n -> `Num n) 395 + ~extract:(function `Num n -> Some n | _ -> None); 396 + case 3 Protobuf.bytes 397 + ~inject:(fun b -> `Blob b) 398 + ~extract:(function `Blob b -> Some b | _ -> None); 399 + ] 400 + in 401 + return { payload }) 402 + 403 + let test_oneof_text () = 404 + let v = { payload = `Text "hello" } in 405 + let wire = Protobuf.encode_string msg_with_payload_codec v in 406 + match Protobuf.decode_string msg_with_payload_codec wire with 407 + | Error msg -> Alcotest.fail msg 408 + | Ok r -> Alcotest.(check bool) "roundtrip" true (r.payload = `Text "hello") 409 + 410 + let test_oneof_num () = 411 + let v = { payload = `Num 42l } in 412 + let wire = Protobuf.encode_string msg_with_payload_codec v in 413 + match Protobuf.decode_string msg_with_payload_codec wire with 414 + | Error msg -> Alcotest.fail msg 415 + | Ok r -> Alcotest.(check bool) "roundtrip" true (r.payload = `Num 42l) 416 + 417 + let test_oneof_none () = 418 + let v = { payload = `None } in 419 + let wire = Protobuf.encode_string msg_with_payload_codec v in 420 + Alcotest.(check int) "empty wire" 0 (String.length wire); 421 + match Protobuf.decode_string msg_with_payload_codec wire with 422 + | Error msg -> Alcotest.fail msg 423 + | Ok r -> Alcotest.(check bool) "payload is None" true (r.payload = `None) 424 + 425 + let test_oneof_last_wins () = 426 + (* Wire has three oneof tags back-to-back; decoder takes the last. *) 427 + let buf = Buffer.create 16 in 428 + Protobuf.Wire.write_tag buf ~field_number:1 429 + ~wire_type:Protobuf.Wire.Length_delimited; 430 + Protobuf.Wire.write_string buf "first"; 431 + Protobuf.Wire.write_tag buf ~field_number:2 ~wire_type:Protobuf.Wire.Varint; 432 + Protobuf.Wire.write_int32 buf 7l; 433 + Protobuf.Wire.write_tag buf ~field_number:3 434 + ~wire_type:Protobuf.Wire.Length_delimited; 435 + Protobuf.Wire.write_string buf "winner"; 436 + let wire = Buffer.contents buf in 437 + match Protobuf.decode_string msg_with_payload_codec wire with 438 + | Error msg -> Alcotest.fail msg 439 + | Ok r -> Alcotest.(check bool) "last wins" true (r.payload = `Blob "winner") 440 + 377 441 let test_map_empty () = 378 442 let v = { entries = [] } in 379 443 let wire = Protobuf.encode_string dict_codec v in ··· 730 794 test_unknown_fields_preserved; 731 795 Alcotest.test_case "unknowns empty when schema matches" `Quick 732 796 test_unknowns_empty_when_schema_matches; 797 + Alcotest.test_case "oneof: text case" `Quick test_oneof_text; 798 + Alcotest.test_case "oneof: num case" `Quick test_oneof_num; 799 + Alcotest.test_case "oneof: none -> empty wire" `Quick test_oneof_none; 800 + Alcotest.test_case "oneof: last wins" `Quick test_oneof_last_wins; 733 801 ] 734 802 @ hostile_cases )