WebSocket frame codec (RFC 6455)
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

ocaml-websocket: import websocket handshake from ocaml-requests

+679 -18
+11 -3
dune-project
··· 12 12 13 13 (package 14 14 (name nox-websocket) 15 - (synopsis "WebSocket frame codec (RFC 6455)") 15 + (synopsis "WebSocket protocol (RFC 6455): frame codec and HTTP handshake") 16 16 (tags (org:blacksun network http)) 17 17 (description 18 - "Encode and decode WebSocket frames. Handles masking, fragmentation, 19 - ping/pong, and close frames. Does not handle the HTTP upgrade handshake.") 18 + "Both layers of RFC 6455: the frame codec (encode/decode, masking, 19 + fragmentation, ping/pong, close) at the top of the [Websocket] module, 20 + and the HTTP upgrade handshake (Sec-WebSocket-Key generation, 21 + Sec-WebSocket-Accept computation, protocol/extension negotiation) 22 + under [Websocket.Handshake].") 20 23 (depends 21 24 (ocaml (>= 5.2)) 22 25 (fmt (>= 0.9)) 26 + logs 27 + base64 28 + digestif 29 + nox-crypto-rng 30 + nox-http 23 31 (alcotest :with-test) 24 32 (mdx :with-test) 25 33 (alcobar :with-test)))
+1 -1
lib/dune
··· 1 1 (library 2 2 (name websocket) 3 3 (public_name nox-websocket) 4 - (libraries fmt)) 4 + (libraries fmt logs base64 digestif nox-crypto-rng nox-http))
+161
lib/handshake.ml
··· 1 + (*--------------------------------------------------------------------------- 2 + Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 + SPDX-License-Identifier: ISC 4 + ---------------------------------------------------------------------------*) 5 + 6 + open Http 7 + 8 + let src = Logs.Src.create "websocket.handshake" ~doc:"WebSocket handshake" 9 + 10 + module Log = (val Logs.src_log src : Logs.LOG) 11 + 12 + let err_unexpected_status status = 13 + Error (Fmt.str "Expected status 101, got %d" status) 14 + 15 + let err_bad_upgrade upgrade = 16 + Error (Fmt.str "Upgrade header is '%s', expected 'websocket'" upgrade) 17 + 18 + let protocol_version = "13" 19 + let magic_guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 20 + 21 + let generate_key () = 22 + let random_bytes = Crypto_rng.generate 16 in 23 + let key = Base64.encode_exn random_bytes in 24 + Log.debug (fun m -> m "Generated WebSocket key: %s" key); 25 + key 26 + 27 + let compute_accept ~key = 28 + let combined = key ^ magic_guid in 29 + let hash = Digestif.SHA1.(digest_string combined |> to_raw_string) in 30 + let accept = Base64.encode_exn hash in 31 + Log.debug (fun m -> m "Computed WebSocket accept for key %s: %s" key accept); 32 + accept 33 + 34 + let validate_accept ~key ~accept = 35 + let expected = compute_accept ~key in 36 + let valid = String.equal expected accept in 37 + if not valid then 38 + Log.warn (fun m -> 39 + m "WebSocket accept validation failed: expected %s, got %s" expected 40 + accept); 41 + valid 42 + 43 + let parse_protocols s = 44 + String.split_on_char ',' s |> List.map String.trim 45 + |> List.filter (fun s -> String.length s > 0) 46 + 47 + let protocols_to_string protocols = String.concat ", " protocols 48 + 49 + let select_protocol ~offered ~supported = 50 + List.find_opt (fun s -> List.mem s offered) supported 51 + 52 + type extension = { name : string; params : (string * string option) list } 53 + 54 + let parse_single_extension s = 55 + let parts = String.split_on_char ';' s |> List.map String.trim in 56 + match parts with 57 + | [] -> None 58 + | name :: params -> 59 + let parse_param p = 60 + match String.index_opt p '=' with 61 + | None -> (String.trim p, None) 62 + | Some eq_idx -> 63 + let key = String.trim (String.sub p 0 eq_idx) in 64 + let value = 65 + String.trim 66 + (String.sub p (eq_idx + 1) (String.length p - eq_idx - 1)) 67 + in 68 + let value = 69 + if String.length value >= 2 && value.[0] = '"' then 70 + String.sub value 1 (String.length value - 2) 71 + else value 72 + in 73 + (key, Some value) 74 + in 75 + Some { name = String.trim name; params = List.map parse_param params } 76 + 77 + let parse_extensions s = 78 + let extensions = String.split_on_char ',' s in 79 + List.filter_map parse_single_extension extensions 80 + 81 + let extensions_to_string extensions = 82 + let ext_to_string ext = 83 + let params_str = 84 + List.map 85 + (fun (k, v) -> match v with None -> k | Some v -> Fmt.str "%s=%s" k v) 86 + ext.params 87 + in 88 + String.concat "; " (ext.name :: params_str) 89 + in 90 + String.concat ", " (List.map ext_to_string extensions) 91 + 92 + let has_extension ~name extensions = 93 + List.exists (fun ext -> String.equal ext.name name) extensions 94 + 95 + let extension_params ~name extensions = 96 + match List.find_opt (fun ext -> String.equal ext.name name) extensions with 97 + | Some ext -> Some ext.params 98 + | None -> None 99 + 100 + let upgrade_headers ~key ?protocols ?extensions ?origin () = 101 + let headers = 102 + Headers.empty 103 + |> Headers.set `Upgrade "websocket" 104 + |> Headers.set `Connection "Upgrade" 105 + |> Headers.set `Sec_websocket_key key 106 + |> Headers.set `Sec_websocket_version protocol_version 107 + in 108 + let headers = 109 + match protocols with 110 + | Some ps when ps <> [] -> 111 + Headers.set `Sec_websocket_protocol (protocols_to_string ps) headers 112 + | _ -> headers 113 + in 114 + let headers = 115 + match extensions with 116 + | Some exts when exts <> [] -> 117 + Headers.set `Sec_websocket_extensions 118 + (extensions_to_string exts) 119 + headers 120 + | _ -> headers 121 + in 122 + let headers = 123 + match origin with 124 + | Some o -> Headers.set `Origin o headers 125 + | None -> headers 126 + in 127 + headers 128 + 129 + let string_contains ~needle haystack = 130 + let nlen = String.length needle in 131 + let hlen = String.length haystack in 132 + if nlen > hlen then false 133 + else 134 + let rec check i = 135 + if i + nlen > hlen then false 136 + else if String.sub haystack i nlen = needle then true 137 + else check (i + 1) 138 + in 139 + check 0 140 + 141 + let validate_upgrade_response ~key ~status ~headers = 142 + if status <> 101 then err_unexpected_status status 143 + else 144 + match Headers.find `Upgrade headers with 145 + | None -> Error "Missing Upgrade header" 146 + | Some upgrade when String.lowercase_ascii upgrade <> "websocket" -> 147 + err_bad_upgrade upgrade 148 + | Some _ -> ( 149 + match Headers.find `Connection headers with 150 + | None -> Error "Missing Connection header" 151 + | Some conn -> ( 152 + let conn_lower = String.lowercase_ascii conn in 153 + if not (string_contains ~needle:"upgrade" conn_lower) then 154 + Error 155 + (Fmt.str "Connection header is '%s', expected 'Upgrade'" conn) 156 + else 157 + match Headers.find `Sec_websocket_accept headers with 158 + | None -> Error "Missing Sec-WebSocket-Accept header" 159 + | Some accept -> 160 + if validate_accept ~key ~accept then Ok () 161 + else Error "Sec-WebSocket-Accept validation failed"))
+170
lib/handshake.mli
··· 1 + (*--------------------------------------------------------------------------- 2 + Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 + SPDX-License-Identifier: ISC 4 + ---------------------------------------------------------------------------*) 5 + 6 + open Http 7 + 8 + (** WebSocket HTTP upgrade handshake (RFC 6455 §4). 9 + 10 + A WebSocket connection is established by upgrading an HTTP/1.1 connection 11 + using the [Upgrade] mechanism. This module provides the helpers used by 12 + that handshake: client key generation, server-side accept-value 13 + computation, validation, and protocol/extension negotiation. 14 + 15 + The on-the-wire frame codec lives in {!module:Websocket} (sibling 16 + module). Use this module for the handshake; switch to {!Websocket.encode} 17 + / {!Websocket.decode} once the connection is upgraded. 18 + 19 + Sec-WebSocket-Accept is deterministic given the client key (RFC 6455 §4.2.2 20 + test vector): 21 + {[ 22 + let () = 23 + assert 24 + (Websocket.Handshake.compute_accept ~key:"dGhlIHNhbXBsZSBub25jZQ==" 25 + = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); 26 + assert 27 + (Websocket.Handshake.validate_accept 28 + ~key:"dGhlIHNhbXBsZSBub25jZQ==" 29 + ~accept:"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); 30 + assert (Websocket.Handshake.protocol_version = "13") 31 + ]} 32 + 33 + @see <https://www.rfc-editor.org/rfc/rfc6455> 34 + RFC 6455: The WebSocket Protocol *) 35 + 36 + (** {1 Constants} *) 37 + 38 + val protocol_version : string 39 + (** [protocol_version] is the WebSocket protocol version string per RFC 6455 40 + (always ["13"]). Used as the value for the Sec-WebSocket-Version header. *) 41 + 42 + val magic_guid : string 43 + (** The magic GUID used in Sec-WebSocket-Accept computation. 44 + @see <https://www.rfc-editor.org/rfc/rfc6455#section-1.3> 45 + RFC 6455 Section 1.3. *) 46 + 47 + (** {1 Sec-WebSocket-Key} 48 + 49 + @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.1> 50 + RFC 6455 Section 4.1 *) 51 + 52 + val generate_key : unit -> string 53 + (** [generate_key ()] creates a random Sec-WebSocket-Key value. 54 + 55 + Generates a cryptographically random 16-byte nonce and base64-encodes it. 56 + The result is suitable for use in the Sec-WebSocket-Key header. *) 57 + 58 + (** {1 Sec-WebSocket-Accept} 59 + 60 + @see <https://www.rfc-editor.org/rfc/rfc6455#section-4.2.2> 61 + RFC 6455 Section 4.2.2 *) 62 + 63 + val compute_accept : key:string -> string 64 + (** [compute_accept ~key] computes the expected Sec-WebSocket-Accept value. 65 + 66 + The computation is: [base64(SHA-1(key ^ magic_guid))] 67 + 68 + @param key The Sec-WebSocket-Key sent by the client 69 + @return The expected Sec-WebSocket-Accept value. *) 70 + 71 + val validate_accept : key:string -> accept:string -> bool 72 + (** [validate_accept ~key ~accept] validates a server's Sec-WebSocket-Accept. 73 + 74 + @param key The Sec-WebSocket-Key that was sent 75 + @param accept The Sec-WebSocket-Accept received from the server 76 + @return [true] if the accept value is correct. *) 77 + 78 + (** {1 Sec-WebSocket-Protocol} 79 + 80 + @see <https://www.rfc-editor.org/rfc/rfc6455#section-11.3.4> 81 + RFC 6455 Section 11.3.4 *) 82 + 83 + val parse_protocols : string -> string list 84 + (** [parse_protocols s] parses a Sec-WebSocket-Protocol header value. 85 + 86 + Example: ["graphql-ws, graphql-transport-ws"] -> 87 + [["graphql-ws"; "graphql-transport-ws"]]. *) 88 + 89 + val protocols_to_string : string list -> string 90 + (** [protocols_to_string protocols] formats protocols as a header value. *) 91 + 92 + val select_protocol : 93 + offered:string list -> supported:string list -> string option 94 + (** [select_protocol ~offered ~supported] selects a mutually acceptable 95 + protocol. 96 + 97 + @param offered The protocols offered by the client 98 + @param supported The protocols we support (in preference order) 99 + @return The selected protocol, or [None] if no match. *) 100 + 101 + (** {1 Sec-WebSocket-Extensions} 102 + 103 + @see <https://www.rfc-editor.org/rfc/rfc6455#section-9> RFC 6455 Section 9 104 + @see <https://www.rfc-editor.org/rfc/rfc7692> 105 + RFC 7692: Compression Extensions *) 106 + 107 + type extension = { name : string; params : (string * string option) list } 108 + (** An extension with optional parameters. 109 + 110 + Example: [permessage-deflate; client_max_window_bits] *) 111 + 112 + val parse_extensions : string -> extension list 113 + (** [parse_extensions s] parses a Sec-WebSocket-Extensions header value. 114 + 115 + Example: ["permessage-deflate; client_max_window_bits"]. *) 116 + 117 + val extensions_to_string : extension list -> string 118 + (** [extensions_to_string extensions] formats extensions as a header value. *) 119 + 120 + val has_extension : name:string -> extension list -> bool 121 + (** [has_extension ~name extensions] checks if an extension is present. *) 122 + 123 + val extension_params : 124 + name:string -> extension list -> (string * string option) list option 125 + (** [extension_params ~name extensions] gets parameters for an extension. *) 126 + 127 + (** {1 Handshake Helpers} *) 128 + 129 + val upgrade_headers : 130 + key:string -> 131 + ?protocols:string list -> 132 + ?extensions:extension list -> 133 + ?origin:string -> 134 + unit -> 135 + Headers.t 136 + (** [upgrade_headers ~key ?protocols ?extensions ?origin ()] builds headers for 137 + a WebSocket upgrade request. 138 + 139 + Sets the following headers: 140 + - [Upgrade: websocket] 141 + - [Connection: Upgrade] 142 + - [Sec-WebSocket-Key: {key}] 143 + - [Sec-WebSocket-Version: 13] 144 + - [Sec-WebSocket-Protocol: ...] (if protocols provided) 145 + - [Sec-WebSocket-Extensions: ...] (if extensions provided) 146 + - [Origin: ...] (if origin provided) 147 + 148 + @param key The Sec-WebSocket-Key (use {!generate_key} to create) 149 + @param protocols Optional list of subprotocols to request 150 + @param extensions Optional list of extensions to request 151 + @param origin Optional Origin header value. *) 152 + 153 + val validate_upgrade_response : 154 + key:string -> status:int -> headers:Headers.t -> (unit, string) result 155 + (** [validate_upgrade_response ~key ~status ~headers] validates a WebSocket 156 + upgrade response. 157 + 158 + Checks that: 159 + - Status code is 101 (Switching Protocols) 160 + - Upgrade header is "websocket" 161 + - Connection header includes "Upgrade" 162 + - Sec-WebSocket-Accept is correct for the given key 163 + 164 + @param key The Sec-WebSocket-Key that was sent 165 + @param status The HTTP status code 166 + @param headers The response headers 167 + @return [Ok ()] if valid, [Error reason] if invalid. *) 168 + 169 + val src : Logs.Src.t 170 + (** Log source for handshake operations. *)
+2
lib/websocket.ml
··· 1 1 (* WebSocket frame codec — RFC 6455 §5 *) 2 2 3 + module Handshake = Handshake 4 + 3 5 type opcode = Continuation | Text | Binary | Close | Ping | Pong 4 6 5 7 let pp_opcode ppf = function
+21 -9
lib/websocket.mli
··· 1 - (** WebSocket frame codec (RFC 6455). 1 + (** WebSocket protocol (RFC 6455). 2 + 3 + This module exposes both layers of RFC 6455: 4 + 5 + - The frame codec at the top of the module — encode/decode frames, 6 + masking, fragmentation, ping/pong, close. 7 + - The HTTP upgrade handshake under {!module:Handshake} — Sec-WebSocket-Key 8 + generation, Sec-WebSocket-Accept computation, protocol/extension 9 + negotiation, request/response header builders. 2 10 3 - Encode and decode WebSocket frames. Handles masking, fragmentation, 4 - ping/pong, and close frames. Does NOT handle the HTTP upgrade handshake — 5 - see [ocaml-requests] for that. 11 + Use {!module:Handshake} to upgrade an HTTP/1.1 connection to a WebSocket; 12 + once upgraded, switch to {!encode}/{!decode} to drive the wire protocol. 6 13 7 14 {[ 8 - let frame = Frame.text "hello" in 9 - let bytes = Frame.encode frame in 10 - match Frame.decode bytes with 11 - | Ok (frame, rest) -> ... 12 - | Error `Need_more -> (* incomplete frame *) 15 + let frame = Websocket.text "hello" 16 + let bytes = Websocket.encode frame 17 + 18 + let () = 19 + match Websocket.decode bytes with 20 + | Ok (frame, _rest) -> assert (frame.payload = "hello") 21 + | Error _ -> assert false 13 22 ]} *) 23 + 24 + module Handshake = Handshake 25 + (** RFC 6455 §4: HTTP upgrade handshake. *) 14 26 15 27 (** {1 Opcode} *) 16 28
+11 -3
nox-websocket.opam
··· 1 1 # This file is generated by dune, edit dune-project instead 2 2 opam-version: "2.0" 3 - synopsis: "WebSocket frame codec (RFC 6455)" 3 + synopsis: "WebSocket protocol (RFC 6455): frame codec and HTTP handshake" 4 4 description: """ 5 - Encode and decode WebSocket frames. Handles masking, fragmentation, 6 - ping/pong, and close frames. Does not handle the HTTP upgrade handshake.""" 5 + Both layers of RFC 6455: the frame codec (encode/decode, masking, 6 + fragmentation, ping/pong, close) at the top of the [Websocket] module, 7 + and the HTTP upgrade handshake (Sec-WebSocket-Key generation, 8 + Sec-WebSocket-Accept computation, protocol/extension negotiation) 9 + under [Websocket.Handshake].""" 7 10 maintainer: ["Thomas Gazagnaire <thomas@gazagnaire.org>"] 8 11 authors: ["Thomas Gazagnaire <thomas@gazagnaire.org>"] 9 12 license: "ISC" ··· 14 17 "dune" {>= "3.21"} 15 18 "ocaml" {>= "5.2"} 16 19 "fmt" {>= "0.9"} 20 + "logs" 21 + "base64" 22 + "digestif" 23 + "nox-crypto-rng" 24 + "nox-http" 17 25 "alcotest" {with-test} 18 26 "mdx" {with-test} 19 27 "alcobar" {with-test}
+1 -1
test/dune
··· 1 1 (test 2 2 (name test) 3 - (libraries nox-websocket alcotest)) 3 + (libraries nox-websocket nox-http base64 nox-crypto-rng.unix alcotest))
+3 -1
test/test.ml
··· 1 - let () = Alcotest.run "websocket" [ Test_websocket.suite ] 1 + let () = 2 + Crypto_rng_unix.use_default (); 3 + Alcotest.run "websocket" [ Test_websocket.suite; Test_handshake.suite ]
+298
test/test_handshake.ml
··· 1 + (*--------------------------------------------------------------------------- 2 + Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 + SPDX-License-Identifier: ISC 4 + ---------------------------------------------------------------------------*) 5 + 6 + (** Tests for Websocket.Handshake (RFC 6455 §4) *) 7 + 8 + module Handshake = Websocket.Handshake 9 + module Headers = Http.Headers 10 + 11 + (** Helper for string contains *) 12 + let string_contains ~affix s = 13 + let alen = String.length affix in 14 + let slen = String.length s in 15 + if alen > slen then false 16 + else 17 + let rec check i = 18 + if i + alen > slen then false 19 + else if String.sub s i alen = affix then true 20 + else check (i + 1) 21 + in 22 + check 0 23 + 24 + (** {1 Key Generation Tests} *) 25 + 26 + let test_generate_key_length () = 27 + let key = Handshake.generate_key () in 28 + let decoded = Base64.decode_exn key in 29 + Alcotest.(check int) "decoded length" 16 (String.length decoded) 30 + 31 + let test_generate_key_unique () = 32 + let key1 = Handshake.generate_key () in 33 + let key2 = Handshake.generate_key () in 34 + Alcotest.(check bool) "keys are different" true (key1 <> key2) 35 + 36 + let test_generate_key_valid_base64 () = 37 + let key = Handshake.generate_key () in 38 + let _ = Base64.decode_exn key in 39 + Alcotest.(check pass) "valid base64" () () 40 + 41 + (** {1 Accept Computation Tests (RFC 6455 Section 4.2.2)} *) 42 + 43 + let test_compute_accept_rfc_example () = 44 + let key = "dGhlIHNhbXBsZSBub25jZQ==" in 45 + let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" in 46 + let accept = Handshake.compute_accept ~key in 47 + Alcotest.(check string) "RFC example accept" expected accept 48 + 49 + let test_validate_accept_correct () = 50 + let key = "dGhlIHNhbXBsZSBub25jZQ==" in 51 + let accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" in 52 + Alcotest.(check bool) 53 + "valid accept" true 54 + (Handshake.validate_accept ~key ~accept) 55 + 56 + let test_validate_accept_incorrect () = 57 + let key = "dGhlIHNhbXBsZSBub25jZQ==" in 58 + let accept = "wrongvalue" in 59 + Alcotest.(check bool) 60 + "invalid accept" false 61 + (Handshake.validate_accept ~key ~accept) 62 + 63 + let test_compute_validate_roundtrip () = 64 + let key = Handshake.generate_key () in 65 + let accept = Handshake.compute_accept ~key in 66 + Alcotest.(check bool) 67 + "roundtrip validation" true 68 + (Handshake.validate_accept ~key ~accept) 69 + 70 + (** {1 Protocol Negotiation Tests} *) 71 + 72 + let test_parse_protocols_basic () = 73 + let protos = Handshake.parse_protocols "graphql-ws, graphql-transport-ws" in 74 + Alcotest.(check int) "count" 2 (List.length protos); 75 + Alcotest.(check bool) "has graphql-ws" true (List.mem "graphql-ws" protos); 76 + Alcotest.(check bool) 77 + "has graphql-transport-ws" true 78 + (List.mem "graphql-transport-ws" protos) 79 + 80 + let test_parse_protocols_single () = 81 + let protos = Handshake.parse_protocols "chat" in 82 + Alcotest.(check (list string)) "single" [ "chat" ] protos 83 + 84 + let test_parse_protocols_empty () = 85 + let protos = Handshake.parse_protocols "" in 86 + Alcotest.(check (list string)) "empty" [] protos 87 + 88 + let test_protocols_to_string () = 89 + let protos = [ "graphql-ws"; "graphql-transport-ws" ] in 90 + let s = Handshake.protocols_to_string protos in 91 + Alcotest.(check string) "to_string" "graphql-ws, graphql-transport-ws" s 92 + 93 + let test_select_protocol_match () = 94 + let offered = [ "chat"; "superchat" ] in 95 + let supported = [ "superchat"; "chat" ] in 96 + let selected = Handshake.select_protocol ~offered ~supported in 97 + Alcotest.(check (option string)) "selected" (Some "superchat") selected 98 + 99 + let test_select_protocol_no_match () = 100 + let offered = [ "chat" ] in 101 + let supported = [ "other" ] in 102 + let selected = Handshake.select_protocol ~offered ~supported in 103 + Alcotest.(check (option string)) "no match" None selected 104 + 105 + (** {1 Extension Parsing Tests} *) 106 + 107 + let test_parse_extensions_basic () = 108 + let exts = Handshake.parse_extensions "permessage-deflate" in 109 + Alcotest.(check int) "count" 1 (List.length exts); 110 + Alcotest.(check string) "name" "permessage-deflate" (List.hd exts).name; 111 + Alcotest.(check int) "params count" 0 (List.length (List.hd exts).params) 112 + 113 + let test_parse_extensions_with_params () = 114 + let exts = 115 + Handshake.parse_extensions 116 + "permessage-deflate; client_max_window_bits; server_no_context_takeover" 117 + in 118 + Alcotest.(check int) "count" 1 (List.length exts); 119 + let ext = List.hd exts in 120 + Alcotest.(check string) "name" "permessage-deflate" ext.name; 121 + Alcotest.(check int) "params count" 2 (List.length ext.params) 122 + 123 + let test_parse_extensions_with_values () = 124 + let exts = 125 + Handshake.parse_extensions "permessage-deflate; client_max_window_bits=15" 126 + in 127 + let ext = List.hd exts in 128 + Alcotest.(check string) "name" "permessage-deflate" ext.name; 129 + match ext.params with 130 + | [ (key, Some value) ] -> 131 + Alcotest.(check string) "param key" "client_max_window_bits" key; 132 + Alcotest.(check string) "param value" "15" value 133 + | _ -> Alcotest.fail "Expected one param with value" 134 + 135 + let test_parse_extensions_multiple () = 136 + let exts = Handshake.parse_extensions "permessage-deflate, x-custom" in 137 + Alcotest.(check int) "count" 2 (List.length exts); 138 + Alcotest.(check bool) 139 + "has permessage-deflate" true 140 + (Handshake.has_extension ~name:"permessage-deflate" exts); 141 + Alcotest.(check bool) 142 + "has x-custom" true 143 + (Handshake.has_extension ~name:"x-custom" exts) 144 + 145 + let test_extensions_to_string () = 146 + let exts = 147 + [ 148 + { 149 + Handshake.name = "permessage-deflate"; 150 + params = [ ("client_max_window_bits", None) ]; 151 + }; 152 + ] 153 + in 154 + let s = Handshake.extensions_to_string exts in 155 + Alcotest.(check string) 156 + "to_string" "permessage-deflate; client_max_window_bits" s 157 + 158 + let test_get_extension_params () = 159 + let exts = 160 + Handshake.parse_extensions "permessage-deflate; client_max_window_bits=15" 161 + in 162 + match Handshake.extension_params ~name:"permessage-deflate" exts with 163 + | Some params -> Alcotest.(check int) "params count" 1 (List.length params) 164 + | None -> Alcotest.fail "Expected Some params" 165 + 166 + (** {1 Upgrade Headers Tests} *) 167 + 168 + let test_make_upgrade_headers_basic () = 169 + let key = "dGhlIHNhbXBsZSBub25jZQ==" in 170 + let headers = Handshake.upgrade_headers ~key () in 171 + Alcotest.(check (option string)) 172 + "Upgrade" (Some "websocket") 173 + (Headers.find `Upgrade headers); 174 + Alcotest.(check (option string)) 175 + "Connection" (Some "Upgrade") 176 + (Headers.find `Connection headers); 177 + Alcotest.(check (option string)) 178 + "Sec-WebSocket-Key" (Some key) 179 + (Headers.find `Sec_websocket_key headers); 180 + Alcotest.(check (option string)) 181 + "Sec-WebSocket-Version" (Some "13") 182 + (Headers.find `Sec_websocket_version headers) 183 + 184 + let test_upgrade_headers_with_protocols () = 185 + let key = Handshake.generate_key () in 186 + let headers = 187 + Handshake.upgrade_headers ~key 188 + ~protocols:[ "graphql-ws"; "graphql-transport-ws" ] 189 + () 190 + in 191 + match Headers.find `Sec_websocket_protocol headers with 192 + | Some proto -> 193 + Alcotest.(check bool) 194 + "contains graphql-ws" true 195 + (string_contains ~affix:"graphql-ws" proto) 196 + | None -> Alcotest.fail "Expected Sec-WebSocket-Protocol header" 197 + 198 + let test_upgrade_headers_with_origin () = 199 + let key = Handshake.generate_key () in 200 + let headers = 201 + Handshake.upgrade_headers ~key ~origin:"https://example.com" () 202 + in 203 + Alcotest.(check (option string)) 204 + "Origin" (Some "https://example.com") 205 + (Headers.find `Origin headers) 206 + 207 + (** {1 Upgrade Response Validation Tests} *) 208 + 209 + let test_validate_response_success () = 210 + let key = "dGhlIHNhbXBsZSBub25jZQ==" in 211 + let accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" in 212 + let headers = 213 + Headers.empty 214 + |> Headers.set `Upgrade "websocket" 215 + |> Headers.set `Connection "Upgrade" 216 + |> Headers.set `Sec_websocket_accept accept 217 + in 218 + let result = Handshake.validate_upgrade_response ~key ~status:101 ~headers in 219 + match result with 220 + | Ok () -> Alcotest.(check pass) "valid response" () () 221 + | Error msg -> Alcotest.fail msg 222 + 223 + let test_validate_response_wrong_status () = 224 + let key = Handshake.generate_key () in 225 + let headers = Headers.empty in 226 + let result = Handshake.validate_upgrade_response ~key ~status:200 ~headers in 227 + match result with 228 + | Error msg -> 229 + Alcotest.(check bool) 230 + "mentions 101" true 231 + (string_contains ~affix:"101" msg) 232 + | Ok () -> Alcotest.fail "Expected error for wrong status" 233 + 234 + let test_validate_response_missing_upgrade () = 235 + let key = Handshake.generate_key () in 236 + let headers = Headers.empty |> Headers.set `Connection "Upgrade" in 237 + let result = Handshake.validate_upgrade_response ~key ~status:101 ~headers in 238 + match result with 239 + | Error msg -> 240 + Alcotest.(check bool) 241 + "mentions Upgrade" true 242 + (string_contains ~affix:"Upgrade" msg) 243 + | Ok () -> Alcotest.fail "Expected error for missing Upgrade" 244 + 245 + let test_validate_response_wrong_accept () = 246 + let key = "dGhlIHNhbXBsZSBub25jZQ==" in 247 + let headers = 248 + Headers.empty 249 + |> Headers.set `Upgrade "websocket" 250 + |> Headers.set `Connection "Upgrade" 251 + |> Headers.set `Sec_websocket_accept "wrongvalue" 252 + in 253 + let result = Handshake.validate_upgrade_response ~key ~status:101 ~headers in 254 + match result with 255 + | Error msg -> 256 + Alcotest.(check bool) 257 + "mentions accept" true 258 + (string_contains ~affix:"Accept" msg 259 + || string_contains ~affix:"accept" msg) 260 + | Ok () -> Alcotest.fail "Expected error for wrong accept" 261 + 262 + (** {1 Test Suite} *) 263 + 264 + let suite = 265 + ( "handshake", 266 + [ 267 + Alcotest.test_case "key length" `Quick test_generate_key_length; 268 + Alcotest.test_case "keys unique" `Quick test_generate_key_unique; 269 + Alcotest.test_case "valid base64" `Quick test_generate_key_valid_base64; 270 + Alcotest.test_case "RFC example" `Quick test_compute_accept_rfc_example; 271 + Alcotest.test_case "validate correct" `Quick test_validate_accept_correct; 272 + Alcotest.test_case "validate incorrect" `Quick 273 + test_validate_accept_incorrect; 274 + Alcotest.test_case "roundtrip" `Quick test_compute_validate_roundtrip; 275 + Alcotest.test_case "parse basic" `Quick test_parse_protocols_basic; 276 + Alcotest.test_case "parse single" `Quick test_parse_protocols_single; 277 + Alcotest.test_case "parse empty" `Quick test_parse_protocols_empty; 278 + Alcotest.test_case "to string" `Quick test_protocols_to_string; 279 + Alcotest.test_case "select match" `Quick test_select_protocol_match; 280 + Alcotest.test_case "select no match" `Quick test_select_protocol_no_match; 281 + Alcotest.test_case "basic" `Quick test_parse_extensions_basic; 282 + Alcotest.test_case "with params" `Quick test_parse_extensions_with_params; 283 + Alcotest.test_case "with values" `Quick test_parse_extensions_with_values; 284 + Alcotest.test_case "multiple" `Quick test_parse_extensions_multiple; 285 + Alcotest.test_case "to string" `Quick test_extensions_to_string; 286 + Alcotest.test_case "get params" `Quick test_get_extension_params; 287 + Alcotest.test_case "basic headers" `Quick test_make_upgrade_headers_basic; 288 + Alcotest.test_case "with protocols" `Quick 289 + test_upgrade_headers_with_protocols; 290 + Alcotest.test_case "with origin" `Quick test_upgrade_headers_with_origin; 291 + Alcotest.test_case "success" `Quick test_validate_response_success; 292 + Alcotest.test_case "wrong status" `Quick 293 + test_validate_response_wrong_status; 294 + Alcotest.test_case "missing Upgrade" `Quick 295 + test_validate_response_missing_upgrade; 296 + Alcotest.test_case "wrong accept" `Quick 297 + test_validate_response_wrong_accept; 298 + ] )