this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: onnxrt library — OCaml bindings to ONNX Runtime Web

Type-safe OCaml bindings to onnxruntime-web for browser-based ML inference
via js_of_ocaml or wasm_of_ocaml. Supports WebAssembly (CPU) and WebGPU
(GPU) execution providers.

Modules:
- Dtype: GADT tensor element types (float32, int32, uint8, etc.)
- Tensor: create from Bigarray (zero-copy), read back, GPU download, dispose
- Session: load ONNX models, run inference, manage lifecycle
- Env: configure WASM threads/SIMD/paths and WebGPU power preference

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

+1812
+4
.gitignore
··· 1 + _build/ 2 + *.install 3 + *.merlin 4 + .merlin
+805
docs/plans/2026-03-04-onnxrt-implementation.md
··· 1 + # Onnxrt Implementation Plan 2 + 3 + > **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. 4 + 5 + **Goal:** Implement `onnxrt`, an OCaml library providing type-safe bindings to ONNX Runtime Web for browser-based ML inference via js_of_ocaml. 6 + 7 + **Architecture:** Two-layer design. An internal `Promise_lwt` helper bridges JS Promises to Lwt. The public `Onnxrt` module exposes pure OCaml types (Bigarray, Lwt.t) and uses `Js.Unsafe` internally to call the `onnxruntime-web` JavaScript API. No `Js.t` types in the public API. 8 + 9 + **Tech Stack:** OCaml 5.2+, js_of_ocaml 5.8+, Lwt, dune 3.17, onnxruntime-web (npm, loaded externally) 10 + 11 + **Key JS API surface being bound:** 12 + - `ort.env.wasm.*` / `ort.env.webgpu.*` — global config 13 + - `new ort.Tensor(type, data, dims)` — tensor construction 14 + - `ort.InferenceSession.create(model, options)` — returns Promise<session> 15 + - `session.run(feeds)` — returns Promise<results> 16 + - `session.inputNames` / `session.outputNames` — string arrays 17 + - `session.release()` — returns Promise<void> 18 + - `tensor.data` / `tensor.dims` / `tensor.type` / `tensor.size` — properties 19 + - `tensor.location` — "cpu" | "gpu-buffer" 20 + - `tensor.getData()` — returns Promise<TypedArray> (GPU download) 21 + - `tensor.dispose()` — void 22 + 23 + --- 24 + 25 + ### Task 1: Promise_lwt bridge helper 26 + 27 + The ONNX API is entirely Promise-based. We need a minimal helper to convert 28 + JS Promises to Lwt threads. This is an internal module, not exposed publicly. 29 + 30 + **Files:** 31 + - Create: `lib/promise_lwt.ml` 32 + - Create: `lib/promise_lwt.mli` 33 + 34 + **Step 1: Write `lib/promise_lwt.mli`** 35 + 36 + ```ocaml 37 + (** Internal: bridge JavaScript Promises to Lwt. 38 + 39 + Not part of the public API. *) 40 + 41 + val to_lwt : 'a Js_of_ocaml.Js.t -> 'a Lwt.t 42 + (** [to_lwt js_promise] converts a JavaScript Promise to an Lwt thread. 43 + If the Promise rejects, the Lwt thread fails with [Failure msg]. *) 44 + ``` 45 + 46 + **Step 2: Write `lib/promise_lwt.ml`** 47 + 48 + ```ocaml 49 + open Js_of_ocaml 50 + 51 + let to_lwt (promise : 'a Js.t) : 'a Lwt.t = 52 + let lwt_promise, resolver = Lwt.wait () in 53 + let on_resolve result = Lwt.wakeup resolver result in 54 + let on_reject error = 55 + let msg = 56 + Js.Opt.case 57 + (Js.Unsafe.meth_call error "toString" [||] : Js.js_string Js.t Js.Opt.t) 58 + (fun () -> "unknown error") 59 + Js.to_string 60 + in 61 + Lwt.wakeup_exn resolver (Failure msg) 62 + in 63 + let _ignored : 'b Js.t = 64 + Js.Unsafe.meth_call promise "then" 65 + [| Js.Unsafe.inject (Js.wrap_callback on_resolve); 66 + Js.Unsafe.inject (Js.wrap_callback on_reject) |] 67 + in 68 + lwt_promise 69 + ``` 70 + 71 + **Step 3: Verify it compiles** 72 + 73 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 74 + Expected: Build succeeds (the modules are internal, linked into the library) 75 + 76 + **Step 4: Commit** 77 + 78 + ``` 79 + feat: add internal Promise_lwt bridge 80 + ``` 81 + 82 + --- 83 + 84 + ### Task 2: Dtype module 85 + 86 + Pure OCaml, no JS interop. Implements the GADT and conversion functions. 87 + 88 + **Files:** 89 + - Create: `lib/onnxrt.ml` (start with Dtype module only) 90 + 91 + **Step 1: Write the Dtype implementation in `lib/onnxrt.ml`** 92 + 93 + ```ocaml 94 + module Dtype = struct 95 + type ('ocaml, 'elt) t = 96 + | Float32 : (float, Bigarray.float32_elt) t 97 + | Float64 : (float, Bigarray.float64_elt) t 98 + | Int8 : (int, Bigarray.int8_signed_elt) t 99 + | Uint8 : (int, Bigarray.int8_unsigned_elt) t 100 + | Int16 : (int, Bigarray.int16_signed_elt) t 101 + | Uint16 : (int, Bigarray.int16_unsigned_elt) t 102 + | Int32 : (int32, Bigarray.int32_elt) t 103 + 104 + type packed = Pack : ('ocaml, 'elt) t -> packed 105 + 106 + let to_string : type a b. (a, b) t -> string = function 107 + | Float32 -> "float32" 108 + | Float64 -> "float64" 109 + | Int8 -> "int8" 110 + | Uint8 -> "uint8" 111 + | Int16 -> "int16" 112 + | Uint16 -> "uint16" 113 + | Int32 -> "int32" 114 + 115 + let of_string = function 116 + | "float32" -> Some (Pack Float32) 117 + | "float64" -> Some (Pack Float64) 118 + | "int8" -> Some (Pack Int8) 119 + | "uint8" -> Some (Pack Uint8) 120 + | "int16" -> Some (Pack Int16) 121 + | "uint16" -> Some (Pack Uint16) 122 + | "int32" -> Some (Pack Int32) 123 + | _ -> None 124 + 125 + let equal : type a b c d. (a, b) t -> (c, d) t -> bool = 126 + fun a b -> 127 + match (a, b) with 128 + | Float32, Float32 -> true 129 + | Float64, Float64 -> true 130 + | Int8, Int8 -> true 131 + | Uint8, Uint8 -> true 132 + | Int16, Int16 -> true 133 + | Uint16, Uint16 -> true 134 + | Int32, Int32 -> true 135 + | _ -> false 136 + 137 + (* Internal: return the Bigarray kind for a dtype *) 138 + let to_bigarray_kind : type a b. (a, b) t -> (a, b) Bigarray.kind = function 139 + | Float32 -> Bigarray.float32 140 + | Float64 -> Bigarray.float64 141 + | Int8 -> Bigarray.int8_signed 142 + | Uint8 -> Bigarray.int8_unsigned 143 + | Int16 -> Bigarray.int16_signed 144 + | Uint16 -> Bigarray.int16_unsigned 145 + | Int32 -> Bigarray.int32 146 + 147 + (* Internal: return the JS TypedArray constructor name *) 148 + let typed_array_name : type a b. (a, b) t -> string = function 149 + | Float32 -> "Float32Array" 150 + | Float64 -> "Float64Array" 151 + | Int8 -> "Int8Array" 152 + | Uint8 -> "Uint8Array" 153 + | Int16 -> "Int16Array" 154 + | Uint16 -> "Uint16Array" 155 + | Int32 -> "Int32Array" 156 + end 157 + ``` 158 + 159 + **Step 2: Add stub modules so the .ml satisfies the .mli** 160 + 161 + Append these stubs to `lib/onnxrt.ml` so dune can type-check against the .mli: 162 + 163 + ```ocaml 164 + module Tensor = struct 165 + type t = { js_tensor : 'a. 'a } [@@warning "-37"] 166 + type location = Cpu | Gpu_buffer 167 + let of_bigarray1 _ _ ~dims:_ = assert false 168 + let of_bigarray _ _ = assert false 169 + let of_float32s _ ~dims:_ = assert false 170 + let to_bigarray1_exn _ _ = assert false 171 + let to_bigarray_exn _ _ = assert false 172 + let download _ _ = assert false 173 + let dims _ = assert false 174 + let dtype _ = assert false 175 + let size _ = assert false 176 + let location _ = assert false 177 + let dispose _ = assert false 178 + end 179 + 180 + module Execution_provider = struct 181 + type t = Wasm | Webgpu 182 + let to_string = function Wasm -> "wasm" | Webgpu -> "webgpu" 183 + end 184 + 185 + type output_location = Cpu | Gpu_buffer 186 + type graph_optimization = Disabled | Basic | Extended | All 187 + 188 + module Session = struct 189 + type t = { js_session : 'a. 'a } [@@warning "-37"] 190 + let create ?execution_providers:_ ?graph_optimization:_ ?preferred_output_location:_ ?log_level:_ _ () = assert false 191 + let create_from_buffer ?execution_providers:_ ?graph_optimization:_ ?preferred_output_location:_ ?log_level:_ _ () = assert false 192 + let run _ _ = assert false 193 + let run_with_outputs _ _ ~output_names:_ = assert false 194 + let input_names _ = assert false 195 + let output_names _ = assert false 196 + let release _ = assert false 197 + end 198 + 199 + module Env = struct 200 + module Wasm = struct 201 + let set_num_threads _ = assert false 202 + let set_simd _ = assert false 203 + let set_proxy _ = assert false 204 + let set_wasm_paths _ = assert false 205 + end 206 + module Webgpu = struct 207 + let set_power_preference _ = assert false 208 + end 209 + end 210 + ``` 211 + 212 + **Step 3: Verify it compiles** 213 + 214 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 215 + Expected: Build succeeds. All stubs satisfy the .mli signatures. 216 + 217 + **Step 4: Commit** 218 + 219 + ``` 220 + feat: add Dtype implementation and skeleton stubs 221 + ``` 222 + 223 + --- 224 + 225 + ### Task 3: Internal JS helpers 226 + 227 + Shared helpers for accessing the `ort` global object and converting between 228 + OCaml and JS types used across Tensor, Session, and Env modules. 229 + 230 + **Files:** 231 + - Create: `lib/js_helpers.ml` 232 + 233 + **Step 1: Write `lib/js_helpers.ml`** 234 + 235 + ```ocaml 236 + (** Internal JS interop helpers. Not part of the public API. *) 237 + 238 + open Js_of_ocaml 239 + 240 + (** Access the global [ort] object (onnxruntime-web). *) 241 + let ort () : 'a Js.t = 242 + let o = Js.Unsafe.global##.ort in 243 + if Js.Optdef.test o then (Js.Unsafe.coerce o : 'a Js.t) 244 + else failwith "onnxruntime-web is not loaded: global 'ort' object not found" 245 + 246 + (** Convert an OCaml string list to a JS array of JS strings. *) 247 + let js_string_array (strs : string list) : Js.js_string Js.t Js.js_array Js.t = 248 + Js.array (Array.of_list (List.map Js.string strs)) 249 + 250 + (** Convert a JS array of JS strings to an OCaml string list. *) 251 + let string_list_of_js_array (arr : Js.js_string Js.t Js.js_array Js.t) : string list = 252 + Array.to_list (Array.map Js.to_string (Js.to_array arr)) 253 + 254 + (** Convert an OCaml int array to a JS array of ints. *) 255 + let js_int_array (dims : int array) : int Js.js_array Js.t = 256 + Js.array dims 257 + 258 + (** Convert a JS array of ints to an OCaml int array. *) 259 + let int_array_of_js (arr : int Js.js_array Js.t) : int array = 260 + Js.to_array arr 261 + 262 + (** Read a string property from a JS object. *) 263 + let get_string (obj : 'a Js.t) (key : string) : string = 264 + Js.to_string (Js.Unsafe.get obj (Js.string key)) 265 + 266 + (** Read an int property from a JS object. *) 267 + let get_int (obj : 'a Js.t) (key : string) : int = 268 + Js.Unsafe.get obj (Js.string key) 269 + 270 + (** Set a property on a JS object. *) 271 + let set (obj : 'a Js.t) (key : string) (value : 'b) : unit = 272 + Js.Unsafe.set obj (Js.string key) value 273 + 274 + (** Get a nested property: obj.key1.key2 *) 275 + let get_nested (obj : 'a Js.t) (key1 : string) (key2 : string) : 'b Js.t = 276 + Js.Unsafe.get (Js.Unsafe.get obj (Js.string key1)) (Js.string key2) 277 + ``` 278 + 279 + **Step 2: Verify it compiles** 280 + 281 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 282 + Expected: Build succeeds 283 + 284 + **Step 3: Commit** 285 + 286 + ``` 287 + feat: add internal JS interop helpers 288 + ``` 289 + 290 + --- 291 + 292 + ### Task 4: Env module implementation 293 + 294 + Replace the Env stubs with real implementations that set properties on the 295 + global `ort.env` object. 296 + 297 + **Files:** 298 + - Modify: `lib/onnxrt.ml` — replace Env module 299 + 300 + **Step 1: Replace the Env stub** 301 + 302 + Replace the `module Env = struct ... end` block in `onnxrt.ml` with: 303 + 304 + ```ocaml 305 + module Env = struct 306 + module Wasm = struct 307 + let set_num_threads n = 308 + let env = Js_helpers.ort () in 309 + Js_helpers.set 310 + (Js_helpers.get_nested env "env" "wasm") 311 + "numThreads" n 312 + 313 + let set_simd enabled = 314 + let env = Js_helpers.ort () in 315 + Js_helpers.set 316 + (Js_helpers.get_nested env "env" "wasm") 317 + "simd" (Js_of_ocaml.Js.bool enabled) 318 + 319 + let set_proxy enabled = 320 + let env = Js_helpers.ort () in 321 + Js_helpers.set 322 + (Js_helpers.get_nested env "env" "wasm") 323 + "proxy" (Js_of_ocaml.Js.bool enabled) 324 + 325 + let set_wasm_paths prefix = 326 + let env = Js_helpers.ort () in 327 + Js_helpers.set 328 + (Js_helpers.get_nested env "env" "wasm") 329 + "wasmPaths" (Js_of_ocaml.Js.string prefix) 330 + end 331 + 332 + module Webgpu = struct 333 + let set_power_preference pref = 334 + let env = Js_helpers.ort () in 335 + let s = match pref with 336 + | `High_performance -> "high-performance" 337 + | `Low_power -> "low-power" 338 + in 339 + Js_helpers.set 340 + (Js_helpers.get_nested env "env" "webgpu") 341 + "powerPreference" (Js_of_ocaml.Js.string s) 342 + end 343 + end 344 + ``` 345 + 346 + **Step 2: Verify it compiles** 347 + 348 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 349 + Expected: Build succeeds 350 + 351 + **Step 3: Commit** 352 + 353 + ``` 354 + feat: implement Env module (WASM and WebGPU configuration) 355 + ``` 356 + 357 + --- 358 + 359 + ### Task 5: Execution_provider and top-level types 360 + 361 + These are trivial but let's make sure the real implementations are in place. 362 + 363 + **Files:** 364 + - Modify: `lib/onnxrt.ml` — replace stubs 365 + 366 + **Step 1: Replace Execution_provider and type stubs** 367 + 368 + The Execution_provider stub from Task 2 is already correct. Verify the types 369 + `output_location` and `graph_optimization` are also correct (they are pure 370 + OCaml types with no JS interop). No changes needed — these are already final. 371 + 372 + Add an internal helper for converting `graph_optimization` and `output_location` 373 + to JS strings, used by Session: 374 + 375 + ```ocaml 376 + (* Internal: place after the type definitions, before Session module *) 377 + 378 + let graph_optimization_to_string = function 379 + | Disabled -> "disabled" 380 + | Basic -> "basic" 381 + | Extended -> "extended" 382 + | All -> "all" 383 + 384 + let output_location_to_js = function 385 + | Cpu -> Js_of_ocaml.Js.string "cpu" 386 + | Gpu_buffer -> Js_of_ocaml.Js.string "gpu-buffer" 387 + 388 + let log_level_to_string = function 389 + | `Verbose -> "verbose" 390 + | `Info -> "info" 391 + | `Warning -> "warning" 392 + | `Error -> "error" 393 + | `Fatal -> "fatal" 394 + ``` 395 + 396 + **Step 2: Verify it compiles** 397 + 398 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 399 + Expected: Build succeeds 400 + 401 + **Step 3: Commit** 402 + 403 + ``` 404 + feat: add internal conversion helpers for session options 405 + ``` 406 + 407 + --- 408 + 409 + ### Task 6: Tensor module implementation 410 + 411 + The core of the bindings. Creates and reads ONNX tensors by constructing 412 + `new ort.Tensor(type, typedArray, dims)` via `Js.Unsafe`. 413 + 414 + **Files:** 415 + - Modify: `lib/onnxrt.ml` — replace Tensor module 416 + 417 + **Step 1: Replace the Tensor stub** 418 + 419 + Replace the entire `module Tensor = struct ... end` block with: 420 + 421 + ```ocaml 422 + module Tensor = struct 423 + type t = { 424 + js_tensor : Js_of_ocaml.Js.Unsafe.any; 425 + mutable disposed : bool; 426 + } 427 + 428 + type location = Cpu | Gpu_buffer 429 + 430 + let check_not_disposed t = 431 + if t.disposed then invalid_arg "Tensor has been disposed" 432 + 433 + let check_cpu t = 434 + check_not_disposed t; 435 + let loc = Js_helpers.get_string 436 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in 437 + if loc <> "cpu" then 438 + invalid_arg "Tensor data is on GPU; use Tensor.download first" 439 + 440 + let check_dtype : type a b. (a, b) Dtype.t -> t -> unit = 441 + fun expected t -> 442 + let actual_str = Js_helpers.get_string 443 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in 444 + let expected_str = Dtype.to_string expected in 445 + if actual_str <> expected_str then 446 + failwith (Printf.sprintf "Dtype mismatch: tensor is %s, expected %s" 447 + actual_str expected_str) 448 + 449 + (* Create a JS TypedArray from a Bigarray *) 450 + let typed_array_of_bigarray : 451 + type a b. (a, b) Dtype.t -> 452 + (a, b, Bigarray.c_layout) Bigarray.Array1.t -> 453 + Js_of_ocaml.Js.Unsafe.any = 454 + fun dtype ba -> 455 + let open Js_of_ocaml in 456 + let ga = Bigarray.genarray_of_array1 ba in 457 + let ta = Typed_array.from_genarray ga in 458 + Js.Unsafe.coerce ta 459 + 460 + let of_bigarray1 : 461 + type a b. (a, b) Dtype.t -> 462 + (a, b, Bigarray.c_layout) Bigarray.Array1.t -> 463 + dims:int array -> t = 464 + fun dtype ba ~dims -> 465 + let expected_size = Array.fold_left ( * ) 1 dims in 466 + let actual_size = Bigarray.Array1.dim ba in 467 + if expected_size <> actual_size then 468 + invalid_arg (Printf.sprintf 469 + "Tensor.of_bigarray1: dims product (%d) <> bigarray length (%d)" 470 + expected_size actual_size); 471 + let open Js_of_ocaml in 472 + let ta = typed_array_of_bigarray dtype ba in 473 + let js_tensor = 474 + Js.Unsafe.new_obj 475 + (Js.Unsafe.get (Js_helpers.ort ()) (Js.string "Tensor")) 476 + [| Js.Unsafe.inject (Js.string (Dtype.to_string dtype)); 477 + ta; 478 + Js.Unsafe.inject (Js_helpers.js_int_array dims) |] 479 + in 480 + { js_tensor = Js.Unsafe.coerce js_tensor; disposed = false } 481 + 482 + let of_bigarray : 483 + type a b. (a, b) Dtype.t -> 484 + (a, b, Bigarray.c_layout) Bigarray.Genarray.t -> t = 485 + fun dtype ga -> 486 + let dims = Bigarray.Genarray.dims ga in 487 + let flat = Bigarray.reshape_1 ga (Array.fold_left ( * ) 1 dims) in 488 + of_bigarray1 dtype flat ~dims 489 + 490 + let of_float32s data ~dims = 491 + let expected_size = Array.fold_left ( * ) 1 dims in 492 + if Array.length data <> expected_size then 493 + invalid_arg (Printf.sprintf 494 + "Tensor.of_float32s: array length (%d) <> dims product (%d)" 495 + (Array.length data) expected_size); 496 + let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout 497 + expected_size in 498 + Array.iteri (fun i v -> Bigarray.Array1.set ba i v) data; 499 + of_bigarray1 Float32 ba ~dims 500 + 501 + let to_bigarray1_exn : 502 + type a b. (a, b) Dtype.t -> t -> 503 + (a, b, Bigarray.c_layout) Bigarray.Array1.t = 504 + fun dtype t -> 505 + check_cpu t; 506 + check_dtype dtype t; 507 + let open Js_of_ocaml in 508 + let data : Js.Unsafe.any = Js.Unsafe.get t.js_tensor (Js.string "data") in 509 + let ta = (Js.Unsafe.coerce data : Typed_array.arrayBufferView Js.t) in 510 + let ga = Typed_array.to_genarray ta in 511 + let size = Bigarray.Genarray.nth_dim ga 0 in 512 + let ba = Bigarray.reshape_1 ga size in 513 + (* The Bigarray kind from Typed_array.to_genarray matches the JS typed array. 514 + We need to coerce it to match the expected dtype. The check_dtype call 515 + above ensures this is safe. *) 516 + (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t) 517 + 518 + let to_bigarray_exn : 519 + type a b. (a, b) Dtype.t -> t -> 520 + (a, b, Bigarray.c_layout) Bigarray.Genarray.t = 521 + fun dtype t -> 522 + let flat = to_bigarray1_exn dtype t in 523 + let dims_js : Js_of_ocaml.Js.Unsafe.any = 524 + Js_of_ocaml.Js.Unsafe.get t.js_tensor (Js_of_ocaml.Js.string "dims") in 525 + let dims = Js_helpers.int_array_of_js (Js_of_ocaml.Js.Unsafe.coerce dims_js) in 526 + Bigarray.genarray_of_array1 flat |> fun ga -> Bigarray.reshape ga dims 527 + 528 + let download : 529 + type a b. (a, b) Dtype.t -> t -> 530 + (a, b, Bigarray.c_layout) Bigarray.Array1.t Lwt.t = 531 + fun dtype t -> 532 + check_not_disposed t; 533 + check_dtype dtype t; 534 + let open Js_of_ocaml in 535 + let promise = Js.Unsafe.meth_call t.js_tensor "getData" [||] in 536 + let open Lwt.Syntax in 537 + let+ data = Promise_lwt.to_lwt promise in 538 + let ta = (Js.Unsafe.coerce data : Typed_array.arrayBufferView Js.t) in 539 + let ga = Typed_array.to_genarray ta in 540 + let size = Bigarray.Genarray.nth_dim ga 0 in 541 + let ba = Bigarray.reshape_1 ga size in 542 + (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t) 543 + 544 + let dims t = 545 + check_not_disposed t; 546 + let open Js_of_ocaml in 547 + let dims_js = Js.Unsafe.get t.js_tensor (Js.string "dims") in 548 + Js_helpers.int_array_of_js (Js.Unsafe.coerce dims_js) 549 + 550 + let dtype t = 551 + check_not_disposed t; 552 + let type_str = Js_helpers.get_string 553 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in 554 + match Dtype.of_string type_str with 555 + | Some p -> p 556 + | None -> failwith (Printf.sprintf "Unknown tensor dtype: %s" type_str) 557 + 558 + let size t = 559 + check_not_disposed t; 560 + Js_helpers.get_int (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "size" 561 + 562 + let location t = 563 + check_not_disposed t; 564 + let loc = Js_helpers.get_string 565 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in 566 + match loc with 567 + | "cpu" -> Cpu 568 + | "gpu-buffer" -> Gpu_buffer 569 + | s -> failwith (Printf.sprintf "Unknown tensor location: %s" s) 570 + 571 + let dispose t = 572 + if not t.disposed then begin 573 + let open Js_of_ocaml in 574 + ignore (Js.Unsafe.meth_call t.js_tensor "dispose" [||] : Js.Unsafe.any); 575 + t.disposed <- true 576 + end 577 + end 578 + ``` 579 + 580 + **Step 2: Verify it compiles** 581 + 582 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 583 + Expected: Build succeeds 584 + 585 + **Step 3: Commit** 586 + 587 + ``` 588 + feat: implement Tensor module with CPU and GPU support 589 + ``` 590 + 591 + --- 592 + 593 + ### Task 7: Session module implementation 594 + 595 + Creates inference sessions, runs models, manages lifecycle. 596 + 597 + **Files:** 598 + - Modify: `lib/onnxrt.ml` — replace Session module 599 + 600 + **Step 1: Replace the Session stub** 601 + 602 + Replace the entire `module Session = struct ... end` block with: 603 + 604 + ```ocaml 605 + module Session = struct 606 + type t = { 607 + js_session : Js_of_ocaml.Js.Unsafe.any; 608 + input_names_ : string list; 609 + output_names_ : string list; 610 + } 611 + 612 + let build_options ?execution_providers ?graph_optimization 613 + ?preferred_output_location ?log_level () = 614 + let open Js_of_ocaml in 615 + let pairs = ref [] in 616 + (match execution_providers with 617 + | Some eps -> 618 + let js_eps = Js.array (Array.of_list 619 + (List.map (fun ep -> 620 + Js.Unsafe.inject (Js.string (Execution_provider.to_string ep))) 621 + eps)) in 622 + pairs := ("executionProviders", Js.Unsafe.inject js_eps) :: !pairs 623 + | None -> ()); 624 + (match graph_optimization with 625 + | Some go -> 626 + pairs := ("graphOptimizationLevel", 627 + Js.Unsafe.inject (Js.string (graph_optimization_to_string go))) 628 + :: !pairs 629 + | None -> ()); 630 + (match preferred_output_location with 631 + | Some loc -> 632 + pairs := ("preferredOutputLocation", 633 + Js.Unsafe.inject (output_location_to_js loc)) 634 + :: !pairs 635 + | None -> ()); 636 + (match log_level with 637 + | Some level -> 638 + pairs := ("logSeverityLevel", 639 + Js.Unsafe.inject (Js.string (log_level_to_string level))) 640 + :: !pairs 641 + | None -> ()); 642 + Js.Unsafe.obj (Array.of_list !pairs) 643 + 644 + let wrap_session js_session = 645 + let open Js_of_ocaml in 646 + let input_names_ = Js_helpers.string_list_of_js_array 647 + (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "inputNames"))) in 648 + let output_names_ = Js_helpers.string_list_of_js_array 649 + (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "outputNames"))) in 650 + { js_session = Js.Unsafe.coerce js_session; input_names_; output_names_ } 651 + 652 + let create ?execution_providers ?graph_optimization 653 + ?preferred_output_location ?log_level model_url () = 654 + let open Js_of_ocaml in 655 + let ort = Js_helpers.ort () in 656 + let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in 657 + let options = build_options ?execution_providers ?graph_optimization 658 + ?preferred_output_location ?log_level () in 659 + let promise = Js.Unsafe.meth_call inference_session "create" 660 + [| Js.Unsafe.inject (Js.string model_url); 661 + Js.Unsafe.inject options |] in 662 + let open Lwt.Syntax in 663 + let+ js_session = Promise_lwt.to_lwt promise in 664 + wrap_session js_session 665 + 666 + let create_from_buffer (type a b) ?execution_providers ?graph_optimization 667 + ?preferred_output_location ?log_level 668 + (buffer : (a, b, Bigarray.c_layout) Bigarray.Array1.t) () = 669 + let open Js_of_ocaml in 670 + let ort = Js_helpers.ort () in 671 + let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in 672 + let options = build_options ?execution_providers ?graph_optimization 673 + ?preferred_output_location ?log_level () in 674 + (* Convert bigarray to Uint8Array via ArrayBuffer *) 675 + let ga = Bigarray.genarray_of_array1 buffer in 676 + let ta = Typed_array.from_genarray ga in 677 + let ab : Typed_array.arrayBuffer Js.t = 678 + Js.Unsafe.get (Js.Unsafe.coerce ta) (Js.string "buffer") in 679 + let uint8 = Js.Unsafe.new_obj 680 + (Js.Unsafe.global##._Uint8Array) 681 + [| Js.Unsafe.inject ab |] in 682 + let promise = Js.Unsafe.meth_call inference_session "create" 683 + [| Js.Unsafe.inject uint8; 684 + Js.Unsafe.inject options |] in 685 + let open Lwt.Syntax in 686 + let+ js_session = Promise_lwt.to_lwt promise in 687 + wrap_session js_session 688 + 689 + let run t inputs = 690 + let open Js_of_ocaml in 691 + let feeds = Js.Unsafe.obj 692 + (Array.of_list 693 + (List.map (fun (name, (tensor : Tensor.t)) -> 694 + (name, Js.Unsafe.inject tensor.js_tensor)) 695 + inputs)) in 696 + let promise = Js.Unsafe.meth_call t.js_session "run" 697 + [| Js.Unsafe.inject feeds |] in 698 + let open Lwt.Syntax in 699 + let+ results = Promise_lwt.to_lwt promise in 700 + List.map (fun name -> 701 + let js_tensor = Js.Unsafe.get results (Js.string name) in 702 + (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor; 703 + disposed = false })) 704 + t.output_names_ 705 + 706 + let run_with_outputs t inputs ~output_names = 707 + let open Js_of_ocaml in 708 + let feeds = Js.Unsafe.obj 709 + (Array.of_list 710 + (List.map (fun (name, (tensor : Tensor.t)) -> 711 + (name, Js.Unsafe.inject tensor.js_tensor)) 712 + inputs)) in 713 + let promise = Js.Unsafe.meth_call t.js_session "run" 714 + [| Js.Unsafe.inject feeds |] in 715 + let open Lwt.Syntax in 716 + let+ results = Promise_lwt.to_lwt promise in 717 + List.map (fun name -> 718 + let js_tensor = Js.Unsafe.get results (Js.string name) in 719 + (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor; 720 + disposed = false })) 721 + output_names 722 + 723 + let input_names t = t.input_names_ 724 + let output_names t = t.output_names_ 725 + 726 + let release t = 727 + let open Js_of_ocaml in 728 + let promise = Js.Unsafe.meth_call t.js_session "release" [||] in 729 + Promise_lwt.to_lwt promise |> Lwt.map (fun (_ : Js.Unsafe.any) -> ()) 730 + end 731 + ``` 732 + 733 + **Step 2: Verify it compiles** 734 + 735 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 736 + Expected: Build succeeds 737 + 738 + **Step 3: Commit** 739 + 740 + ``` 741 + feat: implement Session module (create, run, release) 742 + ``` 743 + 744 + --- 745 + 746 + ### Task 8: Final build verification and cleanup 747 + 748 + Make sure the full library compiles cleanly with no warnings. 749 + 750 + **Files:** 751 + - Review: `lib/onnxrt.ml` 752 + - Review: `lib/onnxrt.mli` 753 + 754 + **Step 1: Full build with warnings enabled** 755 + 756 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune build 2>&1` 757 + Expected: Clean build, no warnings 758 + 759 + **Step 2: Check module structure** 760 + 761 + Run: `cd /home/jons-agent/workspace/onnxrt && opam exec -- dune describe pp lib/onnxrt.ml 2>&1 | head -20` 762 + Expected: Shows the preprocessed output with js_of_ocaml-ppx applied 763 + 764 + **Step 3: Commit** 765 + 766 + ``` 767 + feat: onnxrt library complete — ONNX Runtime Web bindings for OCaml 768 + ``` 769 + 770 + --- 771 + 772 + ## File Summary 773 + 774 + | File | Purpose | 775 + |------|---------| 776 + | `lib/onnxrt.mli` | Public API (already written) | 777 + | `lib/onnxrt.ml` | Implementation of all public modules | 778 + | `lib/promise_lwt.mli` | Internal: Promise→Lwt bridge signature | 779 + | `lib/promise_lwt.ml` | Internal: Promise→Lwt bridge implementation | 780 + | `lib/js_helpers.ml` | Internal: shared JS interop utilities | 781 + | `lib/dune` | Build config | 782 + | `dune-project` | Project metadata and opam generation | 783 + 784 + ## Key Implementation Notes 785 + 786 + 1. **`Js.Unsafe` throughout**: All JS interop uses `Js.Unsafe.get`, `Js.Unsafe.set`, 787 + `Js.Unsafe.meth_call`, `Js.Unsafe.new_obj`, and `Js.Unsafe.obj`. No typed class 788 + bindings. 789 + 790 + 2. **`Obj.magic` in tensor conversion**: `to_bigarray1_exn` uses `Obj.magic` to cast 791 + the Bigarray kind after a runtime dtype check. This is safe because 792 + `Typed_array.to_genarray` returns a bigarray whose kind matches the JS 793 + TypedArray, and `check_dtype` verifies the match. The GADT prevents misuse 794 + at the public API boundary. 795 + 796 + 3. **Promise bridging**: Uses `.then(resolve, reject)` pattern. The reject handler 797 + calls `error.toString()` to extract a message string. 798 + 799 + 4. **Tensor.t record**: Holds `js_tensor` as `Js.Unsafe.any` (erased type) plus a 800 + `disposed` flag. The Session module constructs Tensor.t values directly 801 + via record syntax (same-module access isn't needed since we use the dot 802 + path `Tensor.{ ... }`). 803 + 804 + 5. **`ort` global**: All JS calls go through `Js_helpers.ort()` which checks that 805 + the onnxruntime-web library is loaded before proceeding.
+19
dune-project
··· 1 + (lang dune 3.17) 2 + 3 + (name onnxrt) 4 + 5 + (generate_opam_files true) 6 + 7 + (license ISC) 8 + 9 + (package 10 + (name onnxrt) 11 + (synopsis "OCaml bindings to ONNX Runtime Web for browser-based ML inference") 12 + (description 13 + "Type-safe OCaml bindings to onnxruntime-web, enabling ML model inference in the browser via js_of_ocaml or wasm_of_ocaml. Supports WebAssembly (CPU) and WebGPU (GPU) execution providers.") 14 + (depends 15 + (ocaml (>= 5.2)) 16 + (js_of_ocaml (>= 5.8)) 17 + (js_of_ocaml-ppx (>= 5.8)) 18 + (lwt (>= 5.7)) 19 + (js_of_ocaml-lwt (>= 5.8))))
+5
lib/dune
··· 1 + (library 2 + (public_name onnxrt) 3 + (preprocess 4 + (pps js_of_ocaml-ppx)) 5 + (libraries js_of_ocaml js_of_ocaml-lwt lwt))
+41
lib/js_helpers.ml
··· 1 + (** Internal JS interop helpers. Not part of the public API. *) 2 + 3 + open Js_of_ocaml 4 + 5 + (** Access the global [ort] object (onnxruntime-web). *) 6 + let ort () : 'a Js.t = 7 + let o = Js.Unsafe.global##.ort in 8 + if Js.Optdef.test o then (Js.Unsafe.coerce o : 'a Js.t) 9 + else failwith "onnxruntime-web is not loaded: global 'ort' object not found" 10 + 11 + (** Convert an OCaml string list to a JS array of JS strings. *) 12 + let js_string_array (strs : string list) : Js.js_string Js.t Js.js_array Js.t = 13 + Js.array (Array.of_list (List.map Js.string strs)) 14 + 15 + (** Convert a JS array of JS strings to an OCaml string list. *) 16 + let string_list_of_js_array (arr : Js.js_string Js.t Js.js_array Js.t) : string list = 17 + Array.to_list (Array.map Js.to_string (Js.to_array arr)) 18 + 19 + (** Convert an OCaml int array to a JS array of ints. *) 20 + let js_int_array (dims : int array) : int Js.js_array Js.t = 21 + Js.array dims 22 + 23 + (** Convert a JS array of ints to an OCaml int array. *) 24 + let int_array_of_js (arr : int Js.js_array Js.t) : int array = 25 + Js.to_array arr 26 + 27 + (** Read a string property from a JS object. *) 28 + let get_string (obj : 'a Js.t) (key : string) : string = 29 + Js.to_string (Js.Unsafe.get obj (Js.string key)) 30 + 31 + (** Read an int property from a JS object. *) 32 + let get_int (obj : 'a Js.t) (key : string) : int = 33 + Js.Unsafe.get obj (Js.string key) 34 + 35 + (** Set a property on a JS object. *) 36 + let set (obj : 'a Js.t) (key : string) (value : 'b) : unit = 37 + Js.Unsafe.set obj (Js.string key) value 38 + 39 + (** Get a nested property: obj.key1.key2 *) 40 + let get_nested (obj : 'a Js.t) (key1 : string) (key2 : string) : 'b Js.t = 41 + Js.Unsafe.get (Js.Unsafe.get obj (Js.string key1)) (Js.string key2)
+409
lib/onnxrt.ml
··· 1 + module Dtype = struct 2 + type ('ocaml, 'elt) t = 3 + | Float32 : (float, Bigarray.float32_elt) t 4 + | Float64 : (float, Bigarray.float64_elt) t 5 + | Int8 : (int, Bigarray.int8_signed_elt) t 6 + | Uint8 : (int, Bigarray.int8_unsigned_elt) t 7 + | Int16 : (int, Bigarray.int16_signed_elt) t 8 + | Uint16 : (int, Bigarray.int16_unsigned_elt) t 9 + | Int32 : (int32, Bigarray.int32_elt) t 10 + 11 + type packed = Pack : ('ocaml, 'elt) t -> packed 12 + 13 + let to_string : type a b. (a, b) t -> string = function 14 + | Float32 -> "float32" 15 + | Float64 -> "float64" 16 + | Int8 -> "int8" 17 + | Uint8 -> "uint8" 18 + | Int16 -> "int16" 19 + | Uint16 -> "uint16" 20 + | Int32 -> "int32" 21 + 22 + let of_string = function 23 + | "float32" -> Some (Pack Float32) 24 + | "float64" -> Some (Pack Float64) 25 + | "int8" -> Some (Pack Int8) 26 + | "uint8" -> Some (Pack Uint8) 27 + | "int16" -> Some (Pack Int16) 28 + | "uint16" -> Some (Pack Uint16) 29 + | "int32" -> Some (Pack Int32) 30 + | _ -> None 31 + 32 + let equal : type a b c d. (a, b) t -> (c, d) t -> bool = 33 + fun a b -> 34 + match (a, b) with 35 + | Float32, Float32 -> true 36 + | Float64, Float64 -> true 37 + | Int8, Int8 -> true 38 + | Uint8, Uint8 -> true 39 + | Int16, Int16 -> true 40 + | Uint16, Uint16 -> true 41 + | Int32, Int32 -> true 42 + | _ -> false 43 + 44 + let _to_bigarray_kind : type a b. (a, b) t -> (a, b) Bigarray.kind = function 45 + | Float32 -> Bigarray.float32 46 + | Float64 -> Bigarray.float64 47 + | Int8 -> Bigarray.int8_signed 48 + | Uint8 -> Bigarray.int8_unsigned 49 + | Int16 -> Bigarray.int16_signed 50 + | Uint16 -> Bigarray.int16_unsigned 51 + | Int32 -> Bigarray.int32 52 + end 53 + 54 + module Tensor = struct 55 + type t = { 56 + js_tensor : Js_of_ocaml.Js.Unsafe.any; 57 + mutable disposed : bool; 58 + } 59 + 60 + type location = Cpu | Gpu_buffer 61 + 62 + let check_not_disposed t = 63 + if t.disposed then invalid_arg "Tensor has been disposed" 64 + 65 + let check_cpu t = 66 + check_not_disposed t; 67 + let loc = Js_helpers.get_string 68 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in 69 + if loc <> "cpu" then 70 + invalid_arg "Tensor data is on GPU; use Tensor.download first" 71 + 72 + let check_dtype : type a b. (a, b) Dtype.t -> t -> unit = 73 + fun expected t -> 74 + let actual_str = Js_helpers.get_string 75 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in 76 + let expected_str = Dtype.to_string expected in 77 + if actual_str <> expected_str then 78 + failwith (Printf.sprintf "Dtype mismatch: tensor is %s, expected %s" 79 + actual_str expected_str) 80 + 81 + let typed_array_of_bigarray : 82 + type a b. (a, b) Dtype.t -> 83 + (a, b, Bigarray.c_layout) Bigarray.Array1.t -> 84 + Js_of_ocaml.Js.Unsafe.any = 85 + fun dtype ba -> 86 + let open Js_of_ocaml in 87 + let ga = Bigarray.genarray_of_array1 ba in 88 + match dtype with 89 + | Dtype.Float32 -> 90 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Float32 ga) 91 + | Dtype.Float64 -> 92 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Float64 ga) 93 + | Dtype.Int8 -> 94 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int8_signed ga) 95 + | Dtype.Uint8 -> 96 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int8_unsigned ga) 97 + | Dtype.Int16 -> 98 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int16_signed ga) 99 + | Dtype.Uint16 -> 100 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int16_unsigned ga) 101 + | Dtype.Int32 -> 102 + Js.Unsafe.coerce (Typed_array.from_genarray Typed_array.Int32_signed ga) 103 + 104 + let of_bigarray1 : 105 + type a b. (a, b) Dtype.t -> 106 + (a, b, Bigarray.c_layout) Bigarray.Array1.t -> 107 + dims:int array -> t = 108 + fun dtype ba ~dims -> 109 + let expected_size = Array.fold_left ( * ) 1 dims in 110 + let actual_size = Bigarray.Array1.dim ba in 111 + if expected_size <> actual_size then 112 + invalid_arg (Printf.sprintf 113 + "Tensor.of_bigarray1: dims product (%d) <> bigarray length (%d)" 114 + expected_size actual_size); 115 + let open Js_of_ocaml in 116 + let ta = typed_array_of_bigarray dtype ba in 117 + let js_tensor = 118 + Js.Unsafe.new_obj 119 + (Js.Unsafe.get (Js_helpers.ort ()) (Js.string "Tensor")) 120 + [| Js.Unsafe.inject (Js.string (Dtype.to_string dtype)); 121 + ta; 122 + Js.Unsafe.inject (Js_helpers.js_int_array dims) |] 123 + in 124 + { js_tensor = Js.Unsafe.coerce js_tensor; disposed = false } 125 + 126 + let of_bigarray : 127 + type a b. (a, b) Dtype.t -> 128 + (a, b, Bigarray.c_layout) Bigarray.Genarray.t -> t = 129 + fun dtype ga -> 130 + let dims = Bigarray.Genarray.dims ga in 131 + let flat = Bigarray.reshape_1 ga (Array.fold_left ( * ) 1 dims) in 132 + of_bigarray1 dtype flat ~dims 133 + 134 + let of_float32s data ~dims = 135 + let expected_size = Array.fold_left ( * ) 1 dims in 136 + if Array.length data <> expected_size then 137 + invalid_arg (Printf.sprintf 138 + "Tensor.of_float32s: array length (%d) <> dims product (%d)" 139 + (Array.length data) expected_size); 140 + let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout 141 + expected_size in 142 + Array.iteri (fun i v -> Bigarray.Array1.set ba i v) data; 143 + of_bigarray1 Float32 ba ~dims 144 + 145 + let to_bigarray1_exn : 146 + type a b. (a, b) Dtype.t -> t -> 147 + (a, b, Bigarray.c_layout) Bigarray.Array1.t = 148 + fun dtype t -> 149 + check_cpu t; 150 + check_dtype dtype t; 151 + let open Js_of_ocaml in 152 + let data : Js.Unsafe.any = Js.Unsafe.get t.js_tensor (Js.string "data") in 153 + let ta = (Js.Unsafe.coerce data : (_, _, _) Typed_array.typedArray Js.t) in 154 + let ga = Typed_array.to_genarray ta in 155 + let size = Bigarray.Genarray.nth_dim ga 0 in 156 + let ba = Bigarray.reshape_1 ga size in 157 + (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t) 158 + 159 + let to_bigarray_exn : 160 + type a b. (a, b) Dtype.t -> t -> 161 + (a, b, Bigarray.c_layout) Bigarray.Genarray.t = 162 + fun dtype t -> 163 + let flat = to_bigarray1_exn dtype t in 164 + let dims_js : Js_of_ocaml.Js.Unsafe.any = 165 + Js_of_ocaml.Js.Unsafe.get t.js_tensor (Js_of_ocaml.Js.string "dims") in 166 + let dims = Js_helpers.int_array_of_js (Js_of_ocaml.Js.Unsafe.coerce dims_js) in 167 + Bigarray.genarray_of_array1 flat |> fun ga -> Bigarray.reshape ga dims 168 + 169 + let download : 170 + type a b. (a, b) Dtype.t -> t -> 171 + (a, b, Bigarray.c_layout) Bigarray.Array1.t Lwt.t = 172 + fun dtype t -> 173 + check_not_disposed t; 174 + check_dtype dtype t; 175 + let open Js_of_ocaml in 176 + let promise = Js.Unsafe.meth_call t.js_tensor "getData" [||] in 177 + let open Lwt.Syntax in 178 + let+ data = Promise_lwt.to_lwt promise in 179 + let ta = (Js.Unsafe.coerce data : (_, _, _) Typed_array.typedArray Js.t) in 180 + let ga = Typed_array.to_genarray ta in 181 + let size = Bigarray.Genarray.nth_dim ga 0 in 182 + let ba = Bigarray.reshape_1 ga size in 183 + (Obj.magic ba : (a, b, Bigarray.c_layout) Bigarray.Array1.t) 184 + 185 + let dims t = 186 + check_not_disposed t; 187 + let open Js_of_ocaml in 188 + let dims_js = Js.Unsafe.get t.js_tensor (Js.string "dims") in 189 + Js_helpers.int_array_of_js (Js.Unsafe.coerce dims_js) 190 + 191 + let dtype t = 192 + check_not_disposed t; 193 + let type_str = Js_helpers.get_string 194 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "type" in 195 + match Dtype.of_string type_str with 196 + | Some p -> p 197 + | None -> failwith (Printf.sprintf "Unknown tensor dtype: %s" type_str) 198 + 199 + let size t = 200 + check_not_disposed t; 201 + Js_helpers.get_int (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "size" 202 + 203 + let location t = 204 + check_not_disposed t; 205 + let loc = Js_helpers.get_string 206 + (Js_of_ocaml.Js.Unsafe.coerce t.js_tensor) "location" in 207 + match loc with 208 + | "cpu" -> Cpu 209 + | "gpu-buffer" -> Gpu_buffer 210 + | s -> failwith (Printf.sprintf "Unknown tensor location: %s" s) 211 + 212 + let dispose t = 213 + if not t.disposed then begin 214 + let open Js_of_ocaml in 215 + ignore (Js.Unsafe.meth_call t.js_tensor "dispose" [||] : Js.Unsafe.any); 216 + t.disposed <- true 217 + end 218 + end 219 + 220 + module Execution_provider = struct 221 + type t = Wasm | Webgpu 222 + let to_string = function Wasm -> "wasm" | Webgpu -> "webgpu" 223 + end 224 + 225 + type output_location = Cpu | Gpu_buffer 226 + type graph_optimization = Disabled | Basic | Extended | All 227 + 228 + let graph_optimization_to_string = function 229 + | Disabled -> "disabled" 230 + | Basic -> "basic" 231 + | Extended -> "extended" 232 + | All -> "all" 233 + 234 + let output_location_to_js = function 235 + | Cpu -> Js_of_ocaml.Js.string "cpu" 236 + | Gpu_buffer -> Js_of_ocaml.Js.string "gpu-buffer" 237 + 238 + let log_level_to_string = function 239 + | `Verbose -> "verbose" 240 + | `Info -> "info" 241 + | `Warning -> "warning" 242 + | `Error -> "error" 243 + | `Fatal -> "fatal" 244 + 245 + module Session = struct 246 + type t = { 247 + js_session : Js_of_ocaml.Js.Unsafe.any; 248 + input_names_ : string list; 249 + output_names_ : string list; 250 + } 251 + 252 + let build_options ?execution_providers ?graph_optimization 253 + ?preferred_output_location ?log_level () = 254 + let open Js_of_ocaml in 255 + let pairs = ref [] in 256 + (match execution_providers with 257 + | Some eps -> 258 + let js_eps = Js.array (Array.of_list 259 + (List.map (fun ep -> 260 + Js.Unsafe.inject (Js.string (Execution_provider.to_string ep))) 261 + eps)) in 262 + pairs := ("executionProviders", Js.Unsafe.inject js_eps) :: !pairs 263 + | None -> ()); 264 + (match graph_optimization with 265 + | Some go -> 266 + pairs := ("graphOptimizationLevel", 267 + Js.Unsafe.inject (Js.string (graph_optimization_to_string go))) 268 + :: !pairs 269 + | None -> ()); 270 + (match preferred_output_location with 271 + | Some loc -> 272 + pairs := ("preferredOutputLocation", 273 + Js.Unsafe.inject (output_location_to_js loc)) 274 + :: !pairs 275 + | None -> ()); 276 + (match log_level with 277 + | Some level -> 278 + pairs := ("logSeverityLevel", 279 + Js.Unsafe.inject (Js.string (log_level_to_string level))) 280 + :: !pairs 281 + | None -> ()); 282 + Js.Unsafe.obj (Array.of_list !pairs) 283 + 284 + let wrap_session js_session = 285 + let open Js_of_ocaml in 286 + let input_names_ = Js_helpers.string_list_of_js_array 287 + (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "inputNames"))) in 288 + let output_names_ = Js_helpers.string_list_of_js_array 289 + (Js.Unsafe.coerce (Js.Unsafe.get js_session (Js.string "outputNames"))) in 290 + { js_session = Js.Unsafe.coerce js_session; input_names_; output_names_ } 291 + 292 + let create ?execution_providers ?graph_optimization 293 + ?preferred_output_location ?log_level model_url () = 294 + let open Js_of_ocaml in 295 + let ort = Js_helpers.ort () in 296 + let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in 297 + let options = build_options ?execution_providers ?graph_optimization 298 + ?preferred_output_location ?log_level () in 299 + let promise = Js.Unsafe.meth_call inference_session "create" 300 + [| Js.Unsafe.inject (Js.string model_url); 301 + Js.Unsafe.inject options |] in 302 + let open Lwt.Syntax in 303 + let+ js_session = Promise_lwt.to_lwt promise in 304 + wrap_session js_session 305 + 306 + let create_from_buffer (type a b) ?execution_providers ?graph_optimization 307 + ?preferred_output_location ?log_level 308 + (buffer : (a, b, Bigarray.c_layout) Bigarray.Array1.t) () = 309 + let open Js_of_ocaml in 310 + let ort = Js_helpers.ort () in 311 + let inference_session = Js.Unsafe.get ort (Js.string "InferenceSession") in 312 + let options = build_options ?execution_providers ?graph_optimization 313 + ?preferred_output_location ?log_level () in 314 + let ga = Bigarray.genarray_of_array1 buffer in 315 + let ta = Typed_array.from_genarray Typed_array.Int8_unsigned (Obj.magic ga) in 316 + let ab : Typed_array.arrayBuffer Js.t = 317 + Js.Unsafe.get (Js.Unsafe.coerce ta) (Js.string "buffer") in 318 + let uint8 = Js.Unsafe.new_obj 319 + (Js.Unsafe.global##._Uint8Array) 320 + [| Js.Unsafe.inject ab |] in 321 + let promise = Js.Unsafe.meth_call inference_session "create" 322 + [| Js.Unsafe.inject uint8; 323 + Js.Unsafe.inject options |] in 324 + let open Lwt.Syntax in 325 + let+ js_session = Promise_lwt.to_lwt promise in 326 + wrap_session js_session 327 + 328 + let run t inputs = 329 + let open Js_of_ocaml in 330 + let feeds = Js.Unsafe.obj 331 + (Array.of_list 332 + (List.map (fun (name, (tensor : Tensor.t)) -> 333 + (name, Js.Unsafe.inject tensor.js_tensor)) 334 + inputs)) in 335 + let promise = Js.Unsafe.meth_call t.js_session "run" 336 + [| Js.Unsafe.inject feeds |] in 337 + let open Lwt.Syntax in 338 + let+ results = Promise_lwt.to_lwt promise in 339 + List.map (fun name -> 340 + let js_tensor = Js.Unsafe.get results (Js.string name) in 341 + (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor; 342 + disposed = false })) 343 + t.output_names_ 344 + 345 + let run_with_outputs t inputs ~output_names = 346 + let open Js_of_ocaml in 347 + let feeds = Js.Unsafe.obj 348 + (Array.of_list 349 + (List.map (fun (name, (tensor : Tensor.t)) -> 350 + (name, Js.Unsafe.inject tensor.js_tensor)) 351 + inputs)) in 352 + let promise = Js.Unsafe.meth_call t.js_session "run" 353 + [| Js.Unsafe.inject feeds |] in 354 + let open Lwt.Syntax in 355 + let+ results = Promise_lwt.to_lwt promise in 356 + List.map (fun name -> 357 + let js_tensor = Js.Unsafe.get results (Js.string name) in 358 + (name, Tensor.{ js_tensor = Js.Unsafe.coerce js_tensor; 359 + disposed = false })) 360 + output_names 361 + 362 + let input_names t = t.input_names_ 363 + let output_names t = t.output_names_ 364 + 365 + let release t = 366 + let open Js_of_ocaml in 367 + let promise = Js.Unsafe.meth_call t.js_session "release" [||] in 368 + Promise_lwt.to_lwt promise |> Lwt.map (fun (_ : Js.Unsafe.any) -> ()) 369 + end 370 + 371 + module Env = struct 372 + module Wasm = struct 373 + let set_num_threads n = 374 + let ort = Js_helpers.ort () in 375 + Js_helpers.set 376 + (Js_helpers.get_nested ort "env" "wasm") 377 + "numThreads" n 378 + 379 + let set_simd enabled = 380 + let ort = Js_helpers.ort () in 381 + Js_helpers.set 382 + (Js_helpers.get_nested ort "env" "wasm") 383 + "simd" (Js_of_ocaml.Js.bool enabled) 384 + 385 + let set_proxy enabled = 386 + let ort = Js_helpers.ort () in 387 + Js_helpers.set 388 + (Js_helpers.get_nested ort "env" "wasm") 389 + "proxy" (Js_of_ocaml.Js.bool enabled) 390 + 391 + let set_wasm_paths prefix = 392 + let ort = Js_helpers.ort () in 393 + Js_helpers.set 394 + (Js_helpers.get_nested ort "env" "wasm") 395 + "wasmPaths" (Js_of_ocaml.Js.string prefix) 396 + end 397 + 398 + module Webgpu = struct 399 + let set_power_preference pref = 400 + let ort = Js_helpers.ort () in 401 + let s = match pref with 402 + | `High_performance -> "high-performance" 403 + | `Low_power -> "low-power" 404 + in 405 + Js_helpers.set 406 + (Js_helpers.get_nested ort "env" "webgpu") 407 + "powerPreference" (Js_of_ocaml.Js.string s) 408 + end 409 + end
+476
lib/onnxrt.mli
··· 1 + (** ONNX Runtime Web bindings for OCaml. 2 + 3 + This library provides OCaml bindings to 4 + {{:https://onnxruntime.ai/} ONNX Runtime Web}, enabling ML model inference 5 + in the browser via [js_of_ocaml] or [wasm_of_ocaml]. 6 + 7 + The bindings target the [onnxruntime-web] npm package and support both the 8 + WebAssembly (CPU) and WebGPU (GPU) execution providers. 9 + 10 + {1 Quick start} 11 + 12 + {[ 13 + open Onnxrt 14 + 15 + let () = 16 + Lwt.async @@ fun () -> 17 + let open Lwt.Syntax in 18 + (* Configure before creating any session *) 19 + Env.Wasm.set_num_threads 2; 20 + (* Load model *) 21 + let* session = Session.create "model.onnx" () in 22 + (* Prepare input *) 23 + let ba = Bigarray.Array1.create Bigarray.float32 Bigarray.c_layout (3 * 224 * 224) in 24 + (* ... fill ba with image data ... *) 25 + let input = Tensor.of_bigarray1 Dtype.Float32 ba ~dims:[| 1; 3; 224; 224 |] in 26 + (* Run inference *) 27 + let* outputs = Session.run session [ "input", input ] in 28 + let output = List.assoc "output" outputs in 29 + let result = Tensor.to_bigarray1_exn Dtype.Float32 output in 30 + (* Clean up *) 31 + Tensor.dispose output; 32 + let* () = Session.release session in 33 + Lwt.return_unit 34 + ]} 35 + 36 + {1 Architecture} 37 + 38 + The library is structured in two layers: 39 + 40 + - {b Low-level}: Direct bindings to the onnxruntime-web JavaScript API via 41 + [Js.Unsafe]. Not exposed publicly. 42 + - {b High-level}: Pure OCaml types with {!Bigarray} for tensor data and 43 + {!Lwt.t} for async operations. This is the public API documented here. 44 + 45 + {1 Execution providers} 46 + 47 + ONNX Runtime Web supports multiple backends for executing model operators: 48 + 49 + - {!Execution_provider.Wasm}: CPU inference via WebAssembly with SIMD and 50 + optional multi-threading. Supports {b all} ONNX operators. This is the 51 + default and most portable backend. 52 + - {!Execution_provider.Webgpu}: GPU inference via WebGPU compute shaders. 53 + Supports ~140 operators; unsupported operators fall back to WASM 54 + automatically, though each fallback incurs a GPU↔CPU data transfer. 55 + 56 + Execution providers are specified as a preference list when creating a 57 + session. The runtime tries each in order and falls back to the next: 58 + 59 + {[ 60 + Session.create "model.onnx" 61 + ~execution_providers:[ Webgpu; Wasm ] 62 + () 63 + ]} 64 + 65 + {1 Threading model} 66 + 67 + All operations that may block return [Lwt.t] promises. The WASM backend may 68 + use internal Web Workers for multi-threading (transparent to the caller, but 69 + requires [SharedArrayBuffer] and cross-origin isolation headers). WebGPU 70 + dispatches compute shaders on the GPU asynchronously. 71 + 72 + {1 GPU tensors} 73 + 74 + When using the WebGPU backend, tensors can reside on the GPU to avoid 75 + CPU↔GPU transfers between chained inference calls. See {!Tensor.location}, 76 + {!Tensor.download}, and {!Session.create} with 77 + [~preferred_output_location:`Gpu_buffer]. 78 + 79 + {1 Prerequisites} 80 + 81 + The [onnxruntime-web] npm package must be loaded in the JavaScript 82 + environment before using this library. For WebGPU support, import from 83 + [onnxruntime-web/webgpu]. The WASM files ([ort-wasm-simd-threaded.wasm] 84 + etc.) must be served at a path configured via {!Env.Wasm.set_wasm_paths}. 85 + *) 86 + 87 + (** {1 Data types} *) 88 + 89 + (** Tensor element types. 90 + 91 + Each constructor carries the correspondence between the ONNX type name, 92 + the OCaml value type, and the {!Bigarray} element type. This allows 93 + type-safe tensor creation and extraction via GADTs. *) 94 + module Dtype : sig 95 + (** A tensor element type, parameterised by the OCaml value type ['ocaml] 96 + and the Bigarray element kind ['elt]. *) 97 + type ('ocaml, 'elt) t = 98 + | Float32 : (float, Bigarray.float32_elt) t 99 + (** 32-bit floating point. The most common type for ML models. *) 100 + | Float64 : (float, Bigarray.float64_elt) t 101 + (** 64-bit floating point. *) 102 + | Int8 : (int, Bigarray.int8_signed_elt) t 103 + (** Signed 8-bit integer. Used in quantized models. *) 104 + | Uint8 : (int, Bigarray.int8_unsigned_elt) t 105 + (** Unsigned 8-bit integer. Common for image data and quantized 106 + models. *) 107 + | Int16 : (int, Bigarray.int16_signed_elt) t 108 + (** Signed 16-bit integer. *) 109 + | Uint16 : (int, Bigarray.int16_unsigned_elt) t 110 + (** Unsigned 16-bit integer. *) 111 + | Int32 : (int32, Bigarray.int32_elt) t 112 + (** 32-bit integer. Common for token IDs in NLP models. *) 113 + 114 + (** An existentially packed dtype for cases where the element type is only 115 + known at runtime (e.g. reading a model's output dtype). *) 116 + type packed = Pack : ('ocaml, 'elt) t -> packed 117 + 118 + val to_string : ('ocaml, 'elt) t -> string 119 + (** [to_string dtype] returns the ONNX type name (e.g. ["float32"], 120 + ["int32"]). *) 121 + 122 + val of_string : string -> packed option 123 + (** [of_string s] parses an ONNX type name. Returns [None] for unsupported 124 + types. *) 125 + 126 + val equal : ('a, 'b) t -> ('c, 'd) t -> bool 127 + (** [equal a b] returns [true] if [a] and [b] represent the same element 128 + type. *) 129 + end 130 + 131 + (** {1 Tensors} *) 132 + 133 + (** Multi-dimensional typed arrays for model input and output. 134 + 135 + Tensors are the primary data exchange type between OCaml and the ONNX 136 + runtime. On the CPU side, they are backed by JavaScript TypedArrays which 137 + share memory with OCaml {!Bigarray} values (zero-copy in [js_of_ocaml]). 138 + 139 + {2 Lifecycle} 140 + 141 + Tensors obtained from {!Session.run} should be {!dispose}d when no longer 142 + needed. For CPU tensors this is a hint to the garbage collector; for GPU 143 + tensors it releases the underlying [GPUBuffer] and failure to dispose will 144 + leak GPU memory. 145 + 146 + {2 GPU tensors} 147 + 148 + When a session is configured with 149 + [~preferred_output_location:`Gpu_buffer], output tensors reside on the GPU. 150 + Their data is not accessible synchronously — use {!download} to transfer 151 + to CPU, or pass them directly as input to another {!Session.run} call to 152 + keep computation on the GPU. *) 153 + module Tensor : sig 154 + (** An opaque tensor handle. *) 155 + type t 156 + 157 + (** Where the tensor's data is stored. *) 158 + type location = 159 + | Cpu 160 + (** Data is in CPU memory (a JavaScript TypedArray). Accessible 161 + synchronously via {!to_bigarray1_exn}. *) 162 + | Gpu_buffer 163 + (** Data is in a WebGPU GPUBuffer. Must be {!download}ed before 164 + CPU-side access, or passed directly to {!Session.run}. *) 165 + 166 + (** {2 Creating tensors} *) 167 + 168 + val of_bigarray1 : 169 + ('a, 'b) Dtype.t -> 170 + ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> 171 + dims:int array -> 172 + t 173 + (** [of_bigarray1 dtype ba ~dims] creates a tensor from a 1-dimensional 174 + bigarray. The [dims] array specifies the logical shape (e.g. 175 + [[| 1; 3; 224; 224 |]]). The product of [dims] must equal the length 176 + of [ba]. 177 + 178 + The bigarray's underlying buffer is shared with the tensor (zero-copy 179 + in [js_of_ocaml]). Modifying [ba] after tensor creation will affect the 180 + tensor's data. 181 + 182 + @raise Invalid_argument if [Array.fold_left ( * ) 1 dims <> Bigarray.Array1.dim ba] *) 183 + 184 + val of_bigarray : 185 + ('a, 'b) Dtype.t -> 186 + ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t -> 187 + t 188 + (** [of_bigarray dtype ga] creates a tensor from a generic bigarray, using 189 + the bigarray's dimensions as the tensor shape. Zero-copy. *) 190 + 191 + val of_float32s : float array -> dims:int array -> t 192 + (** [of_float32s data ~dims] creates a Float32 tensor from an OCaml float 193 + array. Copies the data into a new Float32Array. 194 + 195 + @raise Invalid_argument if [Array.length data] doesn't match the product 196 + of [dims] *) 197 + 198 + (** {2 Reading tensor data} *) 199 + 200 + val to_bigarray1_exn : 201 + ('a, 'b) Dtype.t -> 202 + t -> 203 + ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t 204 + (** [to_bigarray1_exn dtype tensor] returns the tensor's data as a 205 + flat 1-dimensional bigarray. Zero-copy when possible. 206 + 207 + @raise Invalid_argument if the tensor is on the GPU (use {!download} 208 + first) 209 + @raise Failure if [dtype] does not match the tensor's actual dtype *) 210 + 211 + val to_bigarray_exn : 212 + ('a, 'b) Dtype.t -> 213 + t -> 214 + ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t 215 + (** [to_bigarray_exn dtype tensor] returns the tensor's data as a generic 216 + bigarray with the tensor's shape as dimensions. Zero-copy when possible. 217 + 218 + @raise Invalid_argument if the tensor is on the GPU 219 + @raise Failure if [dtype] does not match the tensor's actual dtype *) 220 + 221 + val download : 222 + ('a, 'b) Dtype.t -> 223 + t -> 224 + ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t Lwt.t 225 + (** [download dtype tensor] retrieves the tensor's data, transferring from 226 + GPU to CPU if necessary. For CPU tensors, this resolves immediately. 227 + 228 + This is the only way to access data from a GPU tensor. 229 + 230 + @raise Failure if [dtype] does not match the tensor's actual dtype *) 231 + 232 + (** {2 Tensor metadata} *) 233 + 234 + val dims : t -> int array 235 + (** [dims tensor] returns the tensor's shape (e.g. [[| 1; 3; 224; 224 |]]). *) 236 + 237 + val dtype : t -> Dtype.packed 238 + (** [dtype tensor] returns the tensor's element type as a packed value. 239 + Use pattern matching to recover the type: 240 + 241 + {[ 242 + match Tensor.dtype t with 243 + | Dtype.Pack Float32 -> (* ... *) 244 + | Dtype.Pack Int32 -> (* ... *) 245 + | _ -> failwith "unexpected dtype" 246 + ]} *) 247 + 248 + val size : t -> int 249 + (** [size tensor] returns the total number of elements (product of dims). *) 250 + 251 + val location : t -> location 252 + (** [location tensor] returns where the tensor's data currently resides. *) 253 + 254 + (** {2 Lifecycle} *) 255 + 256 + val dispose : t -> unit 257 + (** [dispose tensor] releases the tensor's resources. For CPU tensors, drops 258 + the internal reference (data may still be accessible via a bigarray alias). 259 + For GPU tensors, destroys the underlying [GPUBuffer]. Always dispose GPU 260 + tensors to avoid memory leaks. 261 + 262 + After disposal, any access to the tensor's data raises. *) 263 + end 264 + 265 + (** {1 Inference sessions} *) 266 + 267 + (** Execution providers determine how model operators are executed. 268 + 269 + Providers are specified as a preference list when creating a session. The 270 + runtime tries each in order, falling back to the next if unavailable. 271 + The WASM provider is always available as a final fallback. *) 272 + module Execution_provider : sig 273 + type t = 274 + | Wasm 275 + (** CPU inference via WebAssembly. Supports all ONNX operators. Uses 276 + SIMD and optional multi-threading. This is the default. *) 277 + | Webgpu 278 + (** GPU inference via WebGPU compute shaders. Requires a browser with 279 + WebGPU support (Chrome 113+, Firefox 141+, Safari 26+). Operators 280 + without WebGPU kernels fall back to WASM automatically. *) 281 + 282 + val to_string : t -> string 283 + (** [to_string ep] returns the JavaScript name (["wasm"] or ["webgpu"]). *) 284 + end 285 + 286 + (** Where session outputs should be placed. *) 287 + type output_location = 288 + | Cpu 289 + (** Transfer results to CPU (default). Data is immediately accessible. *) 290 + | Gpu_buffer 291 + (** Keep results on the GPU. Avoids GPU→CPU transfer overhead when 292 + chaining inference calls. Use {!Tensor.download} to read the data. *) 293 + 294 + (** Graph optimization level applied during session creation. *) 295 + type graph_optimization = 296 + | Disabled (** No graph optimizations. *) 297 + | Basic (** Basic optimizations (constant folding, redundancy elimination). *) 298 + | Extended (** Extended optimizations (includes basic + more advanced rewrites). *) 299 + | All (** All available optimizations (default). *) 300 + 301 + (** An inference session: a loaded and optimized ONNX model ready to run. 302 + 303 + {2 Session lifecycle} 304 + 305 + 1. Create a session with {!create}, which loads the model, applies graph 306 + optimizations, and partitions operators across execution providers. 307 + 2. Run inference with {!run}, passing named input tensors and receiving 308 + named output tensors. 309 + 3. Release with {!release} when done, to free model weights and any 310 + GPU resources. 311 + 312 + {2 Warm-up} 313 + 314 + When using WebGPU, compute shaders are compiled lazily on the first 315 + {!run} call. The first inference will be significantly slower than 316 + subsequent ones. Run a warm-up inference with dummy data after session 317 + creation if latency matters. 318 + 319 + {2 Thread safety} 320 + 321 + Sessions do not support concurrent {!run} calls. Await each result before 322 + starting the next inference. *) 323 + module Session : sig 324 + (** An opaque inference session handle. *) 325 + type t 326 + 327 + val create : 328 + ?execution_providers:Execution_provider.t list -> 329 + ?graph_optimization:graph_optimization -> 330 + ?preferred_output_location:output_location -> 331 + ?log_level:[ `Verbose | `Info | `Warning | `Error | `Fatal ] -> 332 + string -> 333 + unit -> 334 + t Lwt.t 335 + (** [create ?execution_providers ?graph_optimization ?preferred_output_location 336 + ?log_level model_url ()] loads an ONNX model and creates an inference session. 337 + 338 + @param execution_providers Preference-ordered list of backends to try. 339 + Defaults to [[Wasm]]. 340 + @param graph_optimization Level of graph optimization to apply. 341 + Defaults to [All]. 342 + @param preferred_output_location Where to place output tensors. Defaults 343 + to [Cpu]. Set to [Gpu_buffer] when chaining inference calls on the GPU. 344 + @param log_level Minimum severity for runtime log messages. 345 + Defaults to [`Warning]. 346 + @param model_url URL or path to the [.onnx] or [.ort] model file. 347 + 348 + @raise Failure if the model cannot be loaded or parsed *) 349 + 350 + val create_from_buffer : 351 + ?execution_providers:Execution_provider.t list -> 352 + ?graph_optimization:graph_optimization -> 353 + ?preferred_output_location:output_location -> 354 + ?log_level:[ `Verbose | `Info | `Warning | `Error | `Fatal ] -> 355 + ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> 356 + unit -> 357 + t Lwt.t 358 + (** [create_from_buffer ?... buffer ()] creates a session from model bytes 359 + already in memory. The [buffer] should contain the raw [.onnx] or [.ort] 360 + file content (typically fetched separately and cached in IndexedDB). 361 + 362 + Takes any Bigarray element type so you can pass [int8_unsigned] bytes 363 + directly. *) 364 + 365 + val run : 366 + t -> 367 + (string * Tensor.t) list -> 368 + (string * Tensor.t) list Lwt.t 369 + (** [run session inputs] runs inference on the model. 370 + 371 + [inputs] is an association list mapping input names to tensors. Use 372 + {!input_names} to discover the expected names. 373 + 374 + Returns an association list mapping output names to result tensors. 375 + The caller is responsible for {!Tensor.dispose}ing the returned tensors. 376 + 377 + @raise Failure if an input name is not recognised, a tensor shape is 378 + incompatible with the model, or inference fails *) 379 + 380 + val run_with_outputs : 381 + t -> 382 + (string * Tensor.t) list -> 383 + output_names:string list -> 384 + (string * Tensor.t) list Lwt.t 385 + (** [run_with_outputs session inputs ~output_names] runs inference, fetching 386 + only the specified outputs. This can be more efficient than {!run} when 387 + a model has multiple outputs but you only need some of them. 388 + 389 + @raise Failure if an output name is not recognised *) 390 + 391 + val input_names : t -> string list 392 + (** [input_names session] returns the model's expected input tensor names, 393 + in the order defined by the model. *) 394 + 395 + val output_names : t -> string list 396 + (** [output_names session] returns the model's output tensor names, in the 397 + order defined by the model. *) 398 + 399 + val release : t -> unit Lwt.t 400 + (** [release session] frees all resources held by the session, including 401 + model weights and any GPU resources. The session must not be used after 402 + this call. *) 403 + end 404 + 405 + (** {1 Environment configuration} 406 + 407 + Global settings that affect all sessions. These {b must} be set before 408 + the first call to {!Session.create}; changing them afterwards has no 409 + effect. *) 410 + module Env : sig 411 + (** WebAssembly backend configuration. *) 412 + module Wasm : sig 413 + val set_num_threads : int -> unit 414 + (** [set_num_threads n] sets the number of threads for the WASM backend. 415 + 416 + - [0] (default): auto-detect ([navigator.hardwareConcurrency / 2], 417 + capped at 4) 418 + - [1]: single-threaded (no Web Workers, no [SharedArrayBuffer] needed) 419 + - [n]: use [n] threads (requires cross-origin isolation) 420 + 421 + Multi-threading requires the page to be served with: 422 + {v 423 + Cross-Origin-Opener-Policy: same-origin 424 + Cross-Origin-Embedder-Policy: require-corp 425 + v} *) 426 + 427 + val set_simd : bool -> unit 428 + (** [set_simd enabled] enables or disables WASM SIMD. Defaults to [true] 429 + (auto-detect). SIMD provides ~2x speedup on supported hardware. *) 430 + 431 + val set_proxy : bool -> unit 432 + (** [set_proxy enabled] enables the proxy worker, which offloads WASM 433 + inference to a dedicated Web Worker for UI responsiveness. 434 + 435 + {b Incompatible with the WebGPU execution provider.} 436 + 437 + Defaults to [false]. *) 438 + 439 + val set_wasm_paths : string -> unit 440 + (** [set_wasm_paths prefix] sets the URL prefix where [.wasm] files are 441 + served. For example, [set_wasm_paths "/static/wasm/"] causes the 442 + runtime to load [/static/wasm/ort-wasm-simd-threaded.wasm]. 443 + 444 + By default, files are loaded relative to the current page or worker 445 + script location. *) 446 + end 447 + 448 + (** WebGPU backend configuration. *) 449 + module Webgpu : sig 450 + val set_power_preference : [ `High_performance | `Low_power ] -> unit 451 + (** [set_power_preference pref] sets the GPU adapter power preference. 452 + Defaults to [`High_performance]. 453 + 454 + - [`High_performance]: prefer discrete GPU (better throughput) 455 + - [`Low_power]: prefer integrated GPU (better battery life) *) 456 + end 457 + end 458 + 459 + (** {1 Errors} 460 + 461 + All functions that interact with the ONNX runtime raise [Failure] on 462 + error with a descriptive message from the runtime. Async operations may 463 + also reject the [Lwt.t] promise with [Failure]. 464 + 465 + In a production application, wrap calls in [Lwt.catch]: 466 + 467 + {[ 468 + Lwt.catch 469 + (fun () -> 470 + let* session = Session.create "model.onnx" () in 471 + (* ... *)) 472 + (fun exn -> 473 + Logs.err (fun m -> m "ONNX error: %s" (Printexc.to_string exn)); 474 + Lwt.return_unit) 475 + ]} 476 + *)
+17
lib/promise_lwt.ml
··· 1 + open Js_of_ocaml 2 + 3 + let to_lwt (promise : 'a Js.t) : 'a Lwt.t = 4 + let lwt_promise, resolver = Lwt.wait () in 5 + let on_resolve result = Lwt.wakeup resolver result in 6 + let on_reject error = 7 + let msg = 8 + Js.to_string (Js.Unsafe.meth_call error "toString" [||] : Js.js_string Js.t) 9 + in 10 + Lwt.wakeup_exn resolver (Failure msg) 11 + in 12 + let _ignored : 'b Js.t = 13 + Js.Unsafe.meth_call promise "then" 14 + [| Js.Unsafe.inject (Js.wrap_callback on_resolve); 15 + Js.Unsafe.inject (Js.wrap_callback on_reject) |] 16 + in 17 + lwt_promise
+7
lib/promise_lwt.mli
··· 1 + (** Internal: bridge JavaScript Promises to Lwt. 2 + 3 + Not part of the public API. *) 4 + 5 + val to_lwt : 'a Js_of_ocaml.Js.t -> 'a Lwt.t 6 + (** [to_lwt js_promise] converts a JavaScript Promise to an Lwt thread. 7 + If the Promise rejects, the Lwt thread fails with [Failure msg]. *)
+29
onnxrt.opam
··· 1 + # This file is generated by dune, edit dune-project instead 2 + opam-version: "2.0" 3 + synopsis: "OCaml bindings to ONNX Runtime Web for browser-based ML inference" 4 + description: 5 + "Type-safe OCaml bindings to onnxruntime-web, enabling ML model inference in the browser via js_of_ocaml or wasm_of_ocaml. Supports WebAssembly (CPU) and WebGPU (GPU) execution providers." 6 + license: "ISC" 7 + depends: [ 8 + "dune" {>= "3.17"} 9 + "ocaml" {>= "5.2"} 10 + "js_of_ocaml" {>= "5.8"} 11 + "js_of_ocaml-ppx" {>= "5.8"} 12 + "lwt" {>= "5.7"} 13 + "js_of_ocaml-lwt" {>= "5.8"} 14 + "odoc" {with-doc} 15 + ] 16 + build: [ 17 + ["dune" "subst"] {dev} 18 + [ 19 + "dune" 20 + "build" 21 + "-p" 22 + name 23 + "-j" 24 + jobs 25 + "@install" 26 + "@runtest" {with-test} 27 + "@doc" {with-doc} 28 + ] 29 + ]