fuzzy find my records ken.waow.tech
embeddings pds search
6
fork

Configure Feed

Select the types of activity you want to include in your feed.

ken — fuzzy find any record in your atproto repo

semantic search over an atproto repo. sign in, the backend walks your PDS
via com.atproto.sync.getRepo (one CAR, parsed locally via zat), embeds
records with bge-small through llama.cpp, and writes the resulting vector
pack back to your own PDS as a tech.waow.ken.pack record + blobs. nothing
lives anywhere else — delete the record and the pack is gone.

- zig backend, std.http.Server, zat for atproto primitives
- llama.cpp batched inference, 16 records per encode
- incremental re-index: unchanged records are reused by (uri, cid)
- partial search works from the moment the first batch finishes
- opt-in save + delete with explicit consent
- auth-first UI — no drive-by indexing of other people's repos
- public share URLs with per-query OG tags
- running at https://ken.waow.tech on fly.io

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

zzstoatzz decbc109

+11868
+19
.gitignore
··· 1 + data/ 2 + .venv/ 3 + __pycache__/ 4 + *.pyc 5 + .DS_Store 6 + .fastembed_cache/ 7 + 8 + # zig build artifacts 9 + backend/.zig-cache/ 10 + backend/zig-out/ 11 + backend/zig-pkg/ 12 + backend/build.log 13 + spike/zig-embed/.zig-cache/ 14 + spike/zig-embed/zig-out/ 15 + # large binary artifacts — the backend already has its own copy of the 16 + # model at backend/models/, and the llama.cpp release tarball is redundant 17 + # with the extracted dylibs in spike/llama-bin/ 18 + spike/bge-small.gguf 19 + spike/*.tar.gz
+46
README.md
··· 1 + # ken 2 + 3 + fuzzy find any record in your [atproto](https://atproto.com) repo. semantic search over a [PDS](https://atproto.com/guides/data-repos), with the vector pack written back to your own PDS as a record you can inspect or delete. 4 + 5 + running at **[ken.waow.tech](https://ken.waow.tech)**. 6 + 7 + ## how it works 8 + 9 + 1. sign in with your handle. oauth goes to your PDS. 10 + 2. backend fetches your whole repo in one call via [`com.atproto.sync.getRepo`](https://atproto.com/specs/sync#getrepo), parses the CAR locally via [zat](https://tangled.org/zat.dev/zat) 11 + 3. each record is embedded with [bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) running through [llama.cpp](https://github.com/ggerganov/llama.cpp), 16 records per batch 12 + 4. optional: click save, and the resulting vector pack is written back to your PDS as a `tech.waow.ken.pack` record + blobs. nothing lives anywhere else. click delete and it's gone. 13 + 5. subsequent sign-ins reuse vectors by `(uri, cid)` — only new or changed records get re-embedded 14 + 15 + search is in-memory cosine similarity across whatever the backend currently has cached for you. partial search works from the moment the first batch finishes, so the UI never blocks waiting on a full index. 16 + 17 + ## sharing 18 + 19 + a signed-in user can share a specific search via `https://ken.waow.tech/?handle=X&q=Y`. the backend's `GET /` injects per-query OpenGraph tags so link unfurlers render a real preview. anyone visiting a share URL loads the target's saved pack publicly (via the same PDS read path anyone else could take) and runs the query — no auth needed for readers, and nothing new is exposed because the records were already public on the PDS. 20 + 21 + ## layout 22 + 23 + ``` 24 + backend/ zig http server + llama.cpp wrapper + indexer 25 + src/ main source 26 + llama-include/ llama.h headers 27 + llama-bin/ linux x86_64 .so files (docker build) 28 + llama-bin-macos/ arm64 dylibs (local dev) 29 + models/ bge-small.gguf 30 + fly.toml production config 31 + Dockerfile multi-stage build for fly 32 + lexicons/ atproto lexicon specs 33 + tech/waow/ken/pack.json 34 + ``` 35 + 36 + ## running locally 37 + 38 + ```bash 39 + cd backend 40 + zig build 41 + OAUTH_CLIENT_SECRET_KEY=... MODEL_PATH=models/bge-small.gguf ./zig-out/bin/embed-on-pds 42 + ``` 43 + 44 + ## license 45 + 46 + MIT
+6
backend/.dockerignore
··· 1 + .zig-cache 2 + zig-out 3 + fly.toml 4 + Dockerfile 5 + .dockerignore 6 + llama-b8693-bin-ubuntu-x64.tar.gz
+45
backend/Dockerfile
··· 1 + # build stage 2 + FROM debian:bookworm-slim AS builder 3 + 4 + RUN apt-get update && apt-get install -y --no-install-recommends \ 5 + ca-certificates curl xz-utils \ 6 + && rm -rf /var/lib/apt/lists/* 7 + 8 + # install zig 0.16-dev (matching the dev environment) 9 + ARG ZIG_VERSION=0.16.0-dev.3070+b22eb176b 10 + RUN curl -L https://ziglang.org/builds/zig-x86_64-linux-${ZIG_VERSION}.tar.xz | tar -xJ -C /usr/local \ 11 + && ln -s /usr/local/zig-x86_64-linux-${ZIG_VERSION}/zig /usr/local/bin/zig 12 + 13 + WORKDIR /app 14 + COPY build.zig build.zig.zon ./ 15 + COPY src ./src 16 + COPY llama-include ./llama-include 17 + COPY llama-bin ./llama-bin 18 + # vendored zat package cache — avoids a network fetch from tangled.org inside 19 + # the remote builder (which can't reach it reliably) 20 + COPY zig-pkg ./zig-pkg 21 + 22 + # target debian:bookworm-slim's glibc (not musl) so pre-built llama.cpp .so 23 + # files link correctly. cpu_v3 = AVX2, matches fly's EPYC machines. 24 + RUN zig build -Doptimize=ReleaseSafe -Dtarget=x86_64-linux-gnu -Dcpu=x86_64_v3 25 + 26 + # runtime stage — minimal, with the llama libs in ./lib (matches build.zig rpath $ORIGIN/lib) 27 + FROM debian:bookworm-slim 28 + 29 + RUN apt-get update && apt-get install -y --no-install-recommends \ 30 + ca-certificates \ 31 + libstdc++6 \ 32 + libgomp1 \ 33 + && rm -rf /var/lib/apt/lists/* \ 34 + && echo 'precedence ::ffff:0:0/96 100' >> /etc/gai.conf 35 + 36 + WORKDIR /app 37 + COPY --from=builder /app/zig-out/bin/embed-on-pds . 38 + COPY --from=builder /app/llama-bin/llama-b8693/ ./lib/ 39 + COPY models ./models 40 + 41 + EXPOSE 3000 42 + ENV PORT=3000 43 + ENV MODEL_PATH=/app/models/bge-small.gguf 44 + 45 + CMD ["./embed-on-pds"]
+48
backend/build.zig
··· 1 + const std = @import("std"); 2 + 3 + pub fn build(b: *std.Build) void { 4 + const target = b.standardTargetOptions(.{}); 5 + const optimize = b.standardOptimizeOption(.{}); 6 + 7 + const zat = b.dependency("zat", .{ .target = target, .optimize = optimize }); 8 + 9 + const root_module = b.createModule(.{ 10 + .root_source_file = b.path("src/main.zig"), 11 + .target = target, 12 + .optimize = optimize, 13 + .link_libc = true, 14 + .link_libcpp = true, // libllama is C++ underneath even though llama.h is C 15 + .imports = &.{ 16 + .{ .name = "zat", .module = zat.module("zat") }, 17 + }, 18 + }); 19 + 20 + // llama.cpp: headers are shared, libraries differ by platform. both are 21 + // version b8693, vendored alongside the backend so the build is 22 + // self-contained. llama-bin-macos/ is arm64 dylibs for local dev; 23 + // llama-bin/ is ubuntu-x64 .so files for the docker build that ships 24 + // to fly. 25 + root_module.addIncludePath(b.path("llama-include")); 26 + 27 + const lib_dir: std.Build.LazyPath = switch (target.result.os.tag) { 28 + .macos => b.path("llama-bin-macos/llama-b8693"), 29 + .linux => b.path("llama-bin/llama-b8693"), 30 + else => @panic("unsupported target os"), 31 + }; 32 + root_module.addLibraryPath(lib_dir); 33 + root_module.addRPath(.{ .cwd_relative = "$ORIGIN/lib" }); 34 + root_module.linkSystemLibrary("llama", .{}); 35 + root_module.linkSystemLibrary("ggml", .{}); // ggml_backend_load_all lives here 36 + 37 + const exe = b.addExecutable(.{ 38 + .name = "embed-on-pds", 39 + .root_module = root_module, 40 + }); 41 + 42 + b.installArtifact(exe); 43 + 44 + const run = b.addRunArtifact(exe); 45 + if (b.args) |args| run.addArgs(args); 46 + const run_step = b.step("run", "run the backend"); 47 + run_step.dependOn(&run.step); 48 + }
+17
backend/build.zig.zon
··· 1 + .{ 2 + .name = .embed_on_pds_backend, 3 + .version = "0.0.1", 4 + .fingerprint = 0x14f599b254b5b8c8, 5 + .minimum_zig_version = "0.16.0", 6 + .dependencies = .{ 7 + .zat = .{ 8 + .url = "https://tangled.org/zat.dev/zat/archive/v0.3.0-alpha.22.tar.gz", 9 + .hash = "zat-0.3.0-alpha.22-5PuC7p9OCAAz4jpPNPoM4alGgqehPnpZCzIUy4LBajPh", 10 + }, 11 + }, 12 + .paths = .{ 13 + "build.zig", 14 + "build.zig.zon", 15 + "src", 16 + }, 17 + }
+22
backend/fly.toml
··· 1 + app = 'embed-on-pds' 2 + primary_region = 'ord' 3 + 4 + [build] 5 + 6 + [env] 7 + OAUTH_CLIENT_ID = 'https://ken.waow.tech/oauth-client-metadata.json' 8 + OAUTH_REDIRECT_URI = 'https://ken.waow.tech/oauth/callback' 9 + FRONTEND_ORIGIN = 'https://ken.waow.tech' 10 + 11 + [http_service] 12 + internal_port = 3000 13 + force_https = true 14 + auto_stop_machines = 'off' 15 + auto_start_machines = true 16 + min_machines_running = 1 17 + processes = ['app'] 18 + 19 + [[vm]] 20 + memory = '4gb' 21 + cpu_kind = 'performance' 22 + cpus = 2
backend/llama-b8693-bin-ubuntu-x64.tar.gz

This is a binary file and will not be displayed.

+21
backend/llama-bin-macos/llama-b8693/LICENSE
··· 1 + MIT License 2 + 3 + Copyright (c) 2023-2026 The ggml authors 4 + 5 + Permission is hereby granted, free of charge, to any person obtaining a copy 6 + of this software and associated documentation files (the "Software"), to deal 7 + in the Software without restriction, including without limitation the rights 8 + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 + copies of the Software, and to permit persons to whom the Software is 10 + furnished to do so, subject to the following conditions: 11 + 12 + The above copyright notice and this permission notice shall be included in all 13 + copies or substantial portions of the Software. 14 + 15 + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 + SOFTWARE.
backend/llama-bin-macos/llama-b8693/libggml-base.0.9.11.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libggml-base.0.dylib
··· 1 + libggml-base.0.9.11.dylib
+1
backend/llama-bin-macos/llama-b8693/libggml-base.dylib
··· 1 + libggml-base.0.dylib
backend/llama-bin-macos/llama-b8693/libggml-blas.0.9.11.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libggml-blas.0.dylib
··· 1 + libggml-blas.0.9.11.dylib
+1
backend/llama-bin-macos/llama-b8693/libggml-blas.dylib
··· 1 + libggml-blas.0.dylib
backend/llama-bin-macos/llama-b8693/libggml-cpu.0.9.11.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libggml-cpu.0.dylib
··· 1 + libggml-cpu.0.9.11.dylib
+1
backend/llama-bin-macos/llama-b8693/libggml-cpu.dylib
··· 1 + libggml-cpu.0.dylib
backend/llama-bin-macos/llama-b8693/libggml-metal.0.9.11.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libggml-metal.0.dylib
··· 1 + libggml-metal.0.9.11.dylib
+1
backend/llama-bin-macos/llama-b8693/libggml-metal.dylib
··· 1 + libggml-metal.0.dylib
backend/llama-bin-macos/llama-b8693/libggml-rpc.0.9.11.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libggml-rpc.0.dylib
··· 1 + libggml-rpc.0.9.11.dylib
+1
backend/llama-bin-macos/llama-b8693/libggml-rpc.dylib
··· 1 + libggml-rpc.0.dylib
backend/llama-bin-macos/llama-b8693/libggml.0.9.11.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libggml.0.dylib
··· 1 + libggml.0.9.11.dylib
+1
backend/llama-bin-macos/llama-b8693/libggml.dylib
··· 1 + libggml.0.dylib
backend/llama-bin-macos/llama-b8693/libllama.0.0.8693.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libllama.0.dylib
··· 1 + libllama.0.0.8693.dylib
+1
backend/llama-bin-macos/llama-b8693/libllama.dylib
··· 1 + libllama.0.dylib
backend/llama-bin-macos/llama-b8693/libmtmd.0.0.8693.dylib

This is a binary file and will not be displayed.

+1
backend/llama-bin-macos/llama-b8693/libmtmd.0.dylib
··· 1 + libmtmd.0.0.8693.dylib
+1
backend/llama-bin-macos/llama-b8693/libmtmd.dylib
··· 1 + libmtmd.0.dylib
backend/llama-bin-macos/llama-b8693/llama-batched-bench

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-bench

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-cli

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-completion

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-cvector-generator

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-debug-template-parser

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-export-lora

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-fit-params

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-gemma3-cli

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-gguf-split

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-imatrix

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-llava-cli

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-minicpmv-cli

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-mtmd-cli

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-mtmd-debug

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-perplexity

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-quantize

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-qwen2vl-cli

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-results

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-server

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-template-analysis

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-tokenize

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/llama-tts

This is a binary file and will not be displayed.

backend/llama-bin-macos/llama-b8693/rpc-server

This is a binary file and will not be displayed.

+21
backend/llama-bin/llama-b8693/LICENSE
··· 1 + MIT License 2 + 3 + Copyright (c) 2023-2026 The ggml authors 4 + 5 + Permission is hereby granted, free of charge, to any person obtaining a copy 6 + of this software and associated documentation files (the "Software"), to deal 7 + in the Software without restriction, including without limitation the rights 8 + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 + copies of the Software, and to permit persons to whom the Software is 10 + furnished to do so, subject to the following conditions: 11 + 12 + The above copyright notice and this permission notice shall be included in all 13 + copies or substantial portions of the Software. 14 + 15 + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 + SOFTWARE.
+1
backend/llama-bin/llama-b8693/libggml-base.so
··· 1 + libggml-base.so.0
+1
backend/llama-bin/llama-b8693/libggml-base.so.0
··· 1 + libggml-base.so.0.9.11
backend/llama-bin/llama-b8693/libggml-base.so.0.9.11

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-alderlake.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-cannonlake.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-cascadelake.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-cooperlake.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-haswell.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-icelake.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-ivybridge.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-piledriver.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-sandybridge.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-sapphirerapids.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-skylakex.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-sse42.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-x64.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-cpu-zen4.so

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/libggml-rpc.so

This is a binary file and will not be displayed.

+1
backend/llama-bin/llama-b8693/libggml.so
··· 1 + libggml.so.0
+1
backend/llama-bin/llama-b8693/libggml.so.0
··· 1 + libggml.so.0.9.11
backend/llama-bin/llama-b8693/libggml.so.0.9.11

This is a binary file and will not be displayed.

+1
backend/llama-bin/llama-b8693/libllama.so
··· 1 + libllama.so.0
+1
backend/llama-bin/llama-b8693/libllama.so.0
··· 1 + libllama.so.0.0.8693
backend/llama-bin/llama-b8693/libllama.so.0.0.8693

This is a binary file and will not be displayed.

+1
backend/llama-bin/llama-b8693/libmtmd.so
··· 1 + libmtmd.so.0
+1
backend/llama-bin/llama-b8693/libmtmd.so.0
··· 1 + libmtmd.so.0.0.8693
backend/llama-bin/llama-b8693/libmtmd.so.0.0.8693

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-batched-bench

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-bench

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-cli

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-completion

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-debug-template-parser

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-fit-params

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-gemma3-cli

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-gguf-split

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-imatrix

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-llava-cli

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-minicpmv-cli

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-mtmd-cli

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-mtmd-debug

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-perplexity

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-quantize

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-qwen2vl-cli

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-results

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-server

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-template-analysis

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-tokenize

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/llama-tts

This is a binary file and will not be displayed.

backend/llama-bin/llama-b8693/rpc-server

This is a binary file and will not be displayed.

+85
backend/llama-include/ggml-alloc.h
··· 1 + #pragma once 2 + 3 + #include "ggml.h" 4 + 5 + #ifdef __cplusplus 6 + extern "C" { 7 + #endif 8 + 9 + typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; 10 + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; 11 + typedef struct ggml_backend * ggml_backend_t; 12 + 13 + // Tensor allocator 14 + struct ggml_tallocr { 15 + ggml_backend_buffer_t buffer; 16 + void * base; 17 + size_t alignment; 18 + size_t offset; 19 + }; 20 + 21 + GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer); 22 + GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor); 23 + 24 + // Graph allocator 25 + /* 26 + Example usage: 27 + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); 28 + 29 + // optional: create a worst-case graph and reserve the buffers to avoid reallocations 30 + ggml_gallocr_reserve(galloc, build_graph(max_batch)); 31 + 32 + // allocate the graph 33 + struct ggml_cgraph * graph = build_graph(batch); 34 + ggml_gallocr_alloc_graph(galloc, graph); 35 + 36 + printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0)); 37 + 38 + // evaluate the graph 39 + ggml_backend_graph_compute(backend, graph); 40 + */ 41 + 42 + // special tensor flags for use with the graph allocator: 43 + // ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses 44 + // ggml_set_output(): output tensors are never freed and never overwritten 45 + 46 + typedef struct ggml_gallocr * ggml_gallocr_t; 47 + 48 + GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft); 49 + GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs); 50 + GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); 51 + 52 + // pre-allocate buffers from a measure graph - does not allocate or modify the graph 53 + // call with a worst-case graph to avoid buffer reallocations 54 + // not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed 55 + // returns false if the buffer allocation failed 56 + // ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes 57 + GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph); 58 + GGML_API void ggml_gallocr_reserve_n_size( 59 + ggml_gallocr_t galloc, 60 + struct ggml_cgraph * graph, 61 + const int * node_buffer_ids, 62 + const int * leaf_buffer_ids, 63 + size_t * sizes); 64 + GGML_API bool ggml_gallocr_reserve_n( 65 + ggml_gallocr_t galloc, 66 + struct ggml_cgraph * graph, 67 + const int * node_buffer_ids, 68 + const int * leaf_buffer_ids); 69 + 70 + // automatic reallocation if the topology changes when using a single buffer 71 + // returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers) 72 + GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); 73 + 74 + GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); 75 + 76 + // Utils 77 + // Create a buffer and allocate all the tensors in a ggml_context 78 + // ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft 79 + GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); 80 + GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); 81 + GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); 82 + 83 + #ifdef __cplusplus 84 + } 85 + #endif
+373
backend/llama-include/ggml-backend.h
··· 1 + #pragma once 2 + 3 + #include "ggml.h" 4 + #include "ggml-alloc.h" 5 + 6 + #ifdef GGML_BACKEND_SHARED 7 + # if defined(_WIN32) && !defined(__MINGW32__) 8 + # ifdef GGML_BACKEND_BUILD 9 + # define GGML_BACKEND_API __declspec(dllexport) extern 10 + # else 11 + # define GGML_BACKEND_API __declspec(dllimport) extern 12 + # endif 13 + # else 14 + # define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern 15 + # endif 16 + #else 17 + # define GGML_BACKEND_API extern 18 + #endif 19 + 20 + #ifdef __cplusplus 21 + extern "C" { 22 + #endif 23 + 24 + typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; 25 + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; 26 + typedef struct ggml_backend_event * ggml_backend_event_t; 27 + typedef struct ggml_backend * ggml_backend_t; 28 + typedef void * ggml_backend_graph_plan_t; 29 + typedef struct ggml_backend_reg * ggml_backend_reg_t; 30 + typedef struct ggml_backend_device * ggml_backend_dev_t; 31 + 32 + 33 + // 34 + // Backend buffer type 35 + // 36 + 37 + GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); 38 + GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); 39 + GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); 40 + GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); 41 + GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); 42 + GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); 43 + GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft); 44 + 45 + // 46 + // Backend buffer 47 + // 48 + 49 + enum ggml_backend_buffer_usage { 50 + GGML_BACKEND_BUFFER_USAGE_ANY = 0, 51 + GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, 52 + GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2, 53 + }; 54 + 55 + GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); 56 + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); 57 + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); 58 + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); 59 + GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); 60 + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); 61 + GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); 62 + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor); 63 + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); 64 + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); 65 + GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); 66 + GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer); 67 + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); 68 + GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); 69 + 70 + // tensor copy between different backends 71 + GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); 72 + 73 + // 74 + // Backend (stream) 75 + // 76 + 77 + GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend); 78 + GGML_API const char * ggml_backend_name(ggml_backend_t backend); 79 + GGML_API void ggml_backend_free(ggml_backend_t backend); 80 + 81 + GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend); 82 + GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); 83 + GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); 84 + GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); 85 + 86 + GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); 87 + GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); 88 + 89 + // "offset" refers to the offset in tensor->data for setting/getting data 90 + GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); 91 + GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); 92 + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); 93 + 94 + GGML_API void ggml_backend_synchronize(ggml_backend_t backend); 95 + 96 + GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph); 97 + GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); 98 + 99 + GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan); 100 + GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); 101 + GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph); 102 + 103 + // NOTE: will be removed, use device version instead 104 + GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); 105 + GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft); 106 + GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op); 107 + 108 + // asynchronous copy 109 + // the copy is performed after all the currently queued operations in backend_src 110 + // backend_dst will wait for the copy to complete before performing other operations 111 + // automatic fallback to sync copy if async is not supported 112 + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); 113 + 114 + GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); 115 + 116 + // 117 + // Events 118 + // 119 + 120 + GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device); 121 + GGML_API void ggml_backend_event_free(ggml_backend_event_t event); 122 + GGML_API void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend); 123 + GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); 124 + GGML_API void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event); 125 + 126 + // 127 + // Backend device 128 + // 129 + 130 + enum ggml_backend_dev_type { 131 + // CPU device using system memory 132 + GGML_BACKEND_DEVICE_TYPE_CPU, 133 + // GPU device using dedicated memory 134 + GGML_BACKEND_DEVICE_TYPE_GPU, 135 + // integrated GPU device using host memory 136 + GGML_BACKEND_DEVICE_TYPE_IGPU, 137 + // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) 138 + GGML_BACKEND_DEVICE_TYPE_ACCEL 139 + }; 140 + 141 + // functionality supported by the device 142 + struct ggml_backend_dev_caps { 143 + // asynchronous operations 144 + bool async; 145 + // pinned host buffer 146 + bool host_buffer; 147 + // creating buffers from host ptr 148 + bool buffer_from_host_ptr; 149 + // event synchronization 150 + bool events; 151 + }; 152 + 153 + // all the device properties 154 + struct ggml_backend_dev_props { 155 + // device name 156 + const char * name; 157 + // device description 158 + const char * description; 159 + // device free memory in bytes 160 + size_t memory_free; 161 + // device total memory in bytes 162 + size_t memory_total; 163 + // device type 164 + enum ggml_backend_dev_type type; 165 + // device id 166 + // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") 167 + // if the id is unknown, this should be NULL 168 + const char * device_id; 169 + // device capabilities 170 + struct ggml_backend_dev_caps caps; 171 + }; 172 + 173 + GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); 174 + GGML_API const char * ggml_backend_dev_description(ggml_backend_dev_t device); 175 + GGML_API void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total); 176 + GGML_API enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device); 177 + GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); 178 + GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); 179 + GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); 180 + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device); 181 + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); 182 + GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); 183 + 184 + GGML_API bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op); 185 + GGML_API bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft); 186 + GGML_API bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op); 187 + 188 + // 189 + // Backend (reg) 190 + // 191 + 192 + GGML_API const char * ggml_backend_reg_name(ggml_backend_reg_t reg); 193 + GGML_API size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg); 194 + GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index); 195 + GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name); 196 + 197 + // Common functions that may be obtained using ggml_backend_reg_get_proc_address 198 + 199 + // Split buffer type for tensor parallelism 200 + typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); 201 + // Set the number of threads for the backend 202 + typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); 203 + // Get additional buffer types provided by the device (returns a NULL-terminated array) 204 + typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device); 205 + // Set the abort callback for the backend 206 + typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data); 207 + // Get a list of feature flags supported by the backend (returns a NULL-terminated array) 208 + struct ggml_backend_feature { 209 + const char * name; 210 + const char * value; 211 + }; 212 + typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg); 213 + 214 + // 215 + // Backend registry 216 + // 217 + 218 + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); 219 + 220 + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); 221 + 222 + // Backend (reg) enumeration 223 + GGML_API size_t ggml_backend_reg_count(void); 224 + GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index); 225 + GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name); 226 + 227 + // Device enumeration 228 + GGML_API size_t ggml_backend_dev_count(void); 229 + GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index); 230 + GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name); 231 + GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type); 232 + 233 + // Direct backend (stream) initialization 234 + // = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params) 235 + GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params); 236 + // = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params) 237 + GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params); 238 + // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL) 239 + GGML_API ggml_backend_t ggml_backend_init_best(void); 240 + 241 + // Load a backend from a dynamic library and register it 242 + GGML_API ggml_backend_reg_t ggml_backend_load(const char * path); 243 + // Unload a backend if loaded dynamically and unregister it 244 + GGML_API void ggml_backend_unload(ggml_backend_reg_t reg); 245 + // Load all known backends from dynamic libraries 246 + GGML_API void ggml_backend_load_all(void); 247 + GGML_API void ggml_backend_load_all_from_path(const char * dir_path); 248 + 249 + // 250 + // Backend scheduler 251 + // 252 + 253 + // The backend scheduler allows for multiple backend devices to be used together 254 + // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends 255 + // The backends are selected based on: 256 + // - the backend that supports the operation 257 + // - the location of the pre-allocated tensors (e.g. the weights) 258 + /* 259 + Example usage: 260 + 261 + // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned 262 + // preferably to run on the same backend as the buffer 263 + ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); 264 + 265 + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); 266 + 267 + // initialize buffers from a max size graph (optional) 268 + reserve_graph = build_graph(sched, max_batch_size); 269 + 270 + // manually assign nodes to a backend (optional, should not be needed in most cases) 271 + struct ggml_tensor * node = ggml_mul_mat(ctx, ...); 272 + ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu); 273 + 274 + ggml_backend_sched_reserve(sched, reserve_graph); 275 + 276 + // compute 277 + graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation 278 + for (int i = 0; i < 10; ++i) { 279 + ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically 280 + } 281 + 282 + // if there are graph inputs: 283 + graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called) 284 + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph 285 + ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it 286 + ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors 287 + ggml_backend_sched_graph_compute(sched, graph); // execute the graph 288 + 289 + // as an alternative to the above it is also possible to assign the inputs to a dedicated context and 290 + // allocate them statically via ggml_backend_alloc_ctx_tensors 291 + } 292 + */ 293 + 294 + typedef struct ggml_backend_sched * ggml_backend_sched_t; 295 + 296 + // Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback) 297 + // when ask == true, the scheduler wants to know if the user wants to observe this node 298 + // this allows the scheduler to batch nodes together in order to evaluate them in a single call 299 + // 300 + // when ask == false, the scheduler is passing the node tensor to the user for observation 301 + // if the user returns false, the scheduler will cancel the graph compute 302 + // 303 + typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); 304 + 305 + // Initialize a backend scheduler, backends with low index are given priority over backends with high index 306 + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); 307 + GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); 308 + 309 + // Initialize backend buffers from a measure graph 310 + GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes); 311 + GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success 312 + 313 + GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); 314 + GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i); 315 + 316 + // Get the number of splits of the last graph 317 + GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); 318 + GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); 319 + 320 + GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend); 321 + GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); 322 + 323 + GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); 324 + GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); 325 + 326 + // Split graph without allocating it 327 + GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); 328 + 329 + // Allocate and compute graph on the backend scheduler 330 + GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success 331 + GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); 332 + GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); 333 + GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); 334 + 335 + // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph. 336 + // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers. 337 + // The correct way to use this API is to discard the deallocated tensors and create new ones. 338 + GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); 339 + 340 + // Set a callback to be called for each resulting node during graph compute 341 + GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); 342 + 343 + // 344 + // Utils 345 + // 346 + 347 + struct ggml_backend_graph_copy { 348 + ggml_backend_buffer_t buffer; 349 + struct ggml_context * ctx_allocated; 350 + struct ggml_context * ctx_unallocated; 351 + struct ggml_cgraph * graph; 352 + }; 353 + 354 + // Copy a graph to a different backend 355 + GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); 356 + GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); 357 + 358 + typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); 359 + 360 + // Compare the output of two backends 361 + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes); 362 + 363 + // Tensor initialization 364 + GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); 365 + GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor); 366 + 367 + // CPU buffer types are always available 368 + GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); 369 + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); 370 + 371 + #ifdef __cplusplus 372 + } 373 + #endif
+25
backend/llama-include/ggml-blas.h
··· 1 + #pragma once 2 + 3 + #include "ggml.h" 4 + #include "ggml-backend.h" 5 + 6 + 7 + #ifdef __cplusplus 8 + extern "C" { 9 + #endif 10 + 11 + // backend API 12 + GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void); 13 + 14 + GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend); 15 + 16 + // number of threads used for conversion to float 17 + // for openblas and blis, this will also set the number of threads used for blas operations 18 + GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); 19 + 20 + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void); 21 + 22 + 23 + #ifdef __cplusplus 24 + } 25 + #endif
+151
backend/llama-include/ggml-cpu.h
··· 1 + #pragma once 2 + 3 + #include "ggml.h" 4 + #include "ggml-backend.h" 5 + 6 + #ifdef __cplusplus 7 + extern "C" { 8 + #endif 9 + 10 + // the compute plan that needs to be prepared for ggml_graph_compute() 11 + // since https://github.com/ggml-org/ggml/issues/287 12 + struct ggml_cplan { 13 + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` 14 + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` 15 + 16 + int n_threads; 17 + struct ggml_threadpool * threadpool; 18 + 19 + // abort ggml_graph_compute when true 20 + ggml_abort_callback abort_callback; 21 + void * abort_callback_data; 22 + 23 + // use only reference implementations 24 + bool use_ref; 25 + }; 26 + 27 + // numa strategies 28 + enum ggml_numa_strategy { 29 + GGML_NUMA_STRATEGY_DISABLED = 0, 30 + GGML_NUMA_STRATEGY_DISTRIBUTE = 1, 31 + GGML_NUMA_STRATEGY_ISOLATE = 2, 32 + GGML_NUMA_STRATEGY_NUMACTL = 3, 33 + GGML_NUMA_STRATEGY_MIRROR = 4, 34 + GGML_NUMA_STRATEGY_COUNT 35 + }; 36 + 37 + GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems 38 + GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node 39 + 40 + GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); 41 + GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); 42 + 43 + GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); 44 + GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); 45 + 46 + GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); 47 + GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); 48 + 49 + GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); 50 + GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); 51 + 52 + GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); 53 + GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); 54 + 55 + GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); 56 + GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); 57 + 58 + GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params); 59 + GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); 60 + GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool); 61 + GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); 62 + GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); 63 + 64 + // ggml_graph_plan() has to be called before ggml_graph_compute() 65 + // when plan.work_size > 0, caller must allocate memory for plan.work_data 66 + GGML_BACKEND_API struct ggml_cplan ggml_graph_plan( 67 + const struct ggml_cgraph * cgraph, 68 + int n_threads, /* = GGML_DEFAULT_N_THREADS */ 69 + struct ggml_threadpool * threadpool /* = NULL */ ); 70 + GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); 71 + 72 + // same as ggml_graph_compute() but the work data is allocated as a part of the context 73 + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data 74 + GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); 75 + 76 + // 77 + // system info 78 + // 79 + 80 + // x86 81 + GGML_BACKEND_API int ggml_cpu_has_sse3 (void); 82 + GGML_BACKEND_API int ggml_cpu_has_ssse3 (void); 83 + GGML_BACKEND_API int ggml_cpu_has_avx (void); 84 + GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void); 85 + GGML_BACKEND_API int ggml_cpu_has_avx2 (void); 86 + GGML_BACKEND_API int ggml_cpu_has_bmi2 (void); 87 + GGML_BACKEND_API int ggml_cpu_has_f16c (void); 88 + GGML_BACKEND_API int ggml_cpu_has_fma (void); 89 + GGML_BACKEND_API int ggml_cpu_has_avx512 (void); 90 + GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void); 91 + GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void); 92 + GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void); 93 + GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void); 94 + // ARM 95 + GGML_BACKEND_API int ggml_cpu_has_neon (void); 96 + GGML_BACKEND_API int ggml_cpu_has_arm_fma (void); 97 + GGML_BACKEND_API int ggml_cpu_has_fp16_va (void); 98 + GGML_BACKEND_API int ggml_cpu_has_dotprod (void); 99 + GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); 100 + GGML_BACKEND_API int ggml_cpu_has_sve (void); 101 + GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes 102 + GGML_BACKEND_API int ggml_cpu_has_sme (void); 103 + // other 104 + GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); 105 + GGML_BACKEND_API int ggml_cpu_get_rvv_vlen (void); // risc-v vector length in bytes 106 + GGML_BACKEND_API int ggml_cpu_has_vsx (void); 107 + GGML_BACKEND_API int ggml_cpu_has_vxe (void); 108 + GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); 109 + GGML_BACKEND_API int ggml_cpu_has_llamafile (void); 110 + 111 + // Internal types and functions exposed for tests and benchmarks 112 + 113 + typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, 114 + const void * GGML_RESTRICT y, size_t by, int nrc); 115 + 116 + struct ggml_type_traits_cpu { 117 + ggml_from_float_t from_float; 118 + ggml_vec_dot_t vec_dot; 119 + enum ggml_type vec_dot_type; 120 + int64_t nrows; // number of rows to process simultaneously 121 + }; 122 + 123 + GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); 124 + 125 + GGML_BACKEND_API void ggml_cpu_init(void); 126 + 127 + // 128 + // CPU backend 129 + // 130 + 131 + GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void); 132 + 133 + GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend); 134 + GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); 135 + GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); 136 + GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); 137 + 138 + GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref); 139 + 140 + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); 141 + 142 + GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); 143 + GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t); 144 + GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); 145 + GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); 146 + GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); 147 + GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t); 148 + 149 + #ifdef __cplusplus 150 + } 151 + #endif
+61
backend/llama-include/ggml-metal.h
··· 1 + // Note: this description is outdated 2 + // 3 + // An interface allowing to compute ggml_cgraph with Metal 4 + // 5 + // This is a fully functional interface that extends ggml with GPU support for Apple devices. 6 + // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.) 7 + // 8 + // How it works? 9 + // 10 + // As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this 11 + // interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you 12 + // use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) 13 + // 14 + // You only need to make sure that all memory buffers that you used during the graph creation 15 + // are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is 16 + // used during the graph evaluation to determine the arguments of the compute kernels. 17 + // 18 + // Synchronization between device and host memory (for example for input and output tensors) 19 + // is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. 20 + // 21 + 22 + #pragma once 23 + 24 + #include "ggml.h" 25 + #include "ggml-backend.h" 26 + 27 + #include <stddef.h> 28 + #include <stdbool.h> 29 + 30 + struct ggml_tensor; 31 + struct ggml_cgraph; 32 + 33 + #ifdef __cplusplus 34 + extern "C" { 35 + #endif 36 + 37 + // 38 + // backend API 39 + // user-code should use only these functions 40 + // 41 + 42 + // TODO: remove in the future 43 + GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); 44 + 45 + GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); 46 + 47 + GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); 48 + 49 + // helper to check if the device supports a specific family 50 + // ideally, the user code should be doing these checks 51 + // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf 52 + GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); 53 + 54 + // capture all command buffers committed the next time `ggml_backend_graph_compute` is called 55 + GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); 56 + 57 + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void); 58 + 59 + #ifdef __cplusplus 60 + } 61 + #endif
+256
backend/llama-include/ggml-opt.h
··· 1 + // This file contains functionality for training models using GGML. 2 + // It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets. 3 + // At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code. 4 + // 5 + // Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) 6 + 7 + #pragma once 8 + 9 + #include "ggml.h" 10 + #include "ggml-backend.h" 11 + 12 + #include <stdint.h> 13 + 14 + #ifdef __cplusplus 15 + extern "C" { 16 + #endif 17 + 18 + struct ggml_opt_dataset; 19 + struct ggml_opt_context; 20 + struct ggml_opt_result; 21 + 22 + typedef struct ggml_opt_dataset * ggml_opt_dataset_t; 23 + typedef struct ggml_opt_context * ggml_opt_context_t; 24 + typedef struct ggml_opt_result * ggml_opt_result_t; 25 + 26 + // ====== Loss ====== 27 + 28 + // built-in loss types, i.e. the built-in quantities minimized by the optimizer 29 + // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value 30 + enum ggml_opt_loss_type { 31 + GGML_OPT_LOSS_TYPE_MEAN, 32 + GGML_OPT_LOSS_TYPE_SUM, 33 + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, 34 + GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, 35 + }; 36 + 37 + // ====== Dataset ====== 38 + 39 + GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( 40 + enum ggml_type type_data, // the type for the internal data tensor 41 + enum ggml_type type_label, // the type for the internal labels tensor 42 + int64_t ne_datapoint, // number of elements per datapoint 43 + int64_t ne_label, // number of elements per label 44 + int64_t ndata, // total number of datapoints/labels 45 + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) 46 + GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); 47 + 48 + // get underlying tensors that store the data 49 + GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); 50 + GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] 51 + GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] 52 + 53 + // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative 54 + GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata); 55 + 56 + // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch 57 + GGML_API void ggml_opt_dataset_get_batch( 58 + ggml_opt_dataset_t dataset, 59 + struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] 60 + struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] 61 + int64_t ibatch); 62 + GGML_API void ggml_opt_dataset_get_batch_host( 63 + ggml_opt_dataset_t dataset, 64 + void * data_batch, 65 + size_t nb_data_batch, 66 + void * labels_batch, 67 + int64_t ibatch); 68 + 69 + // ====== Model / Context ====== 70 + 71 + enum ggml_opt_build_type { 72 + GGML_OPT_BUILD_TYPE_FORWARD = 10, 73 + GGML_OPT_BUILD_TYPE_GRAD = 20, 74 + GGML_OPT_BUILD_TYPE_OPT = 30, 75 + }; 76 + 77 + enum ggml_opt_optimizer_type { 78 + GGML_OPT_OPTIMIZER_TYPE_ADAMW, 79 + GGML_OPT_OPTIMIZER_TYPE_SGD, 80 + 81 + GGML_OPT_OPTIMIZER_TYPE_COUNT 82 + }; 83 + 84 + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss 85 + struct ggml_opt_optimizer_params { 86 + struct { 87 + float alpha; // learning rate 88 + float beta1; // first AdamW momentum 89 + float beta2; // second AdamW momentum 90 + float eps; // epsilon for numerical stability 91 + float wd; // weight decay - 0.0f to disable 92 + } adamw; 93 + struct { 94 + float alpha; // learning rate 95 + float wd; // weight decay 96 + } sgd; 97 + }; 98 + 99 + // callback to calculate optimizer parameters prior to a backward pass 100 + // userdata can be used to pass arbitrary data 101 + typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); 102 + 103 + // returns the default optimizer params (constant, hard-coded values) 104 + // userdata is not used 105 + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); 106 + 107 + // casts userdata to ggml_opt_optimizer_params and returns it 108 + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata); 109 + 110 + // parameters for initializing a new optimization context 111 + struct ggml_opt_params { 112 + ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs 113 + 114 + // by default the forward graph needs to be reconstructed for each eval 115 + // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically 116 + struct ggml_context * ctx_compute; 117 + struct ggml_tensor * inputs; 118 + struct ggml_tensor * outputs; 119 + 120 + enum ggml_opt_loss_type loss_type; 121 + enum ggml_opt_build_type build_type; 122 + 123 + int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done 124 + 125 + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters 126 + void * get_opt_pars_ud; // userdata for calculating optimizer parameters 127 + 128 + // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor 129 + enum ggml_opt_optimizer_type optimizer; 130 + }; 131 + 132 + // get parameters for an optimization context with defaults set where possible 133 + // parameters for which no sensible defaults exist are supplied as arguments to this function 134 + GGML_API struct ggml_opt_params ggml_opt_default_params( 135 + ggml_backend_sched_t backend_sched, 136 + enum ggml_opt_loss_type loss_type); 137 + 138 + GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); 139 + GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); 140 + 141 + // set gradients to zero, initialize loss, and optionally reset the optimizer 142 + GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); 143 + 144 + GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically 145 + 146 + // get underlying tensors that store data 147 + // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc 148 + GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor 149 + GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor 150 + GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against 151 + GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss 152 + GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs 153 + GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels 154 + 155 + // get the gradient accumulator for a node from the forward graph 156 + GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); 157 + 158 + GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme 159 + 160 + GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type); 161 + 162 + // ====== Optimization Result ====== 163 + 164 + GGML_API ggml_opt_result_t ggml_opt_result_init(void); 165 + GGML_API void ggml_opt_result_free(ggml_opt_result_t result); 166 + GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); 167 + 168 + // get data from result, uncertainties are optional and can be ignored by passing NULL 169 + GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints 170 + GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value 171 + GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values 172 + GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value 173 + 174 + // ====== Computation ====== 175 + 176 + // if not using static graphs, this function must be called prior to ggml_opt_alloc 177 + GGML_API void ggml_opt_prepare_alloc( 178 + ggml_opt_context_t opt_ctx, 179 + struct ggml_context * ctx_compute, 180 + struct ggml_cgraph * gf, 181 + struct ggml_tensor * inputs, 182 + struct ggml_tensor * outputs); 183 + 184 + // allocate the next graph for evaluation, either forward or forward + backward 185 + // must be called exactly once prior to calling ggml_opt_eval 186 + GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward); 187 + 188 + // do forward pass, increment result if not NULL, do backward pass if allocated 189 + GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); 190 + 191 + // ############################################################################ 192 + // ## The high-level functions start here. They do not depend on any private ## 193 + // ## functions or structs and can be copied to and adapted for user code. ## 194 + // ############################################################################ 195 + 196 + // ====== Intended Usage ====== 197 + // 198 + // 1. Select the appropriate loss for your problem. 199 + // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them. 200 + // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster). 201 + // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors. 202 + // The first context should contain the model parameters and inputs and be allocated statically in user code. 203 + // The second context should contain all other tensors and will be (re)allocated automatically. 204 + // Due to this automated allocation the data of the second context is not defined when accessed in user code. 205 + // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors. 206 + // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead. 207 + 208 + // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation 209 + typedef void (*ggml_opt_epoch_callback)( 210 + bool train, // true after training evaluation, false after validation evaluation 211 + ggml_opt_context_t opt_ctx, 212 + ggml_opt_dataset_t dataset, 213 + ggml_opt_result_t result, // result associated with the dataset subsection 214 + int64_t ibatch, // number of batches that have been evaluated so far 215 + int64_t ibatch_max, // total number of batches in this dataset subsection 216 + int64_t t_start_us); // time at which the evaluation on the dataset subsection was started 217 + 218 + // do training on front of dataset, do evaluation only on back of dataset 219 + GGML_API void ggml_opt_epoch( 220 + ggml_opt_context_t opt_ctx, 221 + ggml_opt_dataset_t dataset, 222 + ggml_opt_result_t result_train, // result to increment during training, ignored if NULL 223 + ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL 224 + int64_t idata_split, // data index at which to split training and evaluation 225 + ggml_opt_epoch_callback callback_train, 226 + ggml_opt_epoch_callback callback_eval); 227 + 228 + // callback that prints a progress bar on stderr 229 + GGML_API void ggml_opt_epoch_callback_progress_bar( 230 + bool train, 231 + ggml_opt_context_t opt_ctx, 232 + ggml_opt_dataset_t dataset, 233 + ggml_opt_result_t result, 234 + int64_t ibatch, 235 + int64_t ibatch_max, 236 + int64_t t_start_us); 237 + 238 + // fit model defined by inputs and outputs to dataset 239 + GGML_API void ggml_opt_fit( 240 + ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs 241 + struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs 242 + struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] 243 + struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used 244 + ggml_opt_dataset_t dataset, // dataset with data and optionally also labels 245 + enum ggml_opt_loss_type loss_type, // loss to minimize 246 + enum ggml_opt_optimizer_type optimizer, // sgd or adamw 247 + ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) 248 + int64_t nepoch, // how many times the dataset should be iterated over 249 + int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs 250 + float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) 251 + bool silent); // whether or not info prints to stderr should be suppressed 252 + 253 + 254 + #ifdef __cplusplus 255 + } 256 + #endif
+35
backend/llama-include/ggml-rpc.h
··· 1 + #pragma once 2 + 3 + #include "ggml-backend.h" 4 + 5 + #ifdef __cplusplus 6 + extern "C" { 7 + #endif 8 + 9 + #define RPC_PROTO_MAJOR_VERSION 3 10 + #define RPC_PROTO_MINOR_VERSION 6 11 + #define RPC_PROTO_PATCH_VERSION 1 12 + 13 + #ifdef __cplusplus 14 + static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); 15 + #endif 16 + 17 + #define GGML_RPC_MAX_SERVERS 16 18 + 19 + // backend API 20 + GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); 21 + GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); 22 + 23 + GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); 24 + 25 + GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); 26 + 27 + GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, 28 + size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices); 29 + 30 + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); 31 + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); 32 + 33 + #ifdef __cplusplus 34 + } 35 + #endif
+2775
backend/llama-include/ggml.h
··· 1 + #pragma once 2 + 3 + // 4 + // GGML Tensor Library 5 + // 6 + // This documentation is still a work in progress. 7 + // If you wish some specific topics to be covered, feel free to drop a comment: 8 + // 9 + // https://github.com/ggml-org/whisper.cpp/issues/40 10 + // 11 + // ## Overview 12 + // 13 + // This library implements: 14 + // 15 + // - a set of tensor operations 16 + // - automatic differentiation 17 + // - basic optimization algorithms 18 + // 19 + // The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, 20 + // but is not limited to, the following: 21 + // 22 + // - linear regression 23 + // - support vector machines 24 + // - neural networks 25 + // 26 + // The library allows the user to define a certain function using the available tensor operations. This function 27 + // definition is represented internally via a computation graph. Each tensor operation in the function definition 28 + // corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the 29 + // function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized 30 + // using one of the available optimization algorithms. 31 + // 32 + // For example, here we define the function: f(x) = a*x^2 + b 33 + // 34 + // { 35 + // struct ggml_init_params params = { 36 + // .mem_size = 16*1024*1024, 37 + // .mem_buffer = NULL, 38 + // }; 39 + // 40 + // // memory allocation happens here 41 + // struct ggml_context * ctx = ggml_init(params); 42 + // 43 + // struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); 44 + // 45 + // ggml_set_param(ctx, x); // x is an input variable 46 + // 47 + // struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); 48 + // struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); 49 + // struct ggml_tensor * x2 = ggml_mul(ctx, x, x); 50 + // struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); 51 + // 52 + // ... 53 + // } 54 + // 55 + // Notice that the function definition above does not involve any actual computation. The computation is performed only 56 + // when the user explicitly requests it. For example, to compute the function's value at x = 2.0: 57 + // 58 + // { 59 + // ... 60 + // 61 + // struct ggml_cgraph * gf = ggml_new_graph(ctx); 62 + // ggml_build_forward_expand(gf, f); 63 + // 64 + // // set the input variable and parameter values 65 + // ggml_set_f32(x, 2.0f); 66 + // ggml_set_f32(a, 3.0f); 67 + // ggml_set_f32(b, 4.0f); 68 + // 69 + // ggml_graph_compute_with_ctx(ctx, &gf, n_threads); 70 + // 71 + // printf("f = %f\n", ggml_get_f32_1d(f, 0)); 72 + // 73 + // ... 74 + // } 75 + // 76 + // The actual computation is performed in the ggml_graph_compute() function. 77 + // 78 + // The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the 79 + // ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know 80 + // in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory 81 + // and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was 82 + // actually needed. 83 + // 84 + // The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic 85 + // differentiation and optimization algorithms. 86 + // 87 + // The described approach allows to define the function graph once and then compute its forward or backward graphs 88 + // multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way 89 + // the user can avoid the memory allocation overhead at runtime. 90 + // 91 + // The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class 92 + // citizens, but in theory the library can be extended to support FP8 and integer data types. 93 + // 94 + // Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary 95 + // and binary operations. Most of the available operations fall into one of these two categories. With time, it became 96 + // clear that the library needs to support more complex operations. The way to support these operations is not clear 97 + // yet, but a few examples are demonstrated in the following operations: 98 + // 99 + // - ggml_permute() 100 + // - ggml_conv_1d_1s() 101 + // - ggml_conv_1d_2s() 102 + // 103 + // For each tensor operator, the library implements a forward and backward computation function. The forward function 104 + // computes the output tensor value given the input tensor values. The backward function computes the adjoint of the 105 + // input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a 106 + // calculus class, or watch the following video: 107 + // 108 + // What is Automatic Differentiation? 109 + // https://www.youtube.com/watch?v=wG_nF1awSSY 110 + // 111 + // 112 + // ## Tensor data (struct ggml_tensor) 113 + // 114 + // The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of 115 + // the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains 116 + // pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: 117 + // 118 + // { 119 + // struct ggml_tensor * c = ggml_add(ctx, a, b); 120 + // 121 + // assert(c->src[0] == a); 122 + // assert(c->src[1] == b); 123 + // } 124 + // 125 + // The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the 126 + // number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows 127 + // to store tensors that are not contiguous in memory, which is useful for operations such as transposition and 128 + // permutation. All tensor operations have to take the stride into account and not assume that the tensor is 129 + // contiguous in memory. 130 + // 131 + // The data of the tensor is accessed via the "data" pointer. For example: 132 + // 133 + // { 134 + // const int nx = 2; 135 + // const int ny = 3; 136 + // 137 + // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny); 138 + // 139 + // for (int y = 0; y < ny; y++) { 140 + // for (int x = 0; x < nx; x++) { 141 + // *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y; 142 + // } 143 + // } 144 + // 145 + // ... 146 + // } 147 + // 148 + // Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. 149 + // 150 + // ## The matrix multiplication operator (ggml_mul_mat) 151 + // 152 + // TODO 153 + // 154 + // 155 + // ## Multi-threading 156 + // 157 + // TODO 158 + // 159 + // 160 + // ## Overview of ggml.c 161 + // 162 + // TODO 163 + // 164 + // 165 + // ## SIMD optimizations 166 + // 167 + // TODO 168 + // 169 + // 170 + // ## Debugging ggml 171 + // 172 + // TODO 173 + // 174 + // 175 + 176 + #ifdef GGML_SHARED 177 + # if defined(_WIN32) && !defined(__MINGW32__) 178 + # ifdef GGML_BUILD 179 + # define GGML_API __declspec(dllexport) extern 180 + # else 181 + # define GGML_API __declspec(dllimport) extern 182 + # endif 183 + # else 184 + # define GGML_API __attribute__ ((visibility ("default"))) extern 185 + # endif 186 + #else 187 + # define GGML_API extern 188 + #endif 189 + 190 + // TODO: support for clang 191 + #ifdef __GNUC__ 192 + # define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) 193 + #elif defined(_MSC_VER) 194 + # define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func 195 + #else 196 + # define GGML_DEPRECATED(func, hint) func 197 + #endif 198 + 199 + #ifndef __GNUC__ 200 + # define GGML_ATTRIBUTE_FORMAT(...) 201 + #elif defined(__MINGW32__) && !defined(__clang__) 202 + # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) 203 + #else 204 + # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) 205 + #endif 206 + 207 + #if defined(_WIN32) && !defined(_WIN32_WINNT) 208 + # define _WIN32_WINNT 0x0A00 209 + #endif 210 + 211 + #include <stdbool.h> 212 + #include <stddef.h> 213 + #include <stdint.h> 214 + #include <stdio.h> 215 + 216 + #define GGML_FILE_MAGIC 0x67676d6c // "ggml" 217 + #define GGML_FILE_VERSION 2 218 + 219 + #define GGML_QNT_VERSION 2 // bump this on quantization format changes 220 + #define GGML_QNT_VERSION_FACTOR 1000 // do not change this 221 + 222 + #define GGML_MAX_DIMS 4 223 + #define GGML_MAX_PARAMS 2048 224 + #define GGML_MAX_SRC 10 225 + #define GGML_MAX_N_THREADS 512 226 + #define GGML_MAX_OP_PARAMS 64 227 + 228 + #ifndef GGML_MAX_NAME 229 + # define GGML_MAX_NAME 64 230 + #endif 231 + 232 + #define GGML_DEFAULT_N_THREADS 4 233 + #define GGML_DEFAULT_GRAPH_SIZE 2048 234 + 235 + #if UINTPTR_MAX == 0xFFFFFFFF 236 + #define GGML_MEM_ALIGN 4 237 + #elif defined(__EMSCRIPTEN__) 238 + // emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm. 239 + // (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.) 240 + // ref: https://github.com/ggml-org/llama.cpp/pull/18628 241 + #define GGML_MEM_ALIGN 8 242 + #else 243 + #define GGML_MEM_ALIGN 16 244 + #endif 245 + 246 + #define GGML_EXIT_SUCCESS 0 247 + #define GGML_EXIT_ABORTED 1 248 + 249 + // TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 250 + #define GGML_ROPE_TYPE_NORMAL 0 251 + #define GGML_ROPE_TYPE_NEOX 2 252 + #define GGML_ROPE_TYPE_MROPE 8 253 + #define GGML_ROPE_TYPE_VISION 24 254 + #define GGML_ROPE_TYPE_IMROPE 40 // binary: 101000 255 + 256 + #define GGML_MROPE_SECTIONS 4 257 + 258 + #define GGML_UNUSED(x) (void)(x) 259 + #ifdef __CUDACC__ 260 + template<typename... Args> 261 + __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {} 262 + #define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__) 263 + #else 264 + #define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0) 265 + #endif // __CUDACC__ 266 + 267 + #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) 268 + 269 + #ifndef NDEBUG 270 + # define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) 271 + #elif defined(__GNUC__) 272 + # define GGML_UNREACHABLE() __builtin_unreachable() 273 + #elif defined(_MSC_VER) 274 + # define GGML_UNREACHABLE() __assume(0) 275 + #else 276 + # define GGML_UNREACHABLE() ((void) 0) 277 + #endif 278 + 279 + #ifdef __cplusplus 280 + # define GGML_NORETURN [[noreturn]] 281 + #elif defined(_MSC_VER) 282 + # define GGML_NORETURN __declspec(noreturn) 283 + #else 284 + # define GGML_NORETURN _Noreturn 285 + #endif 286 + 287 + #define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__) 288 + #define GGML_ASSERT(x) if (!(x)) GGML_ABORT("GGML_ASSERT(%s) failed", #x) 289 + 290 + // used to copy the number of elements and stride in bytes of tensors into local variables. 291 + // main purpose is to reduce code duplication and improve readability. 292 + // 293 + // example: 294 + // 295 + // GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); 296 + // GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); 297 + // 298 + #define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ 299 + const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \ 300 + GGML_UNUSED(prefix##0); 301 + #define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ 302 + GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ 303 + const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \ 304 + GGML_UNUSED(prefix##1); 305 + #define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ 306 + GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ 307 + const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \ 308 + GGML_UNUSED(prefix##2); 309 + #define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ 310 + GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ 311 + const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \ 312 + GGML_UNUSED(prefix##3); 313 + 314 + #define GGML_TENSOR_UNARY_OP_LOCALS \ 315 + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ 316 + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ 317 + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ 318 + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) 319 + 320 + #define GGML_TENSOR_BINARY_OP_LOCALS \ 321 + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ 322 + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ 323 + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ 324 + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ 325 + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ 326 + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) 327 + 328 + #define GGML_TENSOR_TERNARY_OP_LOCALS \ 329 + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ 330 + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ 331 + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ 332 + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ 333 + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \ 334 + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \ 335 + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ 336 + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) 337 + 338 + #define GGML_TENSOR_BINARY_OP_LOCALS01 \ 339 + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ 340 + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ 341 + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ 342 + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) 343 + 344 + #ifdef __cplusplus 345 + extern "C" { 346 + #endif 347 + 348 + // Function type used in fatal error callbacks 349 + typedef void (*ggml_abort_callback_t)(const char * error_message); 350 + 351 + // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout) 352 + // Returns the old callback for chaining 353 + GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback); 354 + 355 + GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4) 356 + GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...); 357 + 358 + enum ggml_status { 359 + GGML_STATUS_ALLOC_FAILED = -2, 360 + GGML_STATUS_FAILED = -1, 361 + GGML_STATUS_SUCCESS = 0, 362 + GGML_STATUS_ABORTED = 1, 363 + }; 364 + 365 + // get ggml_status name string 366 + GGML_API const char * ggml_status_to_string(enum ggml_status status); 367 + 368 + // ieee 754-2008 half-precision float16 369 + // todo: make this not an integral type 370 + typedef uint16_t ggml_fp16_t; 371 + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t); 372 + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float); 373 + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t); 374 + GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t); 375 + 376 + // google brain half-precision bfloat16 377 + typedef struct { uint16_t bits; } ggml_bf16_t; 378 + GGML_API ggml_bf16_t ggml_fp32_to_bf16(float); 379 + GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16 380 + GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t); 381 + GGML_API void ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t); 382 + GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t); 383 + 384 + struct ggml_object; 385 + struct ggml_context; 386 + struct ggml_cgraph; 387 + 388 + // NOTE: always add types at the end of the enum to keep backward compatibility 389 + enum ggml_type { 390 + GGML_TYPE_F32 = 0, 391 + GGML_TYPE_F16 = 1, 392 + GGML_TYPE_Q4_0 = 2, 393 + GGML_TYPE_Q4_1 = 3, 394 + // GGML_TYPE_Q4_2 = 4, support has been removed 395 + // GGML_TYPE_Q4_3 = 5, support has been removed 396 + GGML_TYPE_Q5_0 = 6, 397 + GGML_TYPE_Q5_1 = 7, 398 + GGML_TYPE_Q8_0 = 8, 399 + GGML_TYPE_Q8_1 = 9, 400 + GGML_TYPE_Q2_K = 10, 401 + GGML_TYPE_Q3_K = 11, 402 + GGML_TYPE_Q4_K = 12, 403 + GGML_TYPE_Q5_K = 13, 404 + GGML_TYPE_Q6_K = 14, 405 + GGML_TYPE_Q8_K = 15, 406 + GGML_TYPE_IQ2_XXS = 16, 407 + GGML_TYPE_IQ2_XS = 17, 408 + GGML_TYPE_IQ3_XXS = 18, 409 + GGML_TYPE_IQ1_S = 19, 410 + GGML_TYPE_IQ4_NL = 20, 411 + GGML_TYPE_IQ3_S = 21, 412 + GGML_TYPE_IQ2_S = 22, 413 + GGML_TYPE_IQ4_XS = 23, 414 + GGML_TYPE_I8 = 24, 415 + GGML_TYPE_I16 = 25, 416 + GGML_TYPE_I32 = 26, 417 + GGML_TYPE_I64 = 27, 418 + GGML_TYPE_F64 = 28, 419 + GGML_TYPE_IQ1_M = 29, 420 + GGML_TYPE_BF16 = 30, 421 + // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files 422 + // GGML_TYPE_Q4_0_4_8 = 32, 423 + // GGML_TYPE_Q4_0_8_8 = 33, 424 + GGML_TYPE_TQ1_0 = 34, 425 + GGML_TYPE_TQ2_0 = 35, 426 + // GGML_TYPE_IQ4_NL_4_4 = 36, 427 + // GGML_TYPE_IQ4_NL_4_8 = 37, 428 + // GGML_TYPE_IQ4_NL_8_8 = 38, 429 + GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) 430 + GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) 431 + GGML_TYPE_Q1_0 = 41, 432 + GGML_TYPE_COUNT = 42, 433 + }; 434 + 435 + // precision 436 + enum ggml_prec { 437 + GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default 438 + GGML_PREC_F32 = 10, 439 + }; 440 + 441 + // model file types 442 + enum ggml_ftype { 443 + GGML_FTYPE_UNKNOWN = -1, 444 + GGML_FTYPE_ALL_F32 = 0, 445 + GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors 446 + GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors 447 + GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors 448 + GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 449 + GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors 450 + GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors 451 + GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors 452 + GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors 453 + GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors 454 + GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors 455 + GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors 456 + GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors 457 + GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors 458 + GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors 459 + GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors 460 + GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors 461 + GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors 462 + GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors 463 + GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors 464 + GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors 465 + GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors 466 + GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors 467 + GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors 468 + GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors 469 + GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors 470 + }; 471 + 472 + // available tensor operations: 473 + enum ggml_op { 474 + GGML_OP_NONE = 0, 475 + 476 + GGML_OP_DUP, 477 + GGML_OP_ADD, 478 + GGML_OP_ADD_ID, 479 + GGML_OP_ADD1, 480 + GGML_OP_ACC, 481 + GGML_OP_SUB, 482 + GGML_OP_MUL, 483 + GGML_OP_DIV, 484 + GGML_OP_SQR, 485 + GGML_OP_SQRT, 486 + GGML_OP_LOG, 487 + GGML_OP_SIN, 488 + GGML_OP_COS, 489 + GGML_OP_SUM, 490 + GGML_OP_SUM_ROWS, 491 + GGML_OP_CUMSUM, 492 + GGML_OP_MEAN, 493 + GGML_OP_ARGMAX, 494 + GGML_OP_COUNT_EQUAL, 495 + GGML_OP_REPEAT, 496 + GGML_OP_REPEAT_BACK, 497 + GGML_OP_CONCAT, 498 + GGML_OP_SILU_BACK, 499 + GGML_OP_NORM, // normalize 500 + GGML_OP_RMS_NORM, 501 + GGML_OP_RMS_NORM_BACK, 502 + GGML_OP_GROUP_NORM, 503 + GGML_OP_L2_NORM, 504 + 505 + GGML_OP_MUL_MAT, 506 + GGML_OP_MUL_MAT_ID, 507 + GGML_OP_OUT_PROD, 508 + 509 + GGML_OP_SCALE, 510 + GGML_OP_SET, 511 + GGML_OP_CPY, 512 + GGML_OP_CONT, 513 + GGML_OP_RESHAPE, 514 + GGML_OP_VIEW, 515 + GGML_OP_PERMUTE, 516 + GGML_OP_TRANSPOSE, 517 + GGML_OP_GET_ROWS, 518 + GGML_OP_GET_ROWS_BACK, 519 + GGML_OP_SET_ROWS, 520 + GGML_OP_DIAG, 521 + GGML_OP_DIAG_MASK_INF, 522 + GGML_OP_DIAG_MASK_ZERO, 523 + GGML_OP_SOFT_MAX, 524 + GGML_OP_SOFT_MAX_BACK, 525 + GGML_OP_ROPE, 526 + GGML_OP_ROPE_BACK, 527 + GGML_OP_CLAMP, 528 + GGML_OP_CONV_TRANSPOSE_1D, 529 + GGML_OP_IM2COL, 530 + GGML_OP_IM2COL_BACK, 531 + GGML_OP_IM2COL_3D, 532 + GGML_OP_CONV_2D, 533 + GGML_OP_CONV_3D, 534 + GGML_OP_CONV_2D_DW, 535 + GGML_OP_CONV_TRANSPOSE_2D, 536 + GGML_OP_POOL_1D, 537 + GGML_OP_POOL_2D, 538 + GGML_OP_POOL_2D_BACK, 539 + GGML_OP_UPSCALE, 540 + GGML_OP_PAD, 541 + GGML_OP_PAD_REFLECT_1D, 542 + GGML_OP_ROLL, 543 + GGML_OP_ARANGE, 544 + GGML_OP_TIMESTEP_EMBEDDING, 545 + GGML_OP_ARGSORT, 546 + GGML_OP_TOP_K, 547 + GGML_OP_LEAKY_RELU, 548 + GGML_OP_TRI, 549 + GGML_OP_FILL, 550 + 551 + GGML_OP_FLASH_ATTN_EXT, 552 + GGML_OP_FLASH_ATTN_BACK, 553 + GGML_OP_SSM_CONV, 554 + GGML_OP_SSM_SCAN, 555 + GGML_OP_WIN_PART, 556 + GGML_OP_WIN_UNPART, 557 + GGML_OP_GET_REL_POS, 558 + GGML_OP_ADD_REL_POS, 559 + GGML_OP_RWKV_WKV6, 560 + GGML_OP_GATED_LINEAR_ATTN, 561 + GGML_OP_RWKV_WKV7, 562 + GGML_OP_SOLVE_TRI, 563 + GGML_OP_GATED_DELTA_NET, 564 + 565 + GGML_OP_UNARY, 566 + 567 + GGML_OP_MAP_CUSTOM1, 568 + GGML_OP_MAP_CUSTOM2, 569 + GGML_OP_MAP_CUSTOM3, 570 + 571 + GGML_OP_CUSTOM, 572 + 573 + GGML_OP_CROSS_ENTROPY_LOSS, 574 + GGML_OP_CROSS_ENTROPY_LOSS_BACK, 575 + GGML_OP_OPT_STEP_ADAMW, 576 + GGML_OP_OPT_STEP_SGD, 577 + 578 + GGML_OP_GLU, 579 + 580 + GGML_OP_COUNT, 581 + }; 582 + 583 + enum ggml_unary_op { 584 + GGML_UNARY_OP_ABS, 585 + GGML_UNARY_OP_SGN, 586 + GGML_UNARY_OP_NEG, 587 + GGML_UNARY_OP_STEP, 588 + GGML_UNARY_OP_TANH, 589 + GGML_UNARY_OP_ELU, 590 + GGML_UNARY_OP_RELU, 591 + GGML_UNARY_OP_SIGMOID, 592 + GGML_UNARY_OP_GELU, 593 + GGML_UNARY_OP_GELU_QUICK, 594 + GGML_UNARY_OP_SILU, 595 + GGML_UNARY_OP_HARDSWISH, 596 + GGML_UNARY_OP_HARDSIGMOID, 597 + GGML_UNARY_OP_EXP, 598 + GGML_UNARY_OP_EXPM1, 599 + GGML_UNARY_OP_SOFTPLUS, 600 + GGML_UNARY_OP_GELU_ERF, 601 + GGML_UNARY_OP_XIELU, 602 + GGML_UNARY_OP_FLOOR, 603 + GGML_UNARY_OP_CEIL, 604 + GGML_UNARY_OP_ROUND, 605 + GGML_UNARY_OP_TRUNC, 606 + 607 + GGML_UNARY_OP_COUNT, 608 + }; 609 + 610 + enum ggml_glu_op { 611 + GGML_GLU_OP_REGLU, 612 + GGML_GLU_OP_GEGLU, 613 + GGML_GLU_OP_SWIGLU, 614 + GGML_GLU_OP_SWIGLU_OAI, 615 + GGML_GLU_OP_GEGLU_ERF, 616 + GGML_GLU_OP_GEGLU_QUICK, 617 + 618 + GGML_GLU_OP_COUNT, 619 + }; 620 + 621 + enum ggml_object_type { 622 + GGML_OBJECT_TYPE_TENSOR, 623 + GGML_OBJECT_TYPE_GRAPH, 624 + GGML_OBJECT_TYPE_WORK_BUFFER 625 + }; 626 + 627 + enum ggml_log_level { 628 + GGML_LOG_LEVEL_NONE = 0, 629 + GGML_LOG_LEVEL_DEBUG = 1, 630 + GGML_LOG_LEVEL_INFO = 2, 631 + GGML_LOG_LEVEL_WARN = 3, 632 + GGML_LOG_LEVEL_ERROR = 4, 633 + GGML_LOG_LEVEL_CONT = 5, // continue previous log 634 + }; 635 + 636 + // this tensor... 637 + enum ggml_tensor_flag { 638 + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph 639 + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph 640 + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters 641 + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) 642 + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed 643 + }; 644 + 645 + enum ggml_tri_type { 646 + GGML_TRI_TYPE_UPPER_DIAG = 0, 647 + GGML_TRI_TYPE_UPPER = 1, 648 + GGML_TRI_TYPE_LOWER_DIAG = 2, 649 + GGML_TRI_TYPE_LOWER = 3 650 + }; 651 + 652 + struct ggml_init_params { 653 + // memory pool 654 + size_t mem_size; // bytes 655 + void * mem_buffer; // if NULL, memory will be allocated internally 656 + bool no_alloc; // don't allocate memory for the tensor data 657 + }; 658 + 659 + // n-dimensional tensor 660 + struct ggml_tensor { 661 + enum ggml_type type; 662 + 663 + struct ggml_backend_buffer * buffer; 664 + 665 + int64_t ne[GGML_MAX_DIMS]; // number of elements 666 + size_t nb[GGML_MAX_DIMS]; // stride in bytes: 667 + // nb[0] = ggml_type_size(type) 668 + // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding 669 + // nb[i] = nb[i-1] * ne[i-1] 670 + 671 + // compute data 672 + enum ggml_op op; 673 + 674 + // op params - allocated as int32_t for alignment 675 + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; 676 + 677 + int32_t flags; 678 + 679 + struct ggml_tensor * src[GGML_MAX_SRC]; 680 + 681 + // source tensor and offset for views 682 + struct ggml_tensor * view_src; 683 + size_t view_offs; 684 + 685 + void * data; 686 + 687 + char name[GGML_MAX_NAME]; 688 + 689 + void * extra; // extra things e.g. for ggml-cuda.cu 690 + 691 + char padding[8]; 692 + }; 693 + 694 + static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); 695 + 696 + // Abort callback 697 + // If not NULL, called before ggml computation 698 + // If it returns true, the computation is aborted 699 + typedef bool (*ggml_abort_callback)(void * data); 700 + 701 + 702 + // 703 + // GUID 704 + // 705 + 706 + // GUID types 707 + typedef uint8_t ggml_guid[16]; 708 + typedef ggml_guid * ggml_guid_t; 709 + 710 + GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b); 711 + 712 + // misc 713 + 714 + GGML_API const char * ggml_version(void); 715 + GGML_API const char * ggml_commit(void); 716 + 717 + GGML_API void ggml_time_init(void); // call this once at the beginning of the program 718 + GGML_API int64_t ggml_time_ms(void); 719 + GGML_API int64_t ggml_time_us(void); 720 + GGML_API int64_t ggml_cycles(void); 721 + GGML_API int64_t ggml_cycles_per_ms(void); 722 + 723 + // accepts a UTF-8 path, even on Windows 724 + GGML_API FILE * ggml_fopen(const char * fname, const char * mode); 725 + 726 + GGML_API void ggml_print_object (const struct ggml_object * obj); 727 + GGML_API void ggml_print_objects(const struct ggml_context * ctx); 728 + 729 + GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); 730 + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); 731 + GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); 732 + GGML_API size_t ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN 733 + 734 + GGML_API int64_t ggml_blck_size(enum ggml_type type); 735 + GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block 736 + GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row 737 + 738 + GGML_DEPRECATED( 739 + GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float 740 + "use ggml_row_size() instead"); 741 + 742 + GGML_API const char * ggml_type_name(enum ggml_type type); 743 + GGML_API const char * ggml_op_name (enum ggml_op op); 744 + GGML_API const char * ggml_op_symbol(enum ggml_op op); 745 + 746 + GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); 747 + GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op); 748 + GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name 749 + 750 + GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); 751 + 752 + GGML_API bool ggml_is_quantized(enum ggml_type type); 753 + 754 + // TODO: temporary until model loading of ggml examples is refactored 755 + GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); 756 + 757 + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); 758 + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); 759 + GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); 760 + GGML_API bool ggml_is_view (const struct ggml_tensor * tensor); 761 + GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); 762 + GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); 763 + GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); 764 + GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); 765 + GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars 766 + 767 + // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation) 768 + GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor); 769 + GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() 770 + GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 771 + GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 772 + 773 + // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok) 774 + GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor); 775 + 776 + // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN 777 + GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor); 778 + 779 + // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements 780 + GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor); 781 + 782 + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); 783 + GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); 784 + 785 + GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1); 786 + 787 + // use this to compute the memory overhead of a tensor 788 + GGML_API size_t ggml_tensor_overhead(void); 789 + 790 + GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes); 791 + 792 + // main 793 + 794 + GGML_API struct ggml_context * ggml_init (struct ggml_init_params params); 795 + GGML_API void ggml_reset(struct ggml_context * ctx); 796 + GGML_API void ggml_free (struct ggml_context * ctx); 797 + 798 + GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); 799 + 800 + GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); 801 + GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); 802 + 803 + GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); 804 + GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx); 805 + GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx); 806 + 807 + GGML_API struct ggml_tensor * ggml_new_tensor( 808 + struct ggml_context * ctx, 809 + enum ggml_type type, 810 + int n_dims, 811 + const int64_t *ne); 812 + 813 + GGML_API struct ggml_tensor * ggml_new_tensor_1d( 814 + struct ggml_context * ctx, 815 + enum ggml_type type, 816 + int64_t ne0); 817 + 818 + GGML_API struct ggml_tensor * ggml_new_tensor_2d( 819 + struct ggml_context * ctx, 820 + enum ggml_type type, 821 + int64_t ne0, 822 + int64_t ne1); 823 + 824 + GGML_API struct ggml_tensor * ggml_new_tensor_3d( 825 + struct ggml_context * ctx, 826 + enum ggml_type type, 827 + int64_t ne0, 828 + int64_t ne1, 829 + int64_t ne2); 830 + 831 + GGML_API struct ggml_tensor * ggml_new_tensor_4d( 832 + struct ggml_context * ctx, 833 + enum ggml_type type, 834 + int64_t ne0, 835 + int64_t ne1, 836 + int64_t ne2, 837 + int64_t ne3); 838 + 839 + GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes); 840 + 841 + GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); 842 + GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); 843 + 844 + // Context tensor enumeration and lookup 845 + GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx); 846 + GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor); 847 + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); 848 + 849 + // Converts a flat index into coordinates 850 + GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); 851 + 852 + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); 853 + GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor); 854 + 855 + GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); 856 + GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); 857 + 858 + GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); 859 + GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); 860 + GGML_ATTRIBUTE_FORMAT(2, 3) 861 + GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); 862 + 863 + // Tensor flags 864 + GGML_API void ggml_set_input(struct ggml_tensor * tensor); 865 + GGML_API void ggml_set_output(struct ggml_tensor * tensor); 866 + GGML_API void ggml_set_param(struct ggml_tensor * tensor); 867 + GGML_API void ggml_set_loss(struct ggml_tensor * tensor); 868 + 869 + // 870 + // operations on tensors with backpropagation 871 + // 872 + 873 + GGML_API struct ggml_tensor * ggml_dup( 874 + struct ggml_context * ctx, 875 + struct ggml_tensor * a); 876 + 877 + // in-place, returns view(a) 878 + GGML_API struct ggml_tensor * ggml_dup_inplace( 879 + struct ggml_context * ctx, 880 + struct ggml_tensor * a); 881 + 882 + GGML_API struct ggml_tensor * ggml_add( 883 + struct ggml_context * ctx, 884 + struct ggml_tensor * a, 885 + struct ggml_tensor * b); 886 + 887 + GGML_API struct ggml_tensor * ggml_add_inplace( 888 + struct ggml_context * ctx, 889 + struct ggml_tensor * a, 890 + struct ggml_tensor * b); 891 + 892 + GGML_API struct ggml_tensor * ggml_add_cast( 893 + struct ggml_context * ctx, 894 + struct ggml_tensor * a, 895 + struct ggml_tensor * b, 896 + enum ggml_type type); 897 + 898 + // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]] 899 + GGML_API struct ggml_tensor * ggml_add_id( 900 + struct ggml_context * ctx, 901 + struct ggml_tensor * a, 902 + struct ggml_tensor * b, 903 + struct ggml_tensor * ids); 904 + 905 + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1( 906 + struct ggml_context * ctx, 907 + struct ggml_tensor * a, 908 + struct ggml_tensor * b), 909 + "use ggml_add instead"); 910 + 911 + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace( 912 + struct ggml_context * ctx, 913 + struct ggml_tensor * a, 914 + struct ggml_tensor * b), 915 + "use ggml_add_inplace instead"); 916 + 917 + // dst = a 918 + // view(dst, nb1, nb2, nb3, offset) += b 919 + // return dst 920 + GGML_API struct ggml_tensor * ggml_acc( 921 + struct ggml_context * ctx, 922 + struct ggml_tensor * a, 923 + struct ggml_tensor * b, 924 + size_t nb1, 925 + size_t nb2, 926 + size_t nb3, 927 + size_t offset); 928 + 929 + GGML_API struct ggml_tensor * ggml_acc_inplace( 930 + struct ggml_context * ctx, 931 + struct ggml_tensor * a, 932 + struct ggml_tensor * b, 933 + size_t nb1, 934 + size_t nb2, 935 + size_t nb3, 936 + size_t offset); 937 + 938 + GGML_API struct ggml_tensor * ggml_sub( 939 + struct ggml_context * ctx, 940 + struct ggml_tensor * a, 941 + struct ggml_tensor * b); 942 + 943 + GGML_API struct ggml_tensor * ggml_sub_inplace( 944 + struct ggml_context * ctx, 945 + struct ggml_tensor * a, 946 + struct ggml_tensor * b); 947 + 948 + GGML_API struct ggml_tensor * ggml_mul( 949 + struct ggml_context * ctx, 950 + struct ggml_tensor * a, 951 + struct ggml_tensor * b); 952 + 953 + GGML_API struct ggml_tensor * ggml_mul_inplace( 954 + struct ggml_context * ctx, 955 + struct ggml_tensor * a, 956 + struct ggml_tensor * b); 957 + 958 + GGML_API struct ggml_tensor * ggml_div( 959 + struct ggml_context * ctx, 960 + struct ggml_tensor * a, 961 + struct ggml_tensor * b); 962 + 963 + GGML_API struct ggml_tensor * ggml_div_inplace( 964 + struct ggml_context * ctx, 965 + struct ggml_tensor * a, 966 + struct ggml_tensor * b); 967 + 968 + GGML_API struct ggml_tensor * ggml_sqr( 969 + struct ggml_context * ctx, 970 + struct ggml_tensor * a); 971 + 972 + GGML_API struct ggml_tensor * ggml_sqr_inplace( 973 + struct ggml_context * ctx, 974 + struct ggml_tensor * a); 975 + 976 + GGML_API struct ggml_tensor * ggml_sqrt( 977 + struct ggml_context * ctx, 978 + struct ggml_tensor * a); 979 + 980 + GGML_API struct ggml_tensor * ggml_sqrt_inplace( 981 + struct ggml_context * ctx, 982 + struct ggml_tensor * a); 983 + 984 + GGML_API struct ggml_tensor * ggml_log( 985 + struct ggml_context * ctx, 986 + struct ggml_tensor * a); 987 + 988 + GGML_API struct ggml_tensor * ggml_log_inplace( 989 + struct ggml_context * ctx, 990 + struct ggml_tensor * a); 991 + 992 + GGML_API struct ggml_tensor * ggml_expm1( 993 + struct ggml_context * ctx, 994 + struct ggml_tensor * a); 995 + 996 + GGML_API struct ggml_tensor * ggml_expm1_inplace( 997 + struct ggml_context * ctx, 998 + struct ggml_tensor * a); 999 + 1000 + GGML_API struct ggml_tensor * ggml_softplus( 1001 + struct ggml_context * ctx, 1002 + struct ggml_tensor * a); 1003 + 1004 + GGML_API struct ggml_tensor * ggml_softplus_inplace( 1005 + struct ggml_context * ctx, 1006 + struct ggml_tensor * a); 1007 + 1008 + GGML_API struct ggml_tensor * ggml_sin( 1009 + struct ggml_context * ctx, 1010 + struct ggml_tensor * a); 1011 + 1012 + GGML_API struct ggml_tensor * ggml_sin_inplace( 1013 + struct ggml_context * ctx, 1014 + struct ggml_tensor * a); 1015 + 1016 + GGML_API struct ggml_tensor * ggml_cos( 1017 + struct ggml_context * ctx, 1018 + struct ggml_tensor * a); 1019 + 1020 + GGML_API struct ggml_tensor * ggml_cos_inplace( 1021 + struct ggml_context * ctx, 1022 + struct ggml_tensor * a); 1023 + 1024 + // return scalar 1025 + GGML_API struct ggml_tensor * ggml_sum( 1026 + struct ggml_context * ctx, 1027 + struct ggml_tensor * a); 1028 + 1029 + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] 1030 + GGML_API struct ggml_tensor * ggml_sum_rows( 1031 + struct ggml_context * ctx, 1032 + struct ggml_tensor * a); 1033 + 1034 + GGML_API struct ggml_tensor * ggml_cumsum( 1035 + struct ggml_context * ctx, 1036 + struct ggml_tensor * a); 1037 + 1038 + // mean along rows 1039 + GGML_API struct ggml_tensor * ggml_mean( 1040 + struct ggml_context * ctx, 1041 + struct ggml_tensor * a); 1042 + 1043 + // argmax along rows 1044 + GGML_API struct ggml_tensor * ggml_argmax( 1045 + struct ggml_context * ctx, 1046 + struct ggml_tensor * a); 1047 + 1048 + // count number of equal elements in a and b 1049 + GGML_API struct ggml_tensor * ggml_count_equal( 1050 + struct ggml_context * ctx, 1051 + struct ggml_tensor * a, 1052 + struct ggml_tensor * b); 1053 + 1054 + // if a is the same shape as b, and a is not parameter, return a 1055 + // otherwise, return a new tensor: repeat(a) to fit in b 1056 + GGML_API struct ggml_tensor * ggml_repeat( 1057 + struct ggml_context * ctx, 1058 + struct ggml_tensor * a, 1059 + struct ggml_tensor * b); 1060 + 1061 + // repeat a to the specified shape 1062 + GGML_API struct ggml_tensor * ggml_repeat_4d( 1063 + struct ggml_context * ctx, 1064 + struct ggml_tensor * a, 1065 + int64_t ne0, 1066 + int64_t ne1, 1067 + int64_t ne2, 1068 + int64_t ne3); 1069 + 1070 + // sums repetitions in a into shape of b 1071 + GGML_API struct ggml_tensor * ggml_repeat_back( 1072 + struct ggml_context * ctx, 1073 + struct ggml_tensor * a, 1074 + struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride 1075 + 1076 + // concat a and b along dim 1077 + // used in stable-diffusion 1078 + GGML_API struct ggml_tensor * ggml_concat( 1079 + struct ggml_context * ctx, 1080 + struct ggml_tensor * a, 1081 + struct ggml_tensor * b, 1082 + int dim); 1083 + 1084 + GGML_API struct ggml_tensor * ggml_abs( 1085 + struct ggml_context * ctx, 1086 + struct ggml_tensor * a); 1087 + 1088 + GGML_API struct ggml_tensor * ggml_abs_inplace( 1089 + struct ggml_context * ctx, 1090 + struct ggml_tensor * a); 1091 + 1092 + GGML_API struct ggml_tensor * ggml_sgn( 1093 + struct ggml_context * ctx, 1094 + struct ggml_tensor * a); 1095 + 1096 + GGML_API struct ggml_tensor * ggml_sgn_inplace( 1097 + struct ggml_context * ctx, 1098 + struct ggml_tensor * a); 1099 + 1100 + GGML_API struct ggml_tensor * ggml_neg( 1101 + struct ggml_context * ctx, 1102 + struct ggml_tensor * a); 1103 + 1104 + GGML_API struct ggml_tensor * ggml_neg_inplace( 1105 + struct ggml_context * ctx, 1106 + struct ggml_tensor * a); 1107 + 1108 + GGML_API struct ggml_tensor * ggml_step( 1109 + struct ggml_context * ctx, 1110 + struct ggml_tensor * a); 1111 + 1112 + GGML_API struct ggml_tensor * ggml_step_inplace( 1113 + struct ggml_context * ctx, 1114 + struct ggml_tensor * a); 1115 + 1116 + GGML_API struct ggml_tensor * ggml_tanh( 1117 + struct ggml_context * ctx, 1118 + struct ggml_tensor * a); 1119 + 1120 + GGML_API struct ggml_tensor * ggml_tanh_inplace( 1121 + struct ggml_context * ctx, 1122 + struct ggml_tensor * a); 1123 + 1124 + GGML_API struct ggml_tensor * ggml_elu( 1125 + struct ggml_context * ctx, 1126 + struct ggml_tensor * a); 1127 + 1128 + GGML_API struct ggml_tensor * ggml_elu_inplace( 1129 + struct ggml_context * ctx, 1130 + struct ggml_tensor * a); 1131 + 1132 + GGML_API struct ggml_tensor * ggml_relu( 1133 + struct ggml_context * ctx, 1134 + struct ggml_tensor * a); 1135 + 1136 + GGML_API struct ggml_tensor * ggml_leaky_relu( 1137 + struct ggml_context * ctx, 1138 + struct ggml_tensor * a, float negative_slope, bool inplace); 1139 + 1140 + GGML_API struct ggml_tensor * ggml_relu_inplace( 1141 + struct ggml_context * ctx, 1142 + struct ggml_tensor * a); 1143 + 1144 + GGML_API struct ggml_tensor * ggml_sigmoid( 1145 + struct ggml_context * ctx, 1146 + struct ggml_tensor * a); 1147 + 1148 + GGML_API struct ggml_tensor * ggml_sigmoid_inplace( 1149 + struct ggml_context * ctx, 1150 + struct ggml_tensor * a); 1151 + 1152 + GGML_API struct ggml_tensor * ggml_gelu( 1153 + struct ggml_context * ctx, 1154 + struct ggml_tensor * a); 1155 + 1156 + GGML_API struct ggml_tensor * ggml_gelu_inplace( 1157 + struct ggml_context * ctx, 1158 + struct ggml_tensor * a); 1159 + 1160 + // GELU using erf (error function) when possible 1161 + // some backends may fallback to approximation based on Abramowitz and Stegun formula 1162 + GGML_API struct ggml_tensor * ggml_gelu_erf( 1163 + struct ggml_context * ctx, 1164 + struct ggml_tensor * a); 1165 + 1166 + GGML_API struct ggml_tensor * ggml_gelu_erf_inplace( 1167 + struct ggml_context * ctx, 1168 + struct ggml_tensor * a); 1169 + 1170 + GGML_API struct ggml_tensor * ggml_gelu_quick( 1171 + struct ggml_context * ctx, 1172 + struct ggml_tensor * a); 1173 + 1174 + GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( 1175 + struct ggml_context * ctx, 1176 + struct ggml_tensor * a); 1177 + 1178 + GGML_API struct ggml_tensor * ggml_silu( 1179 + struct ggml_context * ctx, 1180 + struct ggml_tensor * a); 1181 + 1182 + GGML_API struct ggml_tensor * ggml_silu_inplace( 1183 + struct ggml_context * ctx, 1184 + struct ggml_tensor * a); 1185 + 1186 + // a - x 1187 + // b - dy 1188 + GGML_API struct ggml_tensor * ggml_silu_back( 1189 + struct ggml_context * ctx, 1190 + struct ggml_tensor * a, 1191 + struct ggml_tensor * b); 1192 + 1193 + // hardswish(x) = x * relu6(x + 3) / 6 1194 + GGML_API struct ggml_tensor * ggml_hardswish( 1195 + struct ggml_context * ctx, 1196 + struct ggml_tensor * a); 1197 + 1198 + // hardsigmoid(x) = relu6(x + 3) / 6 1199 + GGML_API struct ggml_tensor * ggml_hardsigmoid( 1200 + struct ggml_context * ctx, 1201 + struct ggml_tensor * a); 1202 + 1203 + GGML_API struct ggml_tensor * ggml_exp( 1204 + struct ggml_context * ctx, 1205 + struct ggml_tensor * a); 1206 + 1207 + GGML_API struct ggml_tensor * ggml_exp_inplace( 1208 + struct ggml_context * ctx, 1209 + struct ggml_tensor * a); 1210 + 1211 + GGML_API struct ggml_tensor * ggml_floor( 1212 + struct ggml_context * ctx, 1213 + struct ggml_tensor * a); 1214 + 1215 + GGML_API struct ggml_tensor * ggml_floor_inplace( 1216 + struct ggml_context * ctx, 1217 + struct ggml_tensor * a); 1218 + 1219 + GGML_API struct ggml_tensor * ggml_ceil( 1220 + struct ggml_context * ctx, 1221 + struct ggml_tensor * a); 1222 + 1223 + GGML_API struct ggml_tensor * ggml_ceil_inplace( 1224 + struct ggml_context * ctx, 1225 + struct ggml_tensor * a); 1226 + 1227 + GGML_API struct ggml_tensor * ggml_round( 1228 + struct ggml_context * ctx, 1229 + struct ggml_tensor * a); 1230 + 1231 + GGML_API struct ggml_tensor * ggml_round_inplace( 1232 + struct ggml_context * ctx, 1233 + struct ggml_tensor * a); 1234 + 1235 + /** 1236 + * Truncates the fractional part of each element in the tensor (towards zero). 1237 + * For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0 1238 + * Similar to std::trunc in C/C++. 1239 + */ 1240 + 1241 + GGML_API struct ggml_tensor * ggml_trunc( 1242 + struct ggml_context * ctx, 1243 + struct ggml_tensor * a); 1244 + 1245 + GGML_API struct ggml_tensor * ggml_trunc_inplace( 1246 + struct ggml_context * ctx, 1247 + struct ggml_tensor * a); 1248 + 1249 + 1250 + 1251 + // xIELU activation function 1252 + // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) 1253 + // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions 1254 + // that constrain the positive and negative source alpha values respectively 1255 + GGML_API struct ggml_tensor * ggml_xielu( 1256 + struct ggml_context * ctx, 1257 + struct ggml_tensor * a, 1258 + float alpha_n, 1259 + float alpha_p, 1260 + float beta, 1261 + float eps); 1262 + 1263 + // gated linear unit ops 1264 + // A: n columns, r rows, 1265 + // result is n / 2 columns, r rows, 1266 + // expects gate in second half of row, unless swapped is true 1267 + GGML_API struct ggml_tensor * ggml_glu( 1268 + struct ggml_context * ctx, 1269 + struct ggml_tensor * a, 1270 + enum ggml_glu_op op, 1271 + bool swapped); 1272 + 1273 + GGML_API struct ggml_tensor * ggml_reglu( 1274 + struct ggml_context * ctx, 1275 + struct ggml_tensor * a); 1276 + 1277 + GGML_API struct ggml_tensor * ggml_reglu_swapped( 1278 + struct ggml_context * ctx, 1279 + struct ggml_tensor * a); 1280 + 1281 + GGML_API struct ggml_tensor * ggml_geglu( 1282 + struct ggml_context * ctx, 1283 + struct ggml_tensor * a); 1284 + 1285 + GGML_API struct ggml_tensor * ggml_geglu_swapped( 1286 + struct ggml_context * ctx, 1287 + struct ggml_tensor * a); 1288 + 1289 + GGML_API struct ggml_tensor * ggml_swiglu( 1290 + struct ggml_context * ctx, 1291 + struct ggml_tensor * a); 1292 + 1293 + GGML_API struct ggml_tensor * ggml_swiglu_swapped( 1294 + struct ggml_context * ctx, 1295 + struct ggml_tensor * a); 1296 + 1297 + GGML_API struct ggml_tensor * ggml_geglu_erf( 1298 + struct ggml_context * ctx, 1299 + struct ggml_tensor * a); 1300 + 1301 + GGML_API struct ggml_tensor * ggml_geglu_erf_swapped( 1302 + struct ggml_context * ctx, 1303 + struct ggml_tensor * a); 1304 + 1305 + GGML_API struct ggml_tensor * ggml_geglu_quick( 1306 + struct ggml_context * ctx, 1307 + struct ggml_tensor * a); 1308 + 1309 + GGML_API struct ggml_tensor * ggml_geglu_quick_swapped( 1310 + struct ggml_context * ctx, 1311 + struct ggml_tensor * a); 1312 + 1313 + // A: n columns, r rows, 1314 + // B: n columns, r rows, 1315 + GGML_API struct ggml_tensor * ggml_glu_split( 1316 + struct ggml_context * ctx, 1317 + struct ggml_tensor * a, 1318 + struct ggml_tensor * b, 1319 + enum ggml_glu_op op); 1320 + 1321 + GGML_API struct ggml_tensor * ggml_reglu_split( 1322 + struct ggml_context * ctx, 1323 + struct ggml_tensor * a, 1324 + struct ggml_tensor * b); 1325 + 1326 + GGML_API struct ggml_tensor * ggml_geglu_split( 1327 + struct ggml_context * ctx, 1328 + struct ggml_tensor * a, 1329 + struct ggml_tensor * b); 1330 + 1331 + GGML_API struct ggml_tensor * ggml_swiglu_split( 1332 + struct ggml_context * ctx, 1333 + struct ggml_tensor * a, 1334 + struct ggml_tensor * b); 1335 + 1336 + GGML_API struct ggml_tensor * ggml_geglu_erf_split( 1337 + struct ggml_context * ctx, 1338 + struct ggml_tensor * a, 1339 + struct ggml_tensor * b); 1340 + 1341 + GGML_API struct ggml_tensor * ggml_geglu_quick_split( 1342 + struct ggml_context * ctx, 1343 + struct ggml_tensor * a, 1344 + struct ggml_tensor * b); 1345 + 1346 + GGML_API struct ggml_tensor * ggml_swiglu_oai( 1347 + struct ggml_context * ctx, 1348 + struct ggml_tensor * a, 1349 + struct ggml_tensor * b, 1350 + float alpha, 1351 + float limit); 1352 + 1353 + // normalize along rows 1354 + GGML_API struct ggml_tensor * ggml_norm( 1355 + struct ggml_context * ctx, 1356 + struct ggml_tensor * a, 1357 + float eps); 1358 + 1359 + GGML_API struct ggml_tensor * ggml_norm_inplace( 1360 + struct ggml_context * ctx, 1361 + struct ggml_tensor * a, 1362 + float eps); 1363 + 1364 + GGML_API struct ggml_tensor * ggml_rms_norm( 1365 + struct ggml_context * ctx, 1366 + struct ggml_tensor * a, 1367 + float eps); 1368 + 1369 + GGML_API struct ggml_tensor * ggml_rms_norm_inplace( 1370 + struct ggml_context * ctx, 1371 + struct ggml_tensor * a, 1372 + float eps); 1373 + 1374 + // group normalize along ne0*ne1*n_groups 1375 + // used in stable-diffusion 1376 + GGML_API struct ggml_tensor * ggml_group_norm( 1377 + struct ggml_context * ctx, 1378 + struct ggml_tensor * a, 1379 + int n_groups, 1380 + float eps); 1381 + 1382 + GGML_API struct ggml_tensor * ggml_group_norm_inplace( 1383 + struct ggml_context * ctx, 1384 + struct ggml_tensor * a, 1385 + int n_groups, 1386 + float eps); 1387 + 1388 + // l2 normalize along rows 1389 + // used in rwkv v7 1390 + GGML_API struct ggml_tensor * ggml_l2_norm( 1391 + struct ggml_context * ctx, 1392 + struct ggml_tensor * a, 1393 + float eps); 1394 + 1395 + GGML_API struct ggml_tensor * ggml_l2_norm_inplace( 1396 + struct ggml_context * ctx, 1397 + struct ggml_tensor * a, 1398 + float eps); 1399 + 1400 + // a - x 1401 + // b - dy 1402 + GGML_API struct ggml_tensor * ggml_rms_norm_back( 1403 + struct ggml_context * ctx, 1404 + struct ggml_tensor * a, 1405 + struct ggml_tensor * b, 1406 + float eps); 1407 + 1408 + // A: k columns, n rows => [ne03, ne02, n, k] 1409 + // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k] 1410 + // result is n columns, m rows => [ne03 * x, ne02 * y, m, n] 1411 + GGML_API struct ggml_tensor * ggml_mul_mat( 1412 + struct ggml_context * ctx, 1413 + struct ggml_tensor * a, 1414 + struct ggml_tensor * b); 1415 + 1416 + // change the precision of a matrix multiplication 1417 + // set to GGML_PREC_F32 for higher precision (useful for phi-2) 1418 + GGML_API void ggml_mul_mat_set_prec( 1419 + struct ggml_tensor * a, 1420 + enum ggml_prec prec); 1421 + 1422 + // indirect matrix multiplication 1423 + GGML_API struct ggml_tensor * ggml_mul_mat_id( 1424 + struct ggml_context * ctx, 1425 + struct ggml_tensor * as, 1426 + struct ggml_tensor * b, 1427 + struct ggml_tensor * ids); 1428 + 1429 + // A: m columns, n rows, 1430 + // B: p columns, n rows, 1431 + // result is m columns, p rows 1432 + GGML_API struct ggml_tensor * ggml_out_prod( 1433 + struct ggml_context * ctx, 1434 + struct ggml_tensor * a, 1435 + struct ggml_tensor * b); 1436 + 1437 + // 1438 + // operations on tensors without backpropagation 1439 + // 1440 + 1441 + GGML_API struct ggml_tensor * ggml_scale( 1442 + struct ggml_context * ctx, 1443 + struct ggml_tensor * a, 1444 + float s); 1445 + 1446 + // in-place, returns view(a) 1447 + GGML_API struct ggml_tensor * ggml_scale_inplace( 1448 + struct ggml_context * ctx, 1449 + struct ggml_tensor * a, 1450 + float s); 1451 + 1452 + // x = s * a + b 1453 + GGML_API struct ggml_tensor * ggml_scale_bias( 1454 + struct ggml_context * ctx, 1455 + struct ggml_tensor * a, 1456 + float s, 1457 + float b); 1458 + 1459 + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( 1460 + struct ggml_context * ctx, 1461 + struct ggml_tensor * a, 1462 + float s, 1463 + float b); 1464 + 1465 + // b -> view(a,offset,nb1,nb2,3), return modified a 1466 + GGML_API struct ggml_tensor * ggml_set( 1467 + struct ggml_context * ctx, 1468 + struct ggml_tensor * a, 1469 + struct ggml_tensor * b, 1470 + size_t nb1, 1471 + size_t nb2, 1472 + size_t nb3, 1473 + size_t offset); // in bytes 1474 + 1475 + // b -> view(a,offset,nb1,nb2,3), return view(a) 1476 + GGML_API struct ggml_tensor * ggml_set_inplace( 1477 + struct ggml_context * ctx, 1478 + struct ggml_tensor * a, 1479 + struct ggml_tensor * b, 1480 + size_t nb1, 1481 + size_t nb2, 1482 + size_t nb3, 1483 + size_t offset); // in bytes 1484 + 1485 + GGML_API struct ggml_tensor * ggml_set_1d( 1486 + struct ggml_context * ctx, 1487 + struct ggml_tensor * a, 1488 + struct ggml_tensor * b, 1489 + size_t offset); // in bytes 1490 + 1491 + GGML_API struct ggml_tensor * ggml_set_1d_inplace( 1492 + struct ggml_context * ctx, 1493 + struct ggml_tensor * a, 1494 + struct ggml_tensor * b, 1495 + size_t offset); // in bytes 1496 + 1497 + // b -> view(a,offset,nb1,nb2,3), return modified a 1498 + GGML_API struct ggml_tensor * ggml_set_2d( 1499 + struct ggml_context * ctx, 1500 + struct ggml_tensor * a, 1501 + struct ggml_tensor * b, 1502 + size_t nb1, 1503 + size_t offset); // in bytes 1504 + 1505 + // b -> view(a,offset,nb1,nb2,3), return view(a) 1506 + GGML_API struct ggml_tensor * ggml_set_2d_inplace( 1507 + struct ggml_context * ctx, 1508 + struct ggml_tensor * a, 1509 + struct ggml_tensor * b, 1510 + size_t nb1, 1511 + size_t offset); // in bytes 1512 + 1513 + // a -> b, return view(b) 1514 + GGML_API struct ggml_tensor * ggml_cpy( 1515 + struct ggml_context * ctx, 1516 + struct ggml_tensor * a, 1517 + struct ggml_tensor * b); 1518 + 1519 + // note: casting from f32 to i32 will discard the fractional part 1520 + GGML_API struct ggml_tensor * ggml_cast( 1521 + struct ggml_context * ctx, 1522 + struct ggml_tensor * a, 1523 + enum ggml_type type); 1524 + 1525 + // make contiguous 1526 + GGML_API struct ggml_tensor * ggml_cont( 1527 + struct ggml_context * ctx, 1528 + struct ggml_tensor * a); 1529 + 1530 + // make contiguous, with new shape 1531 + GGML_API struct ggml_tensor * ggml_cont_1d( 1532 + struct ggml_context * ctx, 1533 + struct ggml_tensor * a, 1534 + int64_t ne0); 1535 + 1536 + GGML_API struct ggml_tensor * ggml_cont_2d( 1537 + struct ggml_context * ctx, 1538 + struct ggml_tensor * a, 1539 + int64_t ne0, 1540 + int64_t ne1); 1541 + 1542 + GGML_API struct ggml_tensor * ggml_cont_3d( 1543 + struct ggml_context * ctx, 1544 + struct ggml_tensor * a, 1545 + int64_t ne0, 1546 + int64_t ne1, 1547 + int64_t ne2); 1548 + 1549 + GGML_API struct ggml_tensor * ggml_cont_4d( 1550 + struct ggml_context * ctx, 1551 + struct ggml_tensor * a, 1552 + int64_t ne0, 1553 + int64_t ne1, 1554 + int64_t ne2, 1555 + int64_t ne3); 1556 + 1557 + // return view(a), b specifies the new shape 1558 + // TODO: when we start computing gradient, make a copy instead of view 1559 + GGML_API struct ggml_tensor * ggml_reshape( 1560 + struct ggml_context * ctx, 1561 + struct ggml_tensor * a, 1562 + struct ggml_tensor * b); 1563 + 1564 + // return view(a) 1565 + // TODO: when we start computing gradient, make a copy instead of view 1566 + GGML_API struct ggml_tensor * ggml_reshape_1d( 1567 + struct ggml_context * ctx, 1568 + struct ggml_tensor * a, 1569 + int64_t ne0); 1570 + 1571 + GGML_API struct ggml_tensor * ggml_reshape_2d( 1572 + struct ggml_context * ctx, 1573 + struct ggml_tensor * a, 1574 + int64_t ne0, 1575 + int64_t ne1); 1576 + 1577 + // return view(a) 1578 + // TODO: when we start computing gradient, make a copy instead of view 1579 + GGML_API struct ggml_tensor * ggml_reshape_3d( 1580 + struct ggml_context * ctx, 1581 + struct ggml_tensor * a, 1582 + int64_t ne0, 1583 + int64_t ne1, 1584 + int64_t ne2); 1585 + 1586 + GGML_API struct ggml_tensor * ggml_reshape_4d( 1587 + struct ggml_context * ctx, 1588 + struct ggml_tensor * a, 1589 + int64_t ne0, 1590 + int64_t ne1, 1591 + int64_t ne2, 1592 + int64_t ne3); 1593 + 1594 + // offset in bytes 1595 + GGML_API struct ggml_tensor * ggml_view_1d( 1596 + struct ggml_context * ctx, 1597 + struct ggml_tensor * a, 1598 + int64_t ne0, 1599 + size_t offset); 1600 + 1601 + GGML_API struct ggml_tensor * ggml_view_2d( 1602 + struct ggml_context * ctx, 1603 + struct ggml_tensor * a, 1604 + int64_t ne0, 1605 + int64_t ne1, 1606 + size_t nb1, // row stride in bytes 1607 + size_t offset); 1608 + 1609 + GGML_API struct ggml_tensor * ggml_view_3d( 1610 + struct ggml_context * ctx, 1611 + struct ggml_tensor * a, 1612 + int64_t ne0, 1613 + int64_t ne1, 1614 + int64_t ne2, 1615 + size_t nb1, // row stride in bytes 1616 + size_t nb2, // slice stride in bytes 1617 + size_t offset); 1618 + 1619 + GGML_API struct ggml_tensor * ggml_view_4d( 1620 + struct ggml_context * ctx, 1621 + struct ggml_tensor * a, 1622 + int64_t ne0, 1623 + int64_t ne1, 1624 + int64_t ne2, 1625 + int64_t ne3, 1626 + size_t nb1, // row stride in bytes 1627 + size_t nb2, // slice stride in bytes 1628 + size_t nb3, 1629 + size_t offset); 1630 + 1631 + GGML_API struct ggml_tensor * ggml_permute( 1632 + struct ggml_context * ctx, 1633 + struct ggml_tensor * a, 1634 + int axis0, 1635 + int axis1, 1636 + int axis2, 1637 + int axis3); 1638 + 1639 + // alias for ggml_permute(ctx, a, 1, 0, 2, 3) 1640 + GGML_API struct ggml_tensor * ggml_transpose( 1641 + struct ggml_context * ctx, 1642 + struct ggml_tensor * a); 1643 + 1644 + // supports 4D a: 1645 + // a [n_embd, ne1, ne2, ne3] 1646 + // b I32 [n_rows, ne2, ne3, 1] 1647 + // 1648 + // return [n_embd, n_rows, ne2, ne3] 1649 + GGML_API struct ggml_tensor * ggml_get_rows( 1650 + struct ggml_context * ctx, 1651 + struct ggml_tensor * a, // data 1652 + struct ggml_tensor * b); // row indices 1653 + 1654 + GGML_API struct ggml_tensor * ggml_get_rows_back( 1655 + struct ggml_context * ctx, 1656 + struct ggml_tensor * a, // gradients of ggml_get_rows result 1657 + struct ggml_tensor * b, // row indices 1658 + struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape 1659 + 1660 + // a TD [n_embd, ne1, ne2, ne3] 1661 + // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3 1662 + // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1) 1663 + // 1664 + // undefined behavior if destination rows overlap 1665 + // 1666 + // broadcast: 1667 + // ne2 % ne11 == 0 1668 + // ne3 % ne12 == 0 1669 + // 1670 + // return view(a) 1671 + GGML_API struct ggml_tensor * ggml_set_rows( 1672 + struct ggml_context * ctx, 1673 + struct ggml_tensor * a, // destination 1674 + struct ggml_tensor * b, // source 1675 + struct ggml_tensor * c); // row indices 1676 + 1677 + GGML_API struct ggml_tensor * ggml_diag( 1678 + struct ggml_context * ctx, 1679 + struct ggml_tensor * a); 1680 + 1681 + // set elements above the diagonal to -INF 1682 + GGML_API struct ggml_tensor * ggml_diag_mask_inf( 1683 + struct ggml_context * ctx, 1684 + struct ggml_tensor * a, 1685 + int n_past); 1686 + 1687 + // in-place, returns view(a) 1688 + GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace( 1689 + struct ggml_context * ctx, 1690 + struct ggml_tensor * a, 1691 + int n_past); 1692 + 1693 + // set elements above the diagonal to 0 1694 + GGML_API struct ggml_tensor * ggml_diag_mask_zero( 1695 + struct ggml_context * ctx, 1696 + struct ggml_tensor * a, 1697 + int n_past); 1698 + 1699 + // in-place, returns view(a) 1700 + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( 1701 + struct ggml_context * ctx, 1702 + struct ggml_tensor * a, 1703 + int n_past); 1704 + 1705 + GGML_API struct ggml_tensor * ggml_soft_max( 1706 + struct ggml_context * ctx, 1707 + struct ggml_tensor * a); 1708 + 1709 + // in-place, returns view(a) 1710 + GGML_API struct ggml_tensor * ggml_soft_max_inplace( 1711 + struct ggml_context * ctx, 1712 + struct ggml_tensor * a); 1713 + 1714 + // a [ne0, ne01, ne02, ne03] 1715 + // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional 1716 + // 1717 + // broadcast: 1718 + // ne02 % ne12 == 0 1719 + // ne03 % ne13 == 0 1720 + // 1721 + // fused soft_max(a*scale + mask*(ALiBi slope)) 1722 + // max_bias = 0.0f for no ALiBi 1723 + GGML_API struct ggml_tensor * ggml_soft_max_ext( 1724 + struct ggml_context * ctx, 1725 + struct ggml_tensor * a, 1726 + struct ggml_tensor * mask, 1727 + float scale, 1728 + float max_bias); 1729 + 1730 + GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace( 1731 + struct ggml_context * ctx, 1732 + struct ggml_tensor * a, 1733 + struct ggml_tensor * mask, 1734 + float scale, 1735 + float max_bias); 1736 + 1737 + GGML_API void ggml_soft_max_add_sinks( 1738 + struct ggml_tensor * a, 1739 + struct ggml_tensor * sinks); 1740 + 1741 + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( 1742 + struct ggml_context * ctx, 1743 + struct ggml_tensor * a, 1744 + struct ggml_tensor * b, 1745 + float scale, 1746 + float max_bias); 1747 + 1748 + // in-place, returns view(a) 1749 + GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace( 1750 + struct ggml_context * ctx, 1751 + struct ggml_tensor * a, 1752 + struct ggml_tensor * b, 1753 + float scale, 1754 + float max_bias); 1755 + 1756 + // rotary position embedding 1757 + // if (mode & 1) - skip n_past elements (NOT SUPPORTED) 1758 + // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style 1759 + // 1760 + // b is an int32 vector with size a->ne[2], it contains the positions 1761 + GGML_API struct ggml_tensor * ggml_rope( 1762 + struct ggml_context * ctx, 1763 + struct ggml_tensor * a, 1764 + struct ggml_tensor * b, 1765 + int n_dims, 1766 + int mode); 1767 + 1768 + // in-place, returns view(a) 1769 + GGML_API struct ggml_tensor * ggml_rope_inplace( 1770 + struct ggml_context * ctx, 1771 + struct ggml_tensor * a, 1772 + struct ggml_tensor * b, 1773 + int n_dims, 1774 + int mode); 1775 + 1776 + // custom RoPE 1777 + // c is freq factors (e.g. phi3-128k), (optional) 1778 + GGML_API struct ggml_tensor * ggml_rope_ext( 1779 + struct ggml_context * ctx, 1780 + struct ggml_tensor * a, 1781 + struct ggml_tensor * b, 1782 + struct ggml_tensor * c, 1783 + int n_dims, 1784 + int mode, 1785 + int n_ctx_orig, 1786 + float freq_base, 1787 + float freq_scale, 1788 + float ext_factor, 1789 + float attn_factor, 1790 + float beta_fast, 1791 + float beta_slow); 1792 + 1793 + GGML_API struct ggml_tensor * ggml_rope_multi( 1794 + struct ggml_context * ctx, 1795 + struct ggml_tensor * a, 1796 + struct ggml_tensor * b, 1797 + struct ggml_tensor * c, 1798 + int n_dims, 1799 + int sections[GGML_MROPE_SECTIONS], 1800 + int mode, 1801 + int n_ctx_orig, 1802 + float freq_base, 1803 + float freq_scale, 1804 + float ext_factor, 1805 + float attn_factor, 1806 + float beta_fast, 1807 + float beta_slow); 1808 + 1809 + // in-place, returns view(a) 1810 + GGML_API struct ggml_tensor * ggml_rope_ext_inplace( 1811 + struct ggml_context * ctx, 1812 + struct ggml_tensor * a, 1813 + struct ggml_tensor * b, 1814 + struct ggml_tensor * c, 1815 + int n_dims, 1816 + int mode, 1817 + int n_ctx_orig, 1818 + float freq_base, 1819 + float freq_scale, 1820 + float ext_factor, 1821 + float attn_factor, 1822 + float beta_fast, 1823 + float beta_slow); 1824 + 1825 + GGML_API struct ggml_tensor * ggml_rope_multi_inplace( 1826 + struct ggml_context * ctx, 1827 + struct ggml_tensor * a, 1828 + struct ggml_tensor * b, 1829 + struct ggml_tensor * c, 1830 + int n_dims, 1831 + int sections[GGML_MROPE_SECTIONS], 1832 + int mode, 1833 + int n_ctx_orig, 1834 + float freq_base, 1835 + float freq_scale, 1836 + float ext_factor, 1837 + float attn_factor, 1838 + float beta_fast, 1839 + float beta_slow); 1840 + 1841 + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( 1842 + struct ggml_context * ctx, 1843 + struct ggml_tensor * a, 1844 + struct ggml_tensor * b, 1845 + int n_dims, 1846 + int mode, 1847 + int n_ctx_orig, 1848 + float freq_base, 1849 + float freq_scale, 1850 + float ext_factor, 1851 + float attn_factor, 1852 + float beta_fast, 1853 + float beta_slow), 1854 + "use ggml_rope_ext instead"); 1855 + 1856 + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace( 1857 + struct ggml_context * ctx, 1858 + struct ggml_tensor * a, 1859 + struct ggml_tensor * b, 1860 + int n_dims, 1861 + int mode, 1862 + int n_ctx_orig, 1863 + float freq_base, 1864 + float freq_scale, 1865 + float ext_factor, 1866 + float attn_factor, 1867 + float beta_fast, 1868 + float beta_slow), 1869 + "use ggml_rope_ext_inplace instead"); 1870 + 1871 + // compute correction dims for YaRN RoPE scaling 1872 + GGML_API void ggml_rope_yarn_corr_dims( 1873 + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); 1874 + 1875 + // rotary position embedding backward, i.e compute dx from dy 1876 + // a - dy 1877 + GGML_API struct ggml_tensor * ggml_rope_ext_back( 1878 + struct ggml_context * ctx, 1879 + struct ggml_tensor * a, // gradients of ggml_rope result 1880 + struct ggml_tensor * b, // positions 1881 + struct ggml_tensor * c, // freq factors 1882 + int n_dims, 1883 + int mode, 1884 + int n_ctx_orig, 1885 + float freq_base, 1886 + float freq_scale, 1887 + float ext_factor, 1888 + float attn_factor, 1889 + float beta_fast, 1890 + float beta_slow); 1891 + 1892 + GGML_API struct ggml_tensor * ggml_rope_multi_back( 1893 + struct ggml_context * ctx, 1894 + struct ggml_tensor * a, 1895 + struct ggml_tensor * b, 1896 + struct ggml_tensor * c, 1897 + int n_dims, 1898 + int sections[4], 1899 + int mode, 1900 + int n_ctx_orig, 1901 + float freq_base, 1902 + float freq_scale, 1903 + float ext_factor, 1904 + float attn_factor, 1905 + float beta_fast, 1906 + float beta_slow); 1907 + 1908 + 1909 + // clamp 1910 + // in-place, returns view(a) 1911 + GGML_API struct ggml_tensor * ggml_clamp( 1912 + struct ggml_context * ctx, 1913 + struct ggml_tensor * a, 1914 + float min, 1915 + float max); 1916 + 1917 + // im2col 1918 + // converts data into a format that effectively results in a convolution when combined with matrix multiplication 1919 + GGML_API struct ggml_tensor * ggml_im2col( 1920 + struct ggml_context * ctx, 1921 + struct ggml_tensor * a, // convolution kernel 1922 + struct ggml_tensor * b, // data 1923 + int s0, // stride dimension 0 1924 + int s1, // stride dimension 1 1925 + int p0, // padding dimension 0 1926 + int p1, // padding dimension 1 1927 + int d0, // dilation dimension 0 1928 + int d1, // dilation dimension 1 1929 + bool is_2D, 1930 + enum ggml_type dst_type); 1931 + 1932 + GGML_API struct ggml_tensor * ggml_im2col_back( 1933 + struct ggml_context * ctx, 1934 + struct ggml_tensor * a, // convolution kernel 1935 + struct ggml_tensor * b, // gradient of im2col output 1936 + int64_t * ne, // shape of im2col input 1937 + int s0, // stride dimension 0 1938 + int s1, // stride dimension 1 1939 + int p0, // padding dimension 0 1940 + int p1, // padding dimension 1 1941 + int d0, // dilation dimension 0 1942 + int d1, // dilation dimension 1 1943 + bool is_2D); 1944 + 1945 + GGML_API struct ggml_tensor * ggml_conv_1d( 1946 + struct ggml_context * ctx, 1947 + struct ggml_tensor * a, // convolution kernel 1948 + struct ggml_tensor * b, // data 1949 + int s0, // stride 1950 + int p0, // padding 1951 + int d0); // dilation 1952 + 1953 + // conv_1d with padding = half 1954 + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) 1955 + GGML_API struct ggml_tensor* ggml_conv_1d_ph( 1956 + struct ggml_context * ctx, 1957 + struct ggml_tensor * a, // convolution kernel 1958 + struct ggml_tensor * b, // data 1959 + int s, // stride 1960 + int d); // dilation 1961 + 1962 + // depthwise 1963 + // TODO: this is very likely wrong for some cases! - needs more testing 1964 + GGML_API struct ggml_tensor * ggml_conv_1d_dw( 1965 + struct ggml_context * ctx, 1966 + struct ggml_tensor * a, // convolution kernel 1967 + struct ggml_tensor * b, // data 1968 + int s0, // stride 1969 + int p0, // padding 1970 + int d0); // dilation 1971 + 1972 + GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph( 1973 + struct ggml_context * ctx, 1974 + struct ggml_tensor * a, // convolution kernel 1975 + struct ggml_tensor * b, // data 1976 + int s0, // stride 1977 + int d0); // dilation 1978 + 1979 + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( 1980 + struct ggml_context * ctx, 1981 + struct ggml_tensor * a, // convolution kernel 1982 + struct ggml_tensor * b, // data 1983 + int s0, // stride 1984 + int p0, // padding 1985 + int d0); // dilation 1986 + 1987 + GGML_API struct ggml_tensor * ggml_conv_2d( 1988 + struct ggml_context * ctx, 1989 + struct ggml_tensor * a, // convolution kernel 1990 + struct ggml_tensor * b, // data 1991 + int s0, // stride dimension 0 1992 + int s1, // stride dimension 1 1993 + int p0, // padding dimension 0 1994 + int p1, // padding dimension 1 1995 + int d0, // dilation dimension 0 1996 + int d1); // dilation dimension 1 1997 + 1998 + GGML_API struct ggml_tensor * ggml_im2col_3d( 1999 + struct ggml_context * ctx, 2000 + struct ggml_tensor * a, 2001 + struct ggml_tensor * b, 2002 + int64_t IC, 2003 + int s0, // stride width 2004 + int s1, // stride height 2005 + int s2, // stride depth 2006 + int p0, // padding width 2007 + int p1, // padding height 2008 + int p2, // padding depth 2009 + int d0, // dilation width 2010 + int d1, // dilation height 2011 + int d2, // dilation depth 2012 + enum ggml_type dst_type); 2013 + 2014 + // a: [OC*IC, KD, KH, KW] 2015 + // b: [N*IC, ID, IH, IW] 2016 + // result: [N*OC, OD, OH, OW] 2017 + GGML_API struct ggml_tensor * ggml_conv_3d( 2018 + struct ggml_context * ctx, 2019 + struct ggml_tensor * a, 2020 + struct ggml_tensor * b, 2021 + int64_t IC, 2022 + int s0, // stride width 2023 + int s1, // stride height 2024 + int s2, // stride depth 2025 + int p0, // padding width 2026 + int p1, // padding height 2027 + int p2, // padding depth 2028 + int d0, // dilation width 2029 + int d1, // dilation height 2030 + int d2 // dilation depth 2031 + ); 2032 + 2033 + // kernel size is a->ne[0] x a->ne[1] 2034 + // stride is equal to kernel size 2035 + // padding is zero 2036 + // example: 2037 + // a: 16 16 3 768 2038 + // b: 1024 1024 3 1 2039 + // res: 64 64 768 1 2040 + // used in sam 2041 + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( 2042 + struct ggml_context * ctx, 2043 + struct ggml_tensor * a, 2044 + struct ggml_tensor * b); 2045 + 2046 + // kernel size is a->ne[0] x a->ne[1] 2047 + // stride is 1 2048 + // padding is half 2049 + // example: 2050 + // a: 3 3 256 256 2051 + // b: 64 64 256 1 2052 + // res: 64 64 256 1 2053 + // used in sam 2054 + GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( 2055 + struct ggml_context * ctx, 2056 + struct ggml_tensor * a, 2057 + struct ggml_tensor * b); 2058 + 2059 + // depthwise (via im2col and mul_mat) 2060 + GGML_API struct ggml_tensor * ggml_conv_2d_dw( 2061 + struct ggml_context * ctx, 2062 + struct ggml_tensor * a, // convolution kernel 2063 + struct ggml_tensor * b, // data 2064 + int s0, // stride dimension 0 2065 + int s1, // stride dimension 1 2066 + int p0, // padding dimension 0 2067 + int p1, // padding dimension 1 2068 + int d0, // dilation dimension 0 2069 + int d1); // dilation dimension 1 2070 + 2071 + // Depthwise 2D convolution 2072 + // may be faster than ggml_conv_2d_dw, but not available in all backends 2073 + // a: KW KH 1 C convolution kernel 2074 + // b: W H C N input data 2075 + // res: W_out H_out C N 2076 + GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct( 2077 + struct ggml_context * ctx, 2078 + struct ggml_tensor * a, 2079 + struct ggml_tensor * b, 2080 + int stride0, 2081 + int stride1, 2082 + int pad0, 2083 + int pad1, 2084 + int dilation0, 2085 + int dilation1); 2086 + 2087 + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( 2088 + struct ggml_context * ctx, 2089 + struct ggml_tensor * a, 2090 + struct ggml_tensor * b, 2091 + int stride); 2092 + 2093 + GGML_API struct ggml_tensor * ggml_conv_2d_direct( 2094 + struct ggml_context * ctx, 2095 + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] 2096 + struct ggml_tensor * b, // input data [W, H, C, N] 2097 + int s0, // stride dimension 0 2098 + int s1, // stride dimension 1 2099 + int p0, // padding dimension 0 2100 + int p1, // padding dimension 1 2101 + int d0, // dilation dimension 0 2102 + int d1); // dilation dimension 1 2103 + 2104 + GGML_API struct ggml_tensor * ggml_conv_3d_direct( 2105 + struct ggml_context * ctx, 2106 + struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] 2107 + struct ggml_tensor * b, // input [W, H, D, C * N] 2108 + int s0, // stride 2109 + int s1, 2110 + int s2, 2111 + int p0, // padding 2112 + int p1, 2113 + int p2, 2114 + int d0, // dilation 2115 + int d1, 2116 + int d2, 2117 + int n_channels, 2118 + int n_batch, 2119 + int n_channels_out); 2120 + 2121 + enum ggml_op_pool { 2122 + GGML_OP_POOL_MAX, 2123 + GGML_OP_POOL_AVG, 2124 + GGML_OP_POOL_COUNT, 2125 + }; 2126 + 2127 + GGML_API struct ggml_tensor * ggml_pool_1d( 2128 + struct ggml_context * ctx, 2129 + struct ggml_tensor * a, 2130 + enum ggml_op_pool op, 2131 + int k0, // kernel size 2132 + int s0, // stride 2133 + int p0); // padding 2134 + 2135 + // the result will have 2*p0 padding for the first dimension 2136 + // and 2*p1 padding for the second dimension 2137 + GGML_API struct ggml_tensor * ggml_pool_2d( 2138 + struct ggml_context * ctx, 2139 + struct ggml_tensor * a, 2140 + enum ggml_op_pool op, 2141 + int k0, 2142 + int k1, 2143 + int s0, 2144 + int s1, 2145 + float p0, 2146 + float p1); 2147 + 2148 + GGML_API struct ggml_tensor * ggml_pool_2d_back( 2149 + struct ggml_context * ctx, 2150 + struct ggml_tensor * a, 2151 + struct ggml_tensor * af, // "a"/input used in forward pass 2152 + enum ggml_op_pool op, 2153 + int k0, 2154 + int k1, 2155 + int s0, 2156 + int s1, 2157 + float p0, 2158 + float p1); 2159 + 2160 + enum ggml_scale_mode { 2161 + GGML_SCALE_MODE_NEAREST = 0, 2162 + GGML_SCALE_MODE_BILINEAR = 1, 2163 + GGML_SCALE_MODE_BICUBIC = 2, 2164 + 2165 + GGML_SCALE_MODE_COUNT 2166 + }; 2167 + 2168 + enum ggml_scale_flag { 2169 + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8), 2170 + GGML_SCALE_FLAG_ANTIALIAS = (1 << 9), 2171 + }; 2172 + 2173 + // interpolate 2174 + // multiplies ne0 and ne1 by scale factor 2175 + GGML_API struct ggml_tensor * ggml_upscale( 2176 + struct ggml_context * ctx, 2177 + struct ggml_tensor * a, 2178 + int scale_factor, 2179 + enum ggml_scale_mode mode); 2180 + 2181 + // interpolate 2182 + // interpolate scale to specified dimensions 2183 + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext( 2184 + struct ggml_context * ctx, 2185 + struct ggml_tensor * a, 2186 + int ne0, 2187 + int ne1, 2188 + int ne2, 2189 + int ne3, 2190 + enum ggml_scale_mode mode), 2191 + "use ggml_interpolate instead"); 2192 + 2193 + // Up- or downsamples the input to the specified size. 2194 + // 2D scale modes (eg. bilinear) are applied to the first two dimensions. 2195 + GGML_API struct ggml_tensor * ggml_interpolate( 2196 + struct ggml_context * ctx, 2197 + struct ggml_tensor * a, 2198 + int64_t ne0, 2199 + int64_t ne1, 2200 + int64_t ne2, 2201 + int64_t ne3, 2202 + uint32_t mode); // ggml_scale_mode [ | ggml_scale_flag...] 2203 + 2204 + // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] 2205 + GGML_API struct ggml_tensor * ggml_pad( 2206 + struct ggml_context * ctx, 2207 + struct ggml_tensor * a, 2208 + int p0, 2209 + int p1, 2210 + int p2, 2211 + int p3); 2212 + 2213 + // pad each dimension with values on the other side of the torus (looping around) 2214 + GGML_API struct ggml_tensor * ggml_pad_circular( 2215 + struct ggml_context * ctx, 2216 + struct ggml_tensor * a, 2217 + int p0, 2218 + int p1, 2219 + int p2, 2220 + int p3); 2221 + 2222 + GGML_API struct ggml_tensor * ggml_pad_ext( 2223 + struct ggml_context * ctx, 2224 + struct ggml_tensor * a, 2225 + int lp0, 2226 + int rp0, 2227 + int lp1, 2228 + int rp1, 2229 + int lp2, 2230 + int rp2, 2231 + int lp3, 2232 + int rp3 2233 + ); 2234 + 2235 + // pad each dimension with values on the other side of the torus (looping around) 2236 + GGML_API struct ggml_tensor * ggml_pad_ext_circular( 2237 + struct ggml_context * ctx, 2238 + struct ggml_tensor * a, 2239 + int lp0, 2240 + int rp0, 2241 + int lp1, 2242 + int rp1, 2243 + int lp2, 2244 + int rp2, 2245 + int lp3, 2246 + int rp3); 2247 + 2248 + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] 2249 + GGML_API struct ggml_tensor * ggml_pad_reflect_1d( 2250 + struct ggml_context * ctx, 2251 + struct ggml_tensor * a, 2252 + int p0, 2253 + int p1); 2254 + 2255 + // Move tensor elements by an offset given for each dimension. Elements that 2256 + // are shifted beyond the last position are wrapped around to the beginning. 2257 + GGML_API struct ggml_tensor * ggml_roll( 2258 + struct ggml_context * ctx, 2259 + struct ggml_tensor * a, 2260 + int shift0, 2261 + int shift1, 2262 + int shift2, 2263 + int shift3); 2264 + 2265 + // Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing 2266 + // zeroes everywhere outside the masked area 2267 + GGML_API struct ggml_tensor * ggml_tri( 2268 + struct ggml_context * ctx, 2269 + struct ggml_tensor * a, 2270 + enum ggml_tri_type type); 2271 + 2272 + // Fill tensor a with constant c 2273 + GGML_API struct ggml_tensor * ggml_fill( 2274 + struct ggml_context * ctx, 2275 + struct ggml_tensor * a, 2276 + float c); 2277 + 2278 + GGML_API struct ggml_tensor * ggml_fill_inplace( 2279 + struct ggml_context * ctx, 2280 + struct ggml_tensor * a, 2281 + float c); 2282 + 2283 + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 2284 + // timesteps: [N,] 2285 + // return: [N, dim] 2286 + GGML_API struct ggml_tensor * ggml_timestep_embedding( 2287 + struct ggml_context * ctx, 2288 + struct ggml_tensor * timesteps, 2289 + int dim, 2290 + int max_period); 2291 + 2292 + // sort rows 2293 + enum ggml_sort_order { 2294 + GGML_SORT_ORDER_ASC, 2295 + GGML_SORT_ORDER_DESC, 2296 + }; 2297 + 2298 + GGML_API struct ggml_tensor * ggml_argsort( 2299 + struct ggml_context * ctx, 2300 + struct ggml_tensor * a, 2301 + enum ggml_sort_order order); 2302 + 2303 + // similar to ggml_top_k but implemented as `argsort` + `view` 2304 + GGML_API struct ggml_tensor * ggml_argsort_top_k( 2305 + struct ggml_context * ctx, 2306 + struct ggml_tensor * a, 2307 + int k); 2308 + 2309 + // top k elements per row 2310 + // note: the resulting top k indices are in no particular order 2311 + GGML_API struct ggml_tensor * ggml_top_k( 2312 + struct ggml_context * ctx, 2313 + struct ggml_tensor * a, 2314 + int k); 2315 + 2316 + GGML_API struct ggml_tensor * ggml_arange( 2317 + struct ggml_context * ctx, 2318 + float start, 2319 + float stop, 2320 + float step); 2321 + 2322 + // q: [n_embd_k, n_batch, n_head, ne3 ] 2323 + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] 2324 + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! 2325 + // mask: [n_kv, n_batch, ne32, ne33] 2326 + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! 2327 + // 2328 + // broadcast: 2329 + // n_head % n_head_kv == 0 2330 + // n_head % ne32 == 0 2331 + // ne3 % ne33 == 0 2332 + // 2333 + GGML_API struct ggml_tensor * ggml_flash_attn_ext( 2334 + struct ggml_context * ctx, 2335 + struct ggml_tensor * q, 2336 + struct ggml_tensor * k, 2337 + struct ggml_tensor * v, 2338 + struct ggml_tensor * mask, 2339 + float scale, 2340 + float max_bias, 2341 + float logit_softcap); 2342 + 2343 + GGML_API void ggml_flash_attn_ext_set_prec( 2344 + struct ggml_tensor * a, 2345 + enum ggml_prec prec); 2346 + 2347 + GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec( 2348 + const struct ggml_tensor * a); 2349 + 2350 + GGML_API void ggml_flash_attn_ext_add_sinks( 2351 + struct ggml_tensor * a, 2352 + struct ggml_tensor * sinks); 2353 + 2354 + // TODO: needs to be adapted to ggml_flash_attn_ext 2355 + GGML_API struct ggml_tensor * ggml_flash_attn_back( 2356 + struct ggml_context * ctx, 2357 + struct ggml_tensor * q, 2358 + struct ggml_tensor * k, 2359 + struct ggml_tensor * v, 2360 + struct ggml_tensor * d, 2361 + bool masked); 2362 + 2363 + GGML_API struct ggml_tensor * ggml_ssm_conv( 2364 + struct ggml_context * ctx, 2365 + struct ggml_tensor * sx, 2366 + struct ggml_tensor * c); 2367 + 2368 + GGML_API struct ggml_tensor * ggml_ssm_scan( 2369 + struct ggml_context * ctx, 2370 + struct ggml_tensor * s, 2371 + struct ggml_tensor * x, 2372 + struct ggml_tensor * dt, 2373 + struct ggml_tensor * A, 2374 + struct ggml_tensor * B, 2375 + struct ggml_tensor * C, 2376 + struct ggml_tensor * ids); 2377 + 2378 + // partition into non-overlapping windows with padding if needed 2379 + // example: 2380 + // a: 768 64 64 1 2381 + // w: 14 2382 + // res: 768 14 14 25 2383 + // used in sam 2384 + GGML_API struct ggml_tensor * ggml_win_part( 2385 + struct ggml_context * ctx, 2386 + struct ggml_tensor * a, 2387 + int w); 2388 + 2389 + // reverse of ggml_win_part 2390 + // used in sam 2391 + GGML_API struct ggml_tensor * ggml_win_unpart( 2392 + struct ggml_context * ctx, 2393 + struct ggml_tensor * a, 2394 + int w0, 2395 + int h0, 2396 + int w); 2397 + 2398 + GGML_API struct ggml_tensor * ggml_unary( 2399 + struct ggml_context * ctx, 2400 + struct ggml_tensor * a, 2401 + enum ggml_unary_op op); 2402 + 2403 + GGML_API struct ggml_tensor * ggml_unary_inplace( 2404 + struct ggml_context * ctx, 2405 + struct ggml_tensor * a, 2406 + enum ggml_unary_op op); 2407 + 2408 + // used in sam 2409 + GGML_API struct ggml_tensor * ggml_get_rel_pos( 2410 + struct ggml_context * ctx, 2411 + struct ggml_tensor * a, 2412 + int qh, 2413 + int kh); 2414 + 2415 + // used in sam 2416 + GGML_API struct ggml_tensor * ggml_add_rel_pos( 2417 + struct ggml_context * ctx, 2418 + struct ggml_tensor * a, 2419 + struct ggml_tensor * pw, 2420 + struct ggml_tensor * ph); 2421 + 2422 + GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( 2423 + struct ggml_context * ctx, 2424 + struct ggml_tensor * a, 2425 + struct ggml_tensor * pw, 2426 + struct ggml_tensor * ph); 2427 + 2428 + GGML_API struct ggml_tensor * ggml_rwkv_wkv6( 2429 + struct ggml_context * ctx, 2430 + struct ggml_tensor * k, 2431 + struct ggml_tensor * v, 2432 + struct ggml_tensor * r, 2433 + struct ggml_tensor * tf, 2434 + struct ggml_tensor * td, 2435 + struct ggml_tensor * state); 2436 + 2437 + GGML_API struct ggml_tensor * ggml_gated_linear_attn( 2438 + struct ggml_context * ctx, 2439 + struct ggml_tensor * k, 2440 + struct ggml_tensor * v, 2441 + struct ggml_tensor * q, 2442 + struct ggml_tensor * g, 2443 + struct ggml_tensor * state, 2444 + float scale); 2445 + 2446 + GGML_API struct ggml_tensor * ggml_rwkv_wkv7( 2447 + struct ggml_context * ctx, 2448 + struct ggml_tensor * r, 2449 + struct ggml_tensor * w, 2450 + struct ggml_tensor * k, 2451 + struct ggml_tensor * v, 2452 + struct ggml_tensor * a, 2453 + struct ggml_tensor * b, 2454 + struct ggml_tensor * state); 2455 + 2456 + /* Solves a specific equation of the form Ax=B, where A is a triangular matrix 2457 + * without zeroes on the diagonal (i.e. invertible). 2458 + * B can have any number of columns, but must have the same number of rows as A 2459 + * If A is [n, n] and B is [n, m], then the result will be [n, m] as well 2460 + * Has O(n^3) complexity (unlike most matrix ops out there), so use on cases 2461 + * where n > 100 sparingly, pre-chunk if necessary. 2462 + * 2463 + * If left = false, solves xA=B instead 2464 + * If lower = false, assumes upper triangular instead 2465 + * If uni = true, assumes diagonal of A to be all ones (will override actual values) 2466 + * 2467 + * TODO: currently only lower, right, non-unitriangular variant is implemented 2468 + */ 2469 + GGML_API struct ggml_tensor * ggml_solve_tri( 2470 + struct ggml_context * ctx, 2471 + struct ggml_tensor * a, 2472 + struct ggml_tensor * b, 2473 + bool left, 2474 + bool lower, 2475 + bool uni); 2476 + 2477 + // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] 2478 + // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 2479 + GGML_API struct ggml_tensor * ggml_gated_delta_net( 2480 + struct ggml_context * ctx, 2481 + struct ggml_tensor * q, 2482 + struct ggml_tensor * k, 2483 + struct ggml_tensor * v, 2484 + struct ggml_tensor * g, 2485 + struct ggml_tensor * beta, 2486 + struct ggml_tensor * state); 2487 + 2488 + // custom operators 2489 + 2490 + typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); 2491 + typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); 2492 + typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); 2493 + 2494 + #define GGML_N_TASKS_MAX (-1) 2495 + // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks 2496 + 2497 + GGML_API struct ggml_tensor * ggml_map_custom1( 2498 + struct ggml_context * ctx, 2499 + struct ggml_tensor * a, 2500 + ggml_custom1_op_t fun, 2501 + int n_tasks, 2502 + void * userdata); 2503 + 2504 + GGML_API struct ggml_tensor * ggml_map_custom1_inplace( 2505 + struct ggml_context * ctx, 2506 + struct ggml_tensor * a, 2507 + ggml_custom1_op_t fun, 2508 + int n_tasks, 2509 + void * userdata); 2510 + 2511 + GGML_API struct ggml_tensor * ggml_map_custom2( 2512 + struct ggml_context * ctx, 2513 + struct ggml_tensor * a, 2514 + struct ggml_tensor * b, 2515 + ggml_custom2_op_t fun, 2516 + int n_tasks, 2517 + void * userdata); 2518 + 2519 + GGML_API struct ggml_tensor * ggml_map_custom2_inplace( 2520 + struct ggml_context * ctx, 2521 + struct ggml_tensor * a, 2522 + struct ggml_tensor * b, 2523 + ggml_custom2_op_t fun, 2524 + int n_tasks, 2525 + void * userdata); 2526 + 2527 + GGML_API struct ggml_tensor * ggml_map_custom3( 2528 + struct ggml_context * ctx, 2529 + struct ggml_tensor * a, 2530 + struct ggml_tensor * b, 2531 + struct ggml_tensor * c, 2532 + ggml_custom3_op_t fun, 2533 + int n_tasks, 2534 + void * userdata); 2535 + 2536 + GGML_API struct ggml_tensor * ggml_map_custom3_inplace( 2537 + struct ggml_context * ctx, 2538 + struct ggml_tensor * a, 2539 + struct ggml_tensor * b, 2540 + struct ggml_tensor * c, 2541 + ggml_custom3_op_t fun, 2542 + int n_tasks, 2543 + void * userdata); 2544 + 2545 + typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata); 2546 + 2547 + GGML_API struct ggml_tensor * ggml_custom_4d( 2548 + struct ggml_context * ctx, 2549 + enum ggml_type type, 2550 + int64_t ne0, 2551 + int64_t ne1, 2552 + int64_t ne2, 2553 + int64_t ne3, 2554 + struct ggml_tensor ** args, 2555 + int n_args, 2556 + ggml_custom_op_t fun, 2557 + int n_tasks, 2558 + void * userdata); 2559 + 2560 + GGML_API struct ggml_tensor * ggml_custom_inplace( 2561 + struct ggml_context * ctx, 2562 + struct ggml_tensor * a, 2563 + struct ggml_tensor ** args, 2564 + int n_args, 2565 + ggml_custom_op_t fun, 2566 + int n_tasks, 2567 + void * userdata); 2568 + 2569 + // loss function 2570 + 2571 + GGML_API struct ggml_tensor * ggml_cross_entropy_loss( 2572 + struct ggml_context * ctx, 2573 + struct ggml_tensor * a, // logits 2574 + struct ggml_tensor * b); // labels 2575 + 2576 + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( 2577 + struct ggml_context * ctx, 2578 + struct ggml_tensor * a, // logits 2579 + struct ggml_tensor * b, // labels 2580 + struct ggml_tensor * c); // gradients of cross_entropy_loss result 2581 + 2582 + // AdamW optimizer step 2583 + // Paper: https://arxiv.org/pdf/1711.05101v3.pdf 2584 + // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html 2585 + GGML_API struct ggml_tensor * ggml_opt_step_adamw( 2586 + struct ggml_context * ctx, 2587 + struct ggml_tensor * a, 2588 + struct ggml_tensor * grad, 2589 + struct ggml_tensor * m, 2590 + struct ggml_tensor * v, 2591 + struct ggml_tensor * adamw_params); // parameters such as the learning rate 2592 + 2593 + // stochastic gradient descent step (with weight decay) 2594 + GGML_API struct ggml_tensor * ggml_opt_step_sgd( 2595 + struct ggml_context * ctx, 2596 + struct ggml_tensor * a, 2597 + struct ggml_tensor * grad, 2598 + struct ggml_tensor * sgd_params); // alpha, weight decay 2599 + 2600 + // build forward multiple tensors and select one of them for computing 2601 + // this is useful for creating graphs that have constant topology but compute different things based on the input 2602 + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 2603 + // 2604 + // nodes: 2605 + // | - build forward into the graph but do not compute 2606 + // c - build forward into the graph and compute 2607 + // 2608 + // | | ... c ... | 2609 + // | | ... c ... | 2610 + // | | ... c ... | 2611 + // [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx) 2612 + // c 2613 + // c 2614 + // 2615 + // example: 2616 + // struct ggml_tensor * curs[3]; 2617 + // 2618 + // curs[0] = compute0(...); 2619 + // curs[1] = compute1(...); 2620 + // curs[2] = compute2(...); 2621 + // 2622 + // int idx = select_branch(some_input); 2623 + // 2624 + // struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx); 2625 + // 2626 + GGML_API struct ggml_tensor * ggml_build_forward_select( 2627 + struct ggml_cgraph * cgraph, 2628 + struct ggml_tensor ** tensors, 2629 + int n_tensors, 2630 + int idx); 2631 + 2632 + GGML_API void ggml_build_forward_expand( 2633 + struct ggml_cgraph * cgraph, 2634 + struct ggml_tensor * tensor); 2635 + 2636 + GGML_API void ggml_build_backward_expand( 2637 + struct ggml_context * ctx, // context for gradient computation 2638 + struct ggml_cgraph * cgraph, 2639 + struct ggml_tensor ** grad_accs); 2640 + 2641 + // graph allocation in a context 2642 + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false 2643 + GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); 2644 + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads); 2645 + GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); 2646 + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 2647 + GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); 2648 + 2649 + GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph); 2650 + GGML_API struct ggml_tensor * ggml_graph_node (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i] 2651 + GGML_API struct ggml_tensor ** ggml_graph_nodes (struct ggml_cgraph * cgraph); 2652 + GGML_API int ggml_graph_n_nodes(struct ggml_cgraph * cgraph); 2653 + 2654 + GGML_API void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); 2655 + 2656 + GGML_API size_t ggml_graph_overhead(void); 2657 + GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); 2658 + 2659 + GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name); 2660 + GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); 2661 + GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); 2662 + 2663 + // print info and performance information for the graph 2664 + GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); 2665 + 2666 + // dump the graph into a file using the dot format 2667 + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); 2668 + 2669 + // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? 2670 + typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); 2671 + 2672 + // Set callback for all future logging events. 2673 + // If this is not called, or NULL is supplied, everything is output on stderr. 2674 + GGML_API void ggml_log_get(ggml_log_callback * log_callback, void ** user_data); 2675 + GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); 2676 + 2677 + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); 2678 + 2679 + // 2680 + // quantization 2681 + // 2682 + 2683 + // - ggml_quantize_init can be called multiple times with the same type 2684 + // it will only initialize the quantization tables for the first call or after ggml_quantize_free 2685 + // automatically called by ggml_quantize_chunk for convenience 2686 + // 2687 + // - ggml_quantize_free will free any memory allocated by ggml_quantize_init 2688 + // call this at the end of the program to avoid memory leaks 2689 + // 2690 + // note: these are thread-safe 2691 + // 2692 + GGML_API void ggml_quantize_init(enum ggml_type type); 2693 + GGML_API void ggml_quantize_free(void); 2694 + 2695 + // some quantization type cannot be used without an importance matrix 2696 + GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type); 2697 + 2698 + // calls ggml_quantize_init internally (i.e. can allocate memory) 2699 + GGML_API size_t ggml_quantize_chunk( 2700 + enum ggml_type type, 2701 + const float * src, 2702 + void * dst, 2703 + int64_t start, 2704 + int64_t nrows, 2705 + int64_t n_per_row, 2706 + const float * imatrix); 2707 + 2708 + #ifdef __cplusplus 2709 + // restrict not standard in C++ 2710 + # if defined(__GNUC__) 2711 + # define GGML_RESTRICT __restrict__ 2712 + # elif defined(__clang__) 2713 + # define GGML_RESTRICT __restrict 2714 + # elif defined(_MSC_VER) 2715 + # define GGML_RESTRICT __restrict 2716 + # else 2717 + # define GGML_RESTRICT 2718 + # endif 2719 + #else 2720 + # if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L) 2721 + # define GGML_RESTRICT __restrict 2722 + # else 2723 + # define GGML_RESTRICT restrict 2724 + # endif 2725 + #endif 2726 + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); 2727 + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); 2728 + 2729 + struct ggml_type_traits { 2730 + const char * type_name; 2731 + int64_t blck_size; 2732 + int64_t blck_size_interleave; // interleave elements in blocks 2733 + size_t type_size; 2734 + bool is_quantized; 2735 + ggml_to_float_t to_float; 2736 + ggml_from_float_t from_float_ref; 2737 + }; 2738 + 2739 + GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type); 2740 + 2741 + // ggml threadpool 2742 + // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend 2743 + // the goal should be to create an API that other backends can use move everything to the ggml base 2744 + 2745 + // scheduling priorities 2746 + enum ggml_sched_priority { 2747 + GGML_SCHED_PRIO_LOW = -1, 2748 + GGML_SCHED_PRIO_NORMAL, 2749 + GGML_SCHED_PRIO_MEDIUM, 2750 + GGML_SCHED_PRIO_HIGH, 2751 + GGML_SCHED_PRIO_REALTIME 2752 + }; 2753 + 2754 + // threadpool params 2755 + // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults 2756 + struct ggml_threadpool_params { 2757 + bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) 2758 + int n_threads; // number of threads 2759 + enum ggml_sched_priority prio; // thread priority 2760 + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) 2761 + bool strict_cpu; // strict cpu placement 2762 + bool paused; // start in paused state 2763 + }; 2764 + 2765 + struct ggml_threadpool; // forward declaration, see ggml.c 2766 + 2767 + typedef struct ggml_threadpool * ggml_threadpool_t; 2768 + 2769 + GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads); 2770 + GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads); 2771 + GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1); 2772 + 2773 + #ifdef __cplusplus 2774 + } 2775 + #endif
+204
backend/llama-include/gguf.h
··· 1 + // This file contains functionality related to "GGUF" files, the binary file format used by ggml. 2 + // GGUF files have the following structure: 3 + // 4 + // 1. File magic "GGUF" (4 bytes). 5 + // 2. File version (uint32_t). 6 + // 3. Number of ggml tensors in file (int64_t). 7 + // 4. Number of key-value-pairs in file (int64_t). 8 + // 5. For each KV pair: 9 + // 1. The key (string). 10 + // 2. The value type (gguf_type). 11 + // 3a. If the value type is GGUF_TYPE_ARRAY: 12 + // 1. The type of the array (gguf_type). 13 + // 2. The number of elements in the array (uint64_t). 14 + // 3. The binary representation of each element in the array. 15 + // 3b. Otherwise: 16 + // 1. The binary representation of the value. 17 + // 6. For each ggml tensor: 18 + // 1. The tensor name (string). 19 + // 2. The number of dimensions of the tensor (uint32_t). 20 + // 3. For each dimension: 21 + // 1. The size of the tensor in the dimension (int64_t). 22 + // 4. The tensor data type (ggml_type). 23 + // 5. The tensor data offset in the tensor data binary blob (uint64_t). 24 + // 7. The tensor data binary blob (optional, aligned). 25 + // 26 + // Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator. 27 + // All enums are stored as int32_t. 28 + // All bool values are stored as int8_t. 29 + // If the special key "general.alignment" (uint32_t) is defined it is used for alignment, 30 + // otherwise GGUF_DEFAULT_ALIGNMENT is used. 31 + // 32 + // Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) 33 + 34 + #pragma once 35 + 36 + #include "ggml.h" 37 + 38 + #include <stdbool.h> 39 + #include <stdint.h> 40 + 41 + #define GGUF_MAGIC "GGUF" 42 + #define GGUF_VERSION 3 43 + 44 + #define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment" 45 + 46 + #define GGUF_DEFAULT_ALIGNMENT 32 47 + 48 + #ifdef __cplusplus 49 + extern "C" { 50 + #endif 51 + 52 + // types that can be stored as GGUF KV data 53 + enum gguf_type { 54 + GGUF_TYPE_UINT8 = 0, 55 + GGUF_TYPE_INT8 = 1, 56 + GGUF_TYPE_UINT16 = 2, 57 + GGUF_TYPE_INT16 = 3, 58 + GGUF_TYPE_UINT32 = 4, 59 + GGUF_TYPE_INT32 = 5, 60 + GGUF_TYPE_FLOAT32 = 6, 61 + GGUF_TYPE_BOOL = 7, 62 + GGUF_TYPE_STRING = 8, 63 + GGUF_TYPE_ARRAY = 9, 64 + GGUF_TYPE_UINT64 = 10, 65 + GGUF_TYPE_INT64 = 11, 66 + GGUF_TYPE_FLOAT64 = 12, 67 + GGUF_TYPE_COUNT, // marks the end of the enum 68 + }; 69 + 70 + struct gguf_context; 71 + 72 + struct gguf_init_params { 73 + bool no_alloc; 74 + 75 + // if not NULL, create a ggml_context and allocate the tensor data in it 76 + struct ggml_context ** ctx; 77 + }; 78 + 79 + GGML_API struct gguf_context * gguf_init_empty(void); 80 + GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); 81 + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); 82 + //GGML_API struct gguf_context * gguf_init_from_buffer(..); 83 + 84 + GGML_API void gguf_free(struct gguf_context * ctx); 85 + 86 + GGML_API const char * gguf_type_name(enum gguf_type type); 87 + 88 + GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); 89 + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); 90 + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); 91 + 92 + GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); 93 + GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found 94 + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id); 95 + 96 + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id); 97 + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id); 98 + 99 + // will abort if the wrong type is used for the key 100 + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id); 101 + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id); 102 + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id); 103 + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id); 104 + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id); 105 + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id); 106 + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id); 107 + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id); 108 + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id); 109 + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id); 110 + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id); 111 + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id); 112 + GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id); 113 + GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id); 114 + 115 + // get raw pointer to the first element of the array with the given key_id 116 + // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) 117 + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); 118 + 119 + // get ith C string from array with given key_id 120 + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); 121 + 122 + GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx); 123 + GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found 124 + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id); 125 + GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id); 126 + GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id); 127 + GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id); 128 + 129 + // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist) 130 + GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key); 131 + 132 + // overrides an existing KV pair or adds a new one, the new KV pair is always at the back 133 + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); 134 + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); 135 + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); 136 + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); 137 + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); 138 + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); 139 + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); 140 + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); 141 + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); 142 + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); 143 + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); 144 + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); 145 + 146 + // creates a new array with n elements of the given type and copies the corresponding number of bytes from data 147 + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n); 148 + 149 + // creates a new array with n strings and copies the corresponding strings from data 150 + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n); 151 + 152 + // set or add KV pairs from another context 153 + GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src); 154 + 155 + // add tensor to GGUF context, tensor name must be unique 156 + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); 157 + 158 + // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated 159 + // in such a way that the tensor data remains as one contiguous block (except for padding) 160 + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); 161 + 162 + // assumes that at least gguf_get_tensor_size bytes can be read from data 163 + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data); 164 + 165 + // writing gguf files can be done in 3 ways: 166 + // 167 + // - write the entire gguf_context to a binary file in a single pass: 168 + // 169 + // gguf_write_to_file(ctx, fname, /*only_meta =*/ false); 170 + // 171 + // - write only the meta data to a file, then re-open the file and append the tensor data: 172 + // 173 + // gguf_write_to_file(ctx, fname, /*only_meta =*/ true); 174 + // FILE * f = fopen(fname, "ab"); 175 + // fwrite(f, ...); // write tensor data 176 + // fclose(f); 177 + // 178 + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: 179 + // 180 + // FILE * f = fopen(fname, "wb"); 181 + // const size_t size_meta = gguf_get_meta_size(ctx); 182 + // fseek(f, size_meta, SEEK_SET); 183 + // fwrite(f, ...); // write tensor data 184 + // void * data = malloc(size_meta); 185 + // gguf_get_meta_data(ctx, data); 186 + // rewind(f); 187 + // fwrite(data, 1, data, f); 188 + // free(data); 189 + // fclose(f); 190 + // 191 + 192 + // write the entire context to a binary file 193 + GGML_API bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta); 194 + GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); 195 + 196 + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding 197 + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); 198 + 199 + // writes the meta data to pointer "data" 200 + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); 201 + 202 + #ifdef __cplusplus 203 + } 204 + #endif
+1588
backend/llama-include/llama.h
··· 1 + #ifndef LLAMA_H 2 + #define LLAMA_H 3 + 4 + #include "ggml.h" 5 + #include "ggml-cpu.h" 6 + #include "ggml-backend.h" 7 + #include "ggml-opt.h" 8 + #include "gguf.h" 9 + 10 + #include <stddef.h> 11 + #include <stdint.h> 12 + #include <stdio.h> 13 + #include <stdbool.h> 14 + 15 + #ifdef LLAMA_SHARED 16 + # if defined(_WIN32) && !defined(__MINGW32__) 17 + # ifdef LLAMA_BUILD 18 + # define LLAMA_API __declspec(dllexport) 19 + # else 20 + # define LLAMA_API __declspec(dllimport) 21 + # endif 22 + # else 23 + # define LLAMA_API __attribute__ ((visibility ("default"))) 24 + # endif 25 + #else 26 + # define LLAMA_API 27 + #endif 28 + 29 + #ifdef __GNUC__ 30 + # define DEPRECATED(func, hint) func __attribute__((deprecated(hint))) 31 + #elif defined(_MSC_VER) 32 + # define DEPRECATED(func, hint) __declspec(deprecated(hint)) func 33 + #else 34 + # define DEPRECATED(func, hint) func 35 + #endif 36 + 37 + #define LLAMA_DEFAULT_SEED 0xFFFFFFFF 38 + 39 + #define LLAMA_TOKEN_NULL -1 40 + 41 + #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' 42 + #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' 43 + #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' 44 + 45 + #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN 46 + #define LLAMA_SESSION_VERSION 9 47 + 48 + #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ 49 + #define LLAMA_STATE_SEQ_VERSION 2 50 + 51 + #ifdef __cplusplus 52 + extern "C" { 53 + #endif 54 + 55 + // 56 + // C interface 57 + // 58 + // TODO: show sample usage 59 + // 60 + 61 + struct llama_vocab; 62 + struct llama_model; 63 + struct llama_context; 64 + struct llama_sampler; 65 + 66 + typedef struct llama_memory_i * llama_memory_t; 67 + 68 + typedef int32_t llama_pos; 69 + typedef int32_t llama_token; 70 + typedef int32_t llama_seq_id; 71 + 72 + enum llama_vocab_type { 73 + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab 74 + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback 75 + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE 76 + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece 77 + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram 78 + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization 79 + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming 80 + }; 81 + 82 + enum llama_rope_type { 83 + LLAMA_ROPE_TYPE_NONE = -1, 84 + LLAMA_ROPE_TYPE_NORM = 0, 85 + LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, 86 + LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, 87 + LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE, 88 + LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, 89 + }; 90 + 91 + enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file 92 + LLAMA_TOKEN_TYPE_UNDEFINED = 0, 93 + LLAMA_TOKEN_TYPE_NORMAL = 1, 94 + LLAMA_TOKEN_TYPE_UNKNOWN = 2, 95 + LLAMA_TOKEN_TYPE_CONTROL = 3, 96 + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, 97 + LLAMA_TOKEN_TYPE_UNUSED = 5, 98 + LLAMA_TOKEN_TYPE_BYTE = 6, 99 + }; 100 + 101 + enum llama_token_attr { 102 + LLAMA_TOKEN_ATTR_UNDEFINED = 0, 103 + LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 0, 104 + LLAMA_TOKEN_ATTR_UNUSED = 1 << 1, 105 + LLAMA_TOKEN_ATTR_NORMAL = 1 << 2, 106 + LLAMA_TOKEN_ATTR_CONTROL = 1 << 3, // SPECIAL? 107 + LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4, 108 + LLAMA_TOKEN_ATTR_BYTE = 1 << 5, 109 + LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 6, 110 + LLAMA_TOKEN_ATTR_LSTRIP = 1 << 7, 111 + LLAMA_TOKEN_ATTR_RSTRIP = 1 << 8, 112 + LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 9, 113 + }; 114 + 115 + // model file types 116 + enum llama_ftype { 117 + LLAMA_FTYPE_ALL_F32 = 0, 118 + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors 119 + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors 120 + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors 121 + // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 122 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed 123 + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed 124 + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors 125 + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors 126 + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors 127 + LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors 128 + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors 129 + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors 130 + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors 131 + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors 132 + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors 133 + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors 134 + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors 135 + LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors 136 + LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors 137 + LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors 138 + LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors 139 + LLAMA_FTYPE_MOSTLY_IQ3_XS = 22, // except 1d tensors 140 + LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors 141 + LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors 142 + LLAMA_FTYPE_MOSTLY_IQ4_NL = 25, // except 1d tensors 143 + LLAMA_FTYPE_MOSTLY_IQ3_S = 26, // except 1d tensors 144 + LLAMA_FTYPE_MOSTLY_IQ3_M = 27, // except 1d tensors 145 + LLAMA_FTYPE_MOSTLY_IQ2_S = 28, // except 1d tensors 146 + LLAMA_FTYPE_MOSTLY_IQ2_M = 29, // except 1d tensors 147 + LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors 148 + LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors 149 + LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors 150 + //LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack 151 + //LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack 152 + //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack 153 + LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors 154 + LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors 155 + LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors 156 + LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors 157 + LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors 158 + 159 + LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file 160 + }; 161 + 162 + enum llama_rope_scaling_type { 163 + LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1, 164 + LLAMA_ROPE_SCALING_TYPE_NONE = 0, 165 + LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, 166 + LLAMA_ROPE_SCALING_TYPE_YARN = 2, 167 + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, 168 + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, 169 + }; 170 + 171 + enum llama_pooling_type { 172 + LLAMA_POOLING_TYPE_UNSPECIFIED = -1, 173 + LLAMA_POOLING_TYPE_NONE = 0, 174 + LLAMA_POOLING_TYPE_MEAN = 1, 175 + LLAMA_POOLING_TYPE_CLS = 2, 176 + LLAMA_POOLING_TYPE_LAST = 3, 177 + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph 178 + }; 179 + 180 + enum llama_attention_type { 181 + LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, 182 + LLAMA_ATTENTION_TYPE_CAUSAL = 0, 183 + LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1, 184 + }; 185 + 186 + enum llama_flash_attn_type { 187 + LLAMA_FLASH_ATTN_TYPE_AUTO = -1, 188 + LLAMA_FLASH_ATTN_TYPE_DISABLED = 0, 189 + LLAMA_FLASH_ATTN_TYPE_ENABLED = 1, 190 + }; 191 + 192 + LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); 193 + 194 + enum llama_split_mode { 195 + LLAMA_SPLIT_MODE_NONE = 0, // single GPU 196 + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs 197 + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported 198 + }; 199 + 200 + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) 201 + typedef struct llama_token_data { 202 + llama_token id; // token id 203 + float logit; // log-odds of the token 204 + float p; // probability of the token 205 + } llama_token_data; 206 + 207 + typedef struct llama_token_data_array { 208 + // TODO: consider SoA 209 + // NOTE: this pointer can be modified by the samplers 210 + llama_token_data * data; 211 + size_t size; 212 + int64_t selected; // this is the index in the data array (i.e. not the token id) 213 + bool sorted; // note: do not assume the data is sorted - always check this flag 214 + } llama_token_data_array; 215 + 216 + typedef bool (*llama_progress_callback)(float progress, void * user_data); 217 + 218 + // Input data for llama_encode/llama_decode 219 + // A llama_batch object can contain input about one or many sequences 220 + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens 221 + // 222 + // - token : the token ids of the input (used when embd is NULL) 223 + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) 224 + // - pos : the positions of the respective token in the sequence 225 + // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) 226 + // - seq_id : the sequence to which the respective token belongs 227 + // (if set to NULL, the sequence ID will be assumed to be 0) 228 + // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output 229 + // (if set to NULL: 230 + // - if embeddings: all tokens are output 231 + // - if not: only the last token is output 232 + // ) 233 + // 234 + typedef struct llama_batch { 235 + int32_t n_tokens; 236 + 237 + llama_token * token; 238 + float * embd; 239 + llama_pos * pos; 240 + int32_t * n_seq_id; 241 + llama_seq_id ** seq_id; 242 + int8_t * logits; // TODO: rename this to "output" 243 + } llama_batch; 244 + 245 + enum llama_model_kv_override_type { 246 + LLAMA_KV_OVERRIDE_TYPE_INT, 247 + LLAMA_KV_OVERRIDE_TYPE_FLOAT, 248 + LLAMA_KV_OVERRIDE_TYPE_BOOL, 249 + LLAMA_KV_OVERRIDE_TYPE_STR, 250 + }; 251 + 252 + enum llama_model_meta_key { 253 + LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE, 254 + LLAMA_MODEL_META_KEY_SAMPLING_TOP_K, 255 + LLAMA_MODEL_META_KEY_SAMPLING_TOP_P, 256 + LLAMA_MODEL_META_KEY_SAMPLING_MIN_P, 257 + LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY, 258 + LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD, 259 + LLAMA_MODEL_META_KEY_SAMPLING_TEMP, 260 + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N, 261 + LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT, 262 + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT, 263 + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU, 264 + LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA, 265 + }; 266 + 267 + struct llama_model_kv_override { 268 + enum llama_model_kv_override_type tag; 269 + 270 + char key[128]; 271 + 272 + union { 273 + int64_t val_i64; 274 + double val_f64; 275 + bool val_bool; 276 + char val_str[128]; 277 + }; 278 + }; 279 + 280 + struct llama_model_tensor_buft_override { 281 + const char * pattern; 282 + ggml_backend_buffer_type_t buft; 283 + }; 284 + 285 + struct llama_model_params { 286 + // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) 287 + ggml_backend_dev_t * devices; 288 + 289 + // NULL-terminated list of buffer types to use for tensors that match a pattern 290 + const struct llama_model_tensor_buft_override * tensor_buft_overrides; 291 + 292 + int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers 293 + enum llama_split_mode split_mode; // how to split the model across multiple GPUs 294 + 295 + // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE 296 + int32_t main_gpu; 297 + 298 + // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() 299 + const float * tensor_split; 300 + 301 + // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. 302 + // If the provided progress_callback returns true, model loading continues. 303 + // If it returns false, model loading is immediately aborted. 304 + llama_progress_callback progress_callback; 305 + 306 + // context pointer passed to the progress callback 307 + void * progress_callback_user_data; 308 + 309 + // override key-value pairs of the model meta data 310 + const struct llama_model_kv_override * kv_overrides; 311 + 312 + // Keep the booleans together to avoid misalignment during copy-by-value. 313 + bool vocab_only; // only load the vocabulary, no weights 314 + bool use_mmap; // use mmap if possible 315 + bool use_direct_io; // use direct io, takes precedence over use_mmap when supported 316 + bool use_mlock; // force system to keep model in RAM 317 + bool check_tensors; // validate model tensor data 318 + bool use_extra_bufts; // use extra buffer types (used for weight repacking) 319 + bool no_host; // bypass host buffer allowing extra buffers to be used 320 + bool no_alloc; // only load metadata and simulate memory allocations 321 + }; 322 + 323 + struct llama_sampler_seq_config { 324 + llama_seq_id seq_id; 325 + struct llama_sampler * sampler; 326 + }; 327 + 328 + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations 329 + // https://github.com/ggml-org/llama.cpp/pull/7544 330 + struct llama_context_params { 331 + uint32_t n_ctx; // text context, 0 = from model 332 + uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode 333 + uint32_t n_ubatch; // physical maximum batch size 334 + uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) 335 + int32_t n_threads; // number of threads to use for generation 336 + int32_t n_threads_batch; // number of threads to use for batch processing 337 + 338 + enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` 339 + enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id 340 + enum llama_attention_type attention_type; // attention type to use for embeddings 341 + enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention 342 + 343 + // ref: https://github.com/ggml-org/llama.cpp/pull/2054 344 + float rope_freq_base; // RoPE base frequency, 0 = from model 345 + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model 346 + float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model 347 + float yarn_attn_factor; // YaRN magnitude scaling factor 348 + float yarn_beta_fast; // YaRN low correction dim 349 + float yarn_beta_slow; // YaRN high correction dim 350 + uint32_t yarn_orig_ctx; // YaRN original context size 351 + float defrag_thold; // [DEPRECATED] defragment the KV cache if holes/size > thold, <= 0 disabled (default) 352 + 353 + ggml_backend_sched_eval_callback cb_eval; 354 + void * cb_eval_user_data; 355 + 356 + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] 357 + enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] 358 + 359 + // Abort callback 360 + // if it returns true, execution of llama_decode() will be aborted 361 + // currently works only with CPU execution 362 + ggml_abort_callback abort_callback; 363 + void * abort_callback_data; 364 + 365 + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. 366 + bool embeddings; // if true, extract embeddings (together with logits) 367 + bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU 368 + bool no_perf; // measure performance timings 369 + bool op_offload; // offload host tensor operations to device 370 + bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) 371 + // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases 372 + // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 373 + bool kv_unified; // use a unified buffer across the input sequences when computing the attention 374 + // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix 375 + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 376 + 377 + // [EXPERIMENTAL] 378 + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) 379 + // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) 380 + struct llama_sampler_seq_config * samplers; 381 + size_t n_samplers; 382 + }; 383 + 384 + struct llama_model_tensor_override { 385 + const char * pattern; 386 + enum ggml_type type; 387 + }; 388 + 389 + struct llama_model_imatrix_data { 390 + const char * name; 391 + const float * data; 392 + size_t size; 393 + }; 394 + 395 + // model quantization parameters 396 + typedef struct llama_model_quantize_params { 397 + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() 398 + enum llama_ftype ftype; // quantize to this llama_ftype 399 + enum ggml_type output_tensor_type; // output tensor type 400 + enum ggml_type token_embedding_type; // token embeddings tensor type 401 + bool allow_requantize; // allow quantizing non-f32/f16 tensors 402 + bool quantize_output_tensor; // quantize output.weight 403 + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored 404 + bool pure; // quantize all tensors to the default type 405 + bool keep_split; // quantize to the same number of shards 406 + bool dry_run; // calculate and show the final quantization size without performing quantization 407 + const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data 408 + const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides 409 + const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides 410 + const int32_t * prune_layers; // pointer to layer indices to prune 411 + } llama_model_quantize_params; 412 + 413 + typedef struct llama_logit_bias { 414 + llama_token token; 415 + float bias; 416 + } llama_logit_bias; 417 + 418 + typedef struct llama_sampler_chain_params { 419 + bool no_perf; // whether to measure performance timings 420 + } llama_sampler_chain_params; 421 + 422 + // used in chat template 423 + typedef struct llama_chat_message { 424 + const char * role; 425 + const char * content; 426 + } llama_chat_message; 427 + 428 + // lora adapter 429 + struct llama_adapter_lora; 430 + 431 + // Helpers for getting default parameters 432 + // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) 433 + LLAMA_API struct llama_model_params llama_model_default_params(void); 434 + LLAMA_API struct llama_context_params llama_context_default_params(void); 435 + LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); 436 + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); 437 + 438 + // Initialize the llama + ggml backend 439 + // If numa is true, use NUMA optimizations 440 + // Call once at the start of the program 441 + LLAMA_API void llama_backend_init(void); 442 + 443 + // Call once at the end of the program - currently only used for MPI 444 + LLAMA_API void llama_backend_free(void); 445 + 446 + //optional: 447 + LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); 448 + 449 + // Optional: an auto threadpool gets created in ggml if not passed explicitly 450 + LLAMA_API void llama_attach_threadpool( 451 + struct llama_context * ctx, 452 + ggml_threadpool_t threadpool, 453 + ggml_threadpool_t threadpool_batch); 454 + 455 + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); 456 + 457 + typedef void (*llama_model_set_tensor_data_t)(struct ggml_tensor * tensor, void * userdata); 458 + 459 + // Create a new model from GGUF metadata as well as a function to set the tensor data 460 + // - tensors are created as GGML_TYPE_F32 by default, 461 + // override by adding a tensor with the same name but a different name to the context 462 + LLAMA_API struct llama_model * llama_model_init_from_user( 463 + struct gguf_context * metadata, 464 + llama_model_set_tensor_data_t set_tensor_data, // function to initialize tensor data with 465 + void * set_tensor_data_ud, // userdata for function 466 + struct llama_model_params params); 467 + 468 + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( 469 + const char * path_model, 470 + struct llama_model_params params), 471 + "use llama_model_load_from_file instead"); 472 + 473 + // Load a model from a file 474 + // If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf 475 + // If the split file name does not follow this pattern, use llama_model_load_from_splits 476 + LLAMA_API struct llama_model * llama_model_load_from_file( 477 + const char * path_model, 478 + struct llama_model_params params); 479 + 480 + // Load a model from an open FILE pointer 481 + LLAMA_API struct llama_model * llama_model_load_from_file_ptr( 482 + FILE * file, 483 + struct llama_model_params params); 484 + 485 + // Load a model from multiple splits (support custom naming scheme) 486 + // The paths must be in the correct order 487 + LLAMA_API struct llama_model * llama_model_load_from_splits( 488 + const char ** paths, 489 + size_t n_paths, 490 + struct llama_model_params params); 491 + 492 + LLAMA_API void llama_model_save_to_file( 493 + const struct llama_model * model, 494 + const char * path_model); 495 + 496 + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), 497 + "use llama_model_free instead"); 498 + 499 + LLAMA_API void llama_model_free(struct llama_model * model); 500 + 501 + LLAMA_API struct llama_context * llama_init_from_model( 502 + struct llama_model * model, 503 + struct llama_context_params params); 504 + 505 + DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( 506 + struct llama_model * model, 507 + struct llama_context_params params), 508 + "use llama_init_from_model instead"); 509 + 510 + // Frees all allocated memory 511 + LLAMA_API void llama_free(struct llama_context * ctx); 512 + 513 + enum llama_params_fit_status { 514 + LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit 515 + LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit 516 + LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path 517 + }; 518 + 519 + // fits mparams and cparams to free device memory (assumes system memory is unlimited) 520 + // - returns true if the parameters could be successfully modified to fit device memory 521 + // - this function is NOT thread safe because it modifies the global llama logger state 522 + // - only parameters that have the same value as in llama_default_model_params are modified 523 + // with the exception of the context size which is modified if and only if equal to 0 524 + LLAMA_API enum llama_params_fit_status llama_params_fit( 525 + const char * path_model, 526 + struct llama_model_params * mparams, 527 + struct llama_context_params * cparams, 528 + float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements 529 + struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements 530 + size_t * margins, // margins of memory to leave per device in bytes 531 + uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use 532 + enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log 533 + 534 + LLAMA_API int64_t llama_time_us(void); 535 + 536 + LLAMA_API size_t llama_max_devices(void); 537 + LLAMA_API size_t llama_max_parallel_sequences(void); 538 + LLAMA_API size_t llama_max_tensor_buft_overrides(void); 539 + 540 + LLAMA_API bool llama_supports_mmap (void); 541 + LLAMA_API bool llama_supports_mlock (void); 542 + LLAMA_API bool llama_supports_gpu_offload(void); 543 + LLAMA_API bool llama_supports_rpc (void); 544 + 545 + // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions 546 + // In some cases the requested values via llama_context_params may differ from the actual values used by the context 547 + // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 548 + LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); 549 + LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx); 550 + LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); 551 + LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); 552 + LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); 553 + 554 + DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); 555 + DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); 556 + DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); 557 + DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); 558 + 559 + DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); 560 + 561 + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); 562 + LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); 563 + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type 564 + 565 + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); 566 + LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); 567 + 568 + LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); 569 + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); 570 + LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); 571 + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); 572 + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); 573 + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); 574 + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); 575 + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); 576 + 577 + // Get the model's RoPE frequency scaling factor 578 + LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); 579 + 580 + // Returns the number of classifier outputs (only valid for classifier models) 581 + // Undefined behavior for non-classifier models 582 + LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model); 583 + 584 + // Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided 585 + LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i); 586 + 587 + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); 588 + 589 + LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); 590 + 591 + // Functions to access the model's GGUF metadata scalar values 592 + // - The functions return the length of the string on success, or -1 on failure 593 + // - The output string is always null-terminated and cleared on failure 594 + // - When retrieving a string, an extra byte must be allocated to account for the null terminator 595 + // - GGUF array values are not supported by these functions 596 + 597 + // Get metadata value as a string by key name 598 + LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size); 599 + 600 + // Get the number of metadata key/value pairs 601 + LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model); 602 + 603 + // Get sampling metadata key name. Returns nullptr if the key is invalid 604 + LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key); 605 + 606 + // Get metadata key name by index 607 + LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); 608 + 609 + // Get metadata value as a string by index 610 + LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size); 611 + 612 + // Get a string describing the model type 613 + LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); 614 + 615 + // Returns the total size of all the tensors in the model in bytes 616 + LLAMA_API uint64_t llama_model_size(const struct llama_model * model); 617 + 618 + // Get the default chat template. Returns nullptr if not available 619 + // If name is NULL, returns the default chat template 620 + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); 621 + 622 + // Returns the total number of parameters in the model 623 + LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); 624 + 625 + // Returns true if the model contains an encoder that requires llama_encode() call 626 + LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); 627 + 628 + // Returns true if the model contains a decoder that requires llama_decode() call 629 + LLAMA_API bool llama_model_has_decoder(const struct llama_model * model); 630 + 631 + // For encoder-decoder models, this function returns id of the token that must be provided 632 + // to the decoder to start generating output sequence. For other models, it returns -1. 633 + LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model); 634 + 635 + // Returns true if the model is recurrent (like Mamba, RWKV, etc.) 636 + LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); 637 + 638 + // Returns true if the model is hybrid (like Jamba, Granite, etc.) 639 + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); 640 + 641 + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) 642 + LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); 643 + 644 + // Returns 0 on success 645 + LLAMA_API uint32_t llama_model_quantize( 646 + const char * fname_inp, 647 + const char * fname_out, 648 + const llama_model_quantize_params * params); 649 + 650 + // 651 + // Adapters 652 + // 653 + 654 + // Load a LoRA adapter from file 655 + // The adapter is valid as long as the associated model is not freed 656 + LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( 657 + struct llama_model * model, 658 + const char * path_lora); 659 + 660 + // Functions to access the adapter's GGUF metadata scalar values 661 + // - The functions return the length of the string on success, or -1 on failure 662 + // - The output string is always null-terminated and cleared on failure 663 + // - When retrieving a string, an extra byte must be allocated to account for the null terminator 664 + // - GGUF array values are not supported by these functions 665 + 666 + // Get metadata value as a string by key name 667 + LLAMA_API int32_t llama_adapter_meta_val_str(const struct llama_adapter_lora * adapter, const char * key, char * buf, size_t buf_size); 668 + 669 + // Get the number of metadata key/value pairs 670 + LLAMA_API int32_t llama_adapter_meta_count(const struct llama_adapter_lora * adapter); 671 + 672 + // Get metadata key name by index 673 + LLAMA_API int32_t llama_adapter_meta_key_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); 674 + 675 + // Get metadata value as a string by index 676 + LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); 677 + 678 + // Manually free a LoRA adapter 679 + // NOTE: loaded adapters that are not manually freed will be freed when the associated model is deleted 680 + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); 681 + 682 + // Get the invocation tokens if the current lora is an alora 683 + LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); 684 + LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter); 685 + 686 + // The following functions operate on a llama_context, hence the naming: llama_verb_... 687 + 688 + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. 689 + LLAMA_API int32_t llama_set_adapters_lora( 690 + struct llama_context * ctx, 691 + struct llama_adapter_lora ** adapters, 692 + size_t n_adapters, 693 + float * scales); 694 + 695 + // Apply a loaded control vector to a llama_context, or if data is NULL, clear 696 + // the currently loaded vector. 697 + // n_embd should be the size of a single layer's control, and data should point 698 + // to an n_embd x n_layers buffer starting from layer 1. 699 + // il_start and il_end are the layer range the vector should apply to (both inclusive) 700 + // See llama_control_vector_load in common to load a control vector. 701 + LLAMA_API int32_t llama_set_adapter_cvec( 702 + struct llama_context * ctx, 703 + const float * data, 704 + size_t len, 705 + int32_t n_embd, 706 + int32_t il_start, 707 + int32_t il_end); 708 + 709 + // 710 + // Memory 711 + // 712 + 713 + // Clear the memory contents 714 + // If data == true, the data buffers will also be cleared together with the metadata 715 + LLAMA_API void llama_memory_clear( 716 + llama_memory_t mem, 717 + bool data); 718 + 719 + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) 720 + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails 721 + // seq_id < 0 : match any sequence 722 + // p0 < 0 : [0, p1] 723 + // p1 < 0 : [p0, inf) 724 + LLAMA_API bool llama_memory_seq_rm( 725 + llama_memory_t mem, 726 + llama_seq_id seq_id, 727 + llama_pos p0, 728 + llama_pos p1); 729 + 730 + // Copy all tokens that belong to the specified sequence to another sequence 731 + // p0 < 0 : [0, p1] 732 + // p1 < 0 : [p0, inf) 733 + LLAMA_API void llama_memory_seq_cp( 734 + llama_memory_t mem, 735 + llama_seq_id seq_id_src, 736 + llama_seq_id seq_id_dst, 737 + llama_pos p0, 738 + llama_pos p1); 739 + 740 + // Removes all tokens that do not belong to the specified sequence 741 + LLAMA_API void llama_memory_seq_keep( 742 + llama_memory_t mem, 743 + llama_seq_id seq_id); 744 + 745 + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) 746 + // p0 < 0 : [0, p1] 747 + // p1 < 0 : [p0, inf) 748 + LLAMA_API void llama_memory_seq_add( 749 + llama_memory_t mem, 750 + llama_seq_id seq_id, 751 + llama_pos p0, 752 + llama_pos p1, 753 + llama_pos delta); 754 + 755 + // Integer division of the positions by factor of `d > 1` 756 + // p0 < 0 : [0, p1] 757 + // p1 < 0 : [p0, inf) 758 + LLAMA_API void llama_memory_seq_div( 759 + llama_memory_t mem, 760 + llama_seq_id seq_id, 761 + llama_pos p0, 762 + llama_pos p1, 763 + int d); 764 + 765 + // Returns the smallest position present in the memory for the specified sequence 766 + // This is typically non-zero only for SWA caches 767 + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory 768 + // Return -1 if the sequence is empty 769 + LLAMA_API llama_pos llama_memory_seq_pos_min( 770 + llama_memory_t mem, 771 + llama_seq_id seq_id); 772 + 773 + // Returns the largest position present in the memory for the specified sequence 774 + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory 775 + // Return -1 if the sequence is empty 776 + LLAMA_API llama_pos llama_memory_seq_pos_max( 777 + llama_memory_t mem, 778 + llama_seq_id seq_id); 779 + 780 + // Check if the memory supports shifting 781 + LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); 782 + 783 + // 784 + // State / sessions 785 + // 786 + 787 + // Returns the *actual* size in bytes of the state 788 + // (logits, embedding and memory) 789 + // Only use when saving the state, not when restoring it, otherwise the size may be too small. 790 + LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); 791 + LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), 792 + "use llama_state_get_size instead"); 793 + 794 + // Copies the state to the specified destination address. 795 + // Destination needs to have allocated enough memory. 796 + // Returns the number of bytes copied 797 + LLAMA_API size_t llama_state_get_data( 798 + struct llama_context * ctx, 799 + uint8_t * dst, 800 + size_t size); 801 + LLAMA_API DEPRECATED(size_t llama_copy_state_data( 802 + struct llama_context * ctx, 803 + uint8_t * dst), 804 + "use llama_state_get_data instead"); 805 + 806 + // Set the state reading from the specified address 807 + // Returns the number of bytes read 808 + LLAMA_API size_t llama_state_set_data( 809 + struct llama_context * ctx, 810 + const uint8_t * src, 811 + size_t size); 812 + LLAMA_API DEPRECATED(size_t llama_set_state_data( 813 + struct llama_context * ctx, 814 + const uint8_t * src), 815 + "use llama_state_set_data instead"); 816 + 817 + // Save/load session file 818 + LLAMA_API bool llama_state_load_file( 819 + struct llama_context * ctx, 820 + const char * path_session, 821 + llama_token * tokens_out, 822 + size_t n_token_capacity, 823 + size_t * n_token_count_out); 824 + LLAMA_API DEPRECATED(bool llama_load_session_file( 825 + struct llama_context * ctx, 826 + const char * path_session, 827 + llama_token * tokens_out, 828 + size_t n_token_capacity, 829 + size_t * n_token_count_out), 830 + "use llama_state_load_file instead"); 831 + 832 + LLAMA_API bool llama_state_save_file( 833 + struct llama_context * ctx, 834 + const char * path_session, 835 + const llama_token * tokens, 836 + size_t n_token_count); 837 + LLAMA_API DEPRECATED(bool llama_save_session_file( 838 + struct llama_context * ctx, 839 + const char * path_session, 840 + const llama_token * tokens, 841 + size_t n_token_count), 842 + "use llama_state_save_file instead"); 843 + 844 + // Get the exact size needed to copy the state of a single sequence 845 + LLAMA_API size_t llama_state_seq_get_size( 846 + struct llama_context * ctx, 847 + llama_seq_id seq_id); 848 + 849 + // Copy the state of a single sequence into the specified buffer 850 + LLAMA_API size_t llama_state_seq_get_data( 851 + struct llama_context * ctx, 852 + uint8_t * dst, 853 + size_t size, 854 + llama_seq_id seq_id); 855 + 856 + // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence 857 + // Returns: 858 + // - Positive: Ok 859 + // - Zero: Failed to load 860 + LLAMA_API size_t llama_state_seq_set_data( 861 + struct llama_context * ctx, 862 + const uint8_t * src, 863 + size_t size, 864 + llama_seq_id dest_seq_id); 865 + 866 + LLAMA_API size_t llama_state_seq_save_file( 867 + struct llama_context * ctx, 868 + const char * filepath, 869 + llama_seq_id seq_id, 870 + const llama_token * tokens, 871 + size_t n_token_count); 872 + 873 + LLAMA_API size_t llama_state_seq_load_file( 874 + struct llama_context * ctx, 875 + const char * filepath, 876 + llama_seq_id dest_seq_id, 877 + llama_token * tokens_out, 878 + size_t n_token_capacity, 879 + size_t * n_token_count_out); 880 + 881 + // for backwards-compat 882 + #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 883 + 884 + // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) 885 + #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 886 + 887 + typedef uint32_t llama_state_seq_flags; 888 + 889 + LLAMA_API size_t llama_state_seq_get_size_ext( 890 + struct llama_context * ctx, 891 + llama_seq_id seq_id, 892 + llama_state_seq_flags flags); 893 + 894 + LLAMA_API size_t llama_state_seq_get_data_ext( 895 + struct llama_context * ctx, 896 + uint8_t * dst, 897 + size_t size, 898 + llama_seq_id seq_id, 899 + llama_state_seq_flags flags); 900 + 901 + LLAMA_API size_t llama_state_seq_set_data_ext( 902 + struct llama_context * ctx, 903 + const uint8_t * src, 904 + size_t size, 905 + llama_seq_id dest_seq_id, 906 + llama_state_seq_flags flags); 907 + 908 + // 909 + // Decoding 910 + // 911 + 912 + // Return batch for single sequence of tokens 913 + // The sequence ID will be fixed to 0 914 + // The position of the tokens will be tracked automatically by llama_decode 915 + // 916 + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it 917 + // 918 + LLAMA_API struct llama_batch llama_batch_get_one( 919 + llama_token * tokens, 920 + int32_t n_tokens); 921 + 922 + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens 923 + // Each token can be assigned up to n_seq_max sequence ids 924 + // The batch has to be freed with llama_batch_free() 925 + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) 926 + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token 927 + // The rest of the llama_batch members are allocated with size n_tokens 928 + // All members are left uninitialized 929 + LLAMA_API struct llama_batch llama_batch_init( 930 + int32_t n_tokens, 931 + int32_t embd, 932 + int32_t n_seq_max); 933 + 934 + // Frees a batch of tokens allocated with llama_batch_init() 935 + LLAMA_API void llama_batch_free(struct llama_batch batch); 936 + 937 + // Process a batch of tokens. 938 + // In contrast to llama_decode() - this call does not use KV cache. 939 + // For encode-decoder contexts, processes the batch using the encoder. 940 + // Can store the encoder output internally for later use by the decoder's cross-attention layers. 941 + // 0 - success 942 + // < 0 - error. the memory state is restored to the state before this call 943 + LLAMA_API int32_t llama_encode( 944 + struct llama_context * ctx, 945 + struct llama_batch batch); 946 + 947 + // Process a batch of tokens. 948 + // Requires the context to have a memory. 949 + // For encode-decoder contexts, processes the batch using the decoder. 950 + // Positive return values does not mean a fatal error, but rather a warning. 951 + // Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context 952 + // To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max() 953 + // Upon other return values, the memory state is restored to the state before this call 954 + // 0 - success 955 + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) 956 + // 2 - aborted (processed ubatches will remain in the context's memory) 957 + // -1 - invalid input batch 958 + // < -1 - fatal error (processed ubatches will remain in the context's memory) 959 + LLAMA_API int32_t llama_decode( 960 + struct llama_context * ctx, 961 + struct llama_batch batch); 962 + 963 + // Set the number of threads used for decoding 964 + // n_threads is the number of threads used for generation (single token) 965 + // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) 966 + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch); 967 + 968 + // Get the number of threads used for generation of a single token. 969 + LLAMA_API int32_t llama_n_threads(struct llama_context * ctx); 970 + 971 + // Get the number of threads used for prompt and batch processing (multiple token). 972 + LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx); 973 + 974 + // Set whether the context outputs embeddings or not 975 + // TODO: rename to avoid confusion with llama_get_embeddings() 976 + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); 977 + 978 + // Set whether to use causal attention or not 979 + // If set to true, the model will only attend to the past tokens 980 + LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); 981 + 982 + // Set whether the model is in warmup mode or not 983 + // If true, all model tensors are activated during llama_decode() to load and cache their weights. 984 + LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); 985 + 986 + // Set abort callback 987 + LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); 988 + 989 + // Wait until all computations are finished 990 + // This is automatically done when using one of the functions below to obtain the computation results 991 + // and is not necessary to call it explicitly in most cases 992 + LLAMA_API void llama_synchronize(struct llama_context * ctx); 993 + 994 + // Token logits obtained from the last call to llama_decode() 995 + // The logits for which llama_batch.logits[i] != 0 are stored contiguously 996 + // in the order they have appeared in the batch. 997 + // Rows: number of tokens for which llama_batch.logits[i] != 0 998 + // Cols: n_vocab 999 + // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) 1000 + LLAMA_API float * llama_get_logits(struct llama_context * ctx); 1001 + 1002 + // Logits for the ith token. For positive indices, Equivalent to: 1003 + // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab 1004 + // Negative indices can be used to access logits in reverse order, -1 is the last logit. 1005 + // returns NULL for invalid ids. 1006 + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); 1007 + 1008 + // Get all output token embeddings. 1009 + // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model, 1010 + // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously 1011 + // in the order they have appeared in the batch. 1012 + // shape: [n_outputs*n_embd] 1013 + // Otherwise, returns NULL. 1014 + // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) 1015 + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); 1016 + 1017 + // Get the embeddings for the ith token. For positive indices, Equivalent to: 1018 + // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd 1019 + // Negative indices can be used to access embeddings in reverse order, -1 is the last embedding. 1020 + // shape: [n_embd] (1-dimensional) 1021 + // returns NULL for invalid ids. 1022 + LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); 1023 + 1024 + // Get the embeddings for a sequence id 1025 + // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE 1026 + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence 1027 + // otherwise: float[n_embd] (1-dimensional) 1028 + LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); 1029 + 1030 + // 1031 + // backend sampling API [EXPERIMENTAL] 1032 + // note: use only if the llama_context was created with at least one llama_sampler_seq_config 1033 + // 1034 + 1035 + // Get the backend sampled token for the ith token. 1036 + // Returns LLAMA_TOKEN_NULL if no token was sampled. 1037 + LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); 1038 + 1039 + // Get the backend sampled probabilities for the ith token 1040 + // The index matches llama_get_sampled_token_ith(). 1041 + // Returns NULL if no probabilities were generated. 1042 + LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); 1043 + LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); 1044 + 1045 + // Get the backend sampled logits for the ith token 1046 + // Returns NULL if no logits were sampled. 1047 + LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i); 1048 + LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); 1049 + 1050 + // Get the backend sampled candidates (token ids) for the ith token 1051 + // These are needed to map probability/logit indices to vocab token ids. 1052 + // Returns NULL if no candidates were sampled. 1053 + LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); 1054 + LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); 1055 + 1056 + // 1057 + // Vocab 1058 + // 1059 + 1060 + LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); 1061 + 1062 + LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); 1063 + 1064 + LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); 1065 + 1066 + // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) 1067 + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); 1068 + 1069 + // Identify if Token Id is a control token or a render-able token 1070 + LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); 1071 + 1072 + // Special tokens 1073 + LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence 1074 + LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence 1075 + LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn 1076 + LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator 1077 + LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line 1078 + LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding 1079 + LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask 1080 + 1081 + LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); 1082 + LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); 1083 + LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab); 1084 + 1085 + LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); 1086 + LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); 1087 + LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); 1088 + LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); 1089 + LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); 1090 + LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); 1091 + 1092 + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); 1093 + DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); 1094 + DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); 1095 + DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); 1096 + DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); 1097 + DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); 1098 + DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); 1099 + DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); 1100 + DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); 1101 + DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); 1102 + DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); 1103 + DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); 1104 + DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); 1105 + DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); 1106 + DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); 1107 + DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); 1108 + DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); 1109 + DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); 1110 + DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); 1111 + DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); 1112 + 1113 + // CLS is equivalent to BOS 1114 + DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification 1115 + "use llama_vocab_bos instead"); 1116 + 1117 + // 1118 + // Tokenization 1119 + // 1120 + // The API is thread-safe. 1121 + // 1122 + 1123 + /// @details Convert the provided text into tokens. 1124 + /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. 1125 + /// @return Returns the number of tokens on success, no more than n_tokens_max 1126 + /// @return Returns a negative number on failure - the number of tokens that would have been returned 1127 + /// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit) 1128 + /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so. 1129 + /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated 1130 + /// as plaintext. Does not insert a leading space. 1131 + LLAMA_API int32_t llama_tokenize( 1132 + const struct llama_vocab * vocab, 1133 + const char * text, 1134 + int32_t text_len, 1135 + llama_token * tokens, 1136 + int32_t n_tokens_max, 1137 + bool add_special, 1138 + bool parse_special); 1139 + 1140 + // Token Id -> Piece. 1141 + // Uses the vocabulary in the provided context. 1142 + // Does not write null terminator to the buffer. 1143 + // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') 1144 + // @param special If true, special tokens are rendered in the output. 1145 + LLAMA_API int32_t llama_token_to_piece( 1146 + const struct llama_vocab * vocab, 1147 + llama_token token, 1148 + char * buf, 1149 + int32_t length, 1150 + int32_t lstrip, 1151 + bool special); 1152 + 1153 + /// @details Convert the provided tokens into text (inverse of llama_tokenize()). 1154 + /// @param text The char pointer must be large enough to hold the resulting text. 1155 + /// @return Returns the number of chars/bytes on success, no more than text_len_max. 1156 + /// @return Returns a negative number on failure - the number of chars/bytes that would have been returned. 1157 + /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. 1158 + /// @param unparse_special If true, special tokens are rendered in the output. 1159 + LLAMA_API int32_t llama_detokenize( 1160 + const struct llama_vocab * vocab, 1161 + const llama_token * tokens, 1162 + int32_t n_tokens, 1163 + char * text, 1164 + int32_t text_len_max, 1165 + bool remove_special, 1166 + bool unparse_special); 1167 + 1168 + // 1169 + // Chat templates 1170 + // 1171 + 1172 + /// Apply chat template. Inspired by hf apply_chat_template() on python. 1173 + /// 1174 + /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template 1175 + /// @param tmpl A Jinja template to use for this chat. 1176 + /// @param chat Pointer to a list of multiple llama_chat_message 1177 + /// @param n_msg Number of llama_chat_message in this chat 1178 + /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. 1179 + /// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) 1180 + /// @param length The size of the allocated buffer 1181 + /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. 1182 + LLAMA_API int32_t llama_chat_apply_template( 1183 + const char * tmpl, 1184 + const struct llama_chat_message * chat, 1185 + size_t n_msg, 1186 + bool add_ass, 1187 + char * buf, 1188 + int32_t length); 1189 + 1190 + // Get list of built-in chat templates 1191 + LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len); 1192 + 1193 + // 1194 + // Sampling API 1195 + // 1196 + // Sample usage: 1197 + // 1198 + // // prepare the sampling chain at the start 1199 + // auto sparams = llama_sampler_chain_default_params(); 1200 + // 1201 + // llama_sampler * smpl = llama_sampler_chain_init(sparams); 1202 + // 1203 + // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); 1204 + // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); 1205 + // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); 1206 + // 1207 + // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat" 1208 + // // this sampler will be responsible to select the actual token 1209 + // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed)); 1210 + // 1211 + // ... 1212 + // 1213 + // // decoding loop: 1214 + // while (...) { 1215 + // ... 1216 + // 1217 + // llama_decode(ctx, batch); 1218 + // 1219 + // // sample from the logits of the last token in the batch 1220 + // const llama_token id = llama_sampler_sample(smpl, ctx, -1); 1221 + // 1222 + // ... 1223 + // } 1224 + // 1225 + // llama_sampler_free(smpl); 1226 + // 1227 + 1228 + typedef void * llama_sampler_context_t; 1229 + 1230 + struct llama_sampler_data { 1231 + struct ggml_tensor * logits; 1232 + struct ggml_tensor * probs; 1233 + struct ggml_tensor * sampled; 1234 + struct ggml_tensor * candidates; 1235 + }; 1236 + 1237 + // user code can implement the interface below in order to create custom llama_sampler 1238 + struct llama_sampler_i { 1239 + const char * (*name) (const struct llama_sampler * smpl); // can be NULL 1240 + void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL 1241 + void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required 1242 + void (*reset) ( struct llama_sampler * smpl); // can be NULL 1243 + struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL 1244 + void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL 1245 + 1246 + // [EXPERIMENTAL] 1247 + // backend sampling interface: 1248 + 1249 + // return true if the backend supports all ops needed by the sampler 1250 + // note: call once per sampler 1251 + bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); 1252 + 1253 + // call after .backend_apply() 1254 + void (*backend_accept)( 1255 + struct llama_sampler * smpl, 1256 + struct ggml_context * ctx, 1257 + struct ggml_cgraph * gf, 1258 + struct ggml_tensor * selected_token); 1259 + 1260 + // call after .backend_init() 1261 + void (*backend_apply)( 1262 + struct llama_sampler * smpl, 1263 + struct ggml_context * ctx, 1264 + struct ggml_cgraph * gf, 1265 + struct llama_sampler_data * data); 1266 + 1267 + // called before graph execution to set inputs for the current ubatch 1268 + void (*backend_set_input)(struct llama_sampler * smpl); 1269 + }; 1270 + 1271 + struct llama_sampler { 1272 + struct llama_sampler_i * iface; 1273 + 1274 + llama_sampler_context_t ctx; 1275 + }; 1276 + 1277 + // [EXPERIMENTAL] 1278 + // attach a sampler to the context 1279 + // note: prefer initializing the context with llama_context_params.samplers when possible 1280 + LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); 1281 + 1282 + // mirror of llama_sampler_i: 1283 + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx); 1284 + LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); 1285 + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); 1286 + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); 1287 + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); 1288 + LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); 1289 + // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) 1290 + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); 1291 + 1292 + // llama_sampler_chain 1293 + // a type of llama_sampler that can chain multiple samplers one after another 1294 + 1295 + LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); 1296 + 1297 + // important: takes ownership of the sampler object and will free it when llama_sampler_free is called 1298 + LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); 1299 + 1300 + // return NULL if: 1301 + // - the sampler is NULL 1302 + // - the sampler is not a llama_sampler_chain 1303 + // - the index is out of bounds, unless i == -1 1304 + // - if i == -1, returns the chain itself (can be used to check if the sampler is a chain) 1305 + LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i); 1306 + 1307 + // the total number of samplers in the chain 1308 + LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); 1309 + 1310 + // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed 1311 + LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); 1312 + 1313 + // available samplers: 1314 + 1315 + LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); 1316 + 1317 + /// seed == LLAMA_DEFAULT_SEED to use a random seed. 1318 + LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed); 1319 + 1320 + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 1321 + /// Setting k <= 0 makes this a noop 1322 + LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); 1323 + 1324 + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 1325 + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); 1326 + 1327 + /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 1328 + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); 1329 + 1330 + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. 1331 + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); 1332 + 1333 + /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf 1334 + LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); 1335 + 1336 + /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. 1337 + LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); 1338 + 1339 + /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 1340 + LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); 1341 + 1342 + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 1343 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); 1344 + 1345 + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. 1346 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. 1347 + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. 1348 + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. 1349 + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. 1350 + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. 1351 + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( 1352 + int32_t n_vocab, 1353 + uint32_t seed, 1354 + float tau, 1355 + float eta, 1356 + int32_t m); 1357 + 1358 + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. 1359 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. 1360 + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. 1361 + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. 1362 + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. 1363 + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( 1364 + uint32_t seed, 1365 + float tau, 1366 + float eta); 1367 + 1368 + /// @details Initializes a GBNF grammar, see grammars/README.md for details. 1369 + /// @param vocab The vocabulary that this grammar will be used with. 1370 + /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. 1371 + /// @param grammar_root The name of the start symbol for the grammar. 1372 + LLAMA_API struct llama_sampler * llama_sampler_init_grammar( 1373 + const struct llama_vocab * vocab, 1374 + const char * grammar_str, 1375 + const char * grammar_root); 1376 + 1377 + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( 1378 + const struct llama_vocab * vocab, 1379 + const char * grammar_str, 1380 + const char * grammar_root, 1381 + const char ** trigger_words, 1382 + size_t num_trigger_words, 1383 + const llama_token * trigger_tokens, 1384 + size_t num_trigger_tokens), 1385 + "use llama_sampler_init_grammar_lazy_patterns instead"); 1386 + 1387 + 1388 + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 1389 + /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. 1390 + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. 1391 + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( 1392 + const struct llama_vocab * vocab, 1393 + const char * grammar_str, 1394 + const char * grammar_root, 1395 + const char ** trigger_patterns, 1396 + size_t num_trigger_patterns, 1397 + const llama_token * trigger_tokens, 1398 + size_t num_trigger_tokens); 1399 + 1400 + 1401 + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. 1402 + LLAMA_API struct llama_sampler * llama_sampler_init_penalties( 1403 + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) 1404 + float penalty_repeat, // 1.0 = disabled 1405 + float penalty_freq, // 0.0 = disabled 1406 + float penalty_present); // 0.0 = disabled 1407 + 1408 + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 1409 + LLAMA_API struct llama_sampler * llama_sampler_init_dry( 1410 + const struct llama_vocab * vocab, 1411 + int32_t n_ctx_train, 1412 + float dry_multiplier, 1413 + float dry_base, 1414 + int32_t dry_allowed_length, 1415 + int32_t dry_penalty_last_n, 1416 + const char ** seq_breakers, 1417 + size_t num_breakers); 1418 + 1419 + /// adaptive-p: select tokens near a configurable target probability over time. 1420 + /// 1421 + /// the adaptive-p sampler transforms the token probability distribution to favor tokens 1422 + /// that fall near a user-configurable probability target. 1423 + /// 1424 + /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* 1425 + /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an 1426 + /// adapted target probability at each sampling step, thus maintaining the desired target 1427 + /// probability over time. 1428 + /// 1429 + /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last 1430 + /// in the sampler chain (like mirostat, dist, greedy). 1431 + /// 1432 + /// only mild truncation before this sampler is recommended. we suggest applying min-p 1433 + /// before adaptive-p as the only other active sampler in the chain. 1434 + /// 1435 + /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) 1436 + /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) 1437 + /// @param seed RNG seed 1438 + /// 1439 + /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 1440 + /// 1441 + LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p( 1442 + float target, 1443 + float decay, 1444 + uint32_t seed); 1445 + 1446 + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( 1447 + int32_t n_vocab, 1448 + int32_t n_logit_bias, 1449 + const llama_logit_bias * logit_bias); 1450 + 1451 + // this sampler is meant to be used for fill-in-the-middle infilling 1452 + // it's supposed to be used after top_k + top_p sampling 1453 + // 1454 + // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG 1455 + // 2. combine probs of tokens that have the same prefix 1456 + // 1457 + // example: 1458 + // 1459 + // - before: 1460 + // "hel": 0.5 1461 + // "hell": 0.2 1462 + // "hello": 0.1 1463 + // "dummy": 0.1 1464 + // 1465 + // - after: 1466 + // "hel": 0.8 1467 + // "dummy": 0.1 1468 + // 1469 + // 3. discard non-EOG tokens with low prob 1470 + // 4. if no tokens are left -> pick EOT 1471 + // 1472 + LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); 1473 + 1474 + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise 1475 + LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); 1476 + 1477 + /// @details Sample and accept a token from the idx-th output of the last evaluation 1478 + // 1479 + // Shorthand for: 1480 + // const auto * logits = llama_get_logits_ith(ctx, idx); 1481 + // llama_token_data_array cur_p = { ... init from logits ... }; 1482 + // llama_sampler_apply(smpl, &cur_p); 1483 + // auto token = cur_p.data[cur_p.selected].id; 1484 + // llama_sampler_accept(smpl, token); 1485 + // return token; 1486 + // Returns the sampled token 1487 + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); 1488 + 1489 + // TODO: extend in the future 1490 + //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); 1491 + 1492 + // 1493 + // Model split 1494 + // 1495 + 1496 + /// @details Build a split GGUF final path for this chunk. 1497 + /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" 1498 + // Returns the split_path length. 1499 + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); 1500 + 1501 + /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. 1502 + /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" 1503 + // Returns the split_prefix length. 1504 + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); 1505 + 1506 + // Print system information 1507 + LLAMA_API const char * llama_print_system_info(void); 1508 + 1509 + // Set callback for all future logging events. 1510 + // If this is not called, or NULL is supplied, everything is output on stderr. 1511 + // The logger state is global so these functions are NOT thread safe. 1512 + LLAMA_API void llama_log_get(ggml_log_callback * log_callback, void ** user_data); 1513 + LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); 1514 + 1515 + // 1516 + // Performance utils 1517 + // 1518 + // NOTE: Used by llama.cpp examples/tools, avoid using in third-party apps. Instead, do your own performance measurements. 1519 + // 1520 + 1521 + struct llama_perf_context_data { 1522 + // ms == milliseconds 1523 + double t_start_ms; // absolute start time 1524 + double t_load_ms; // time needed for loading the model 1525 + double t_p_eval_ms; // time needed for processing the prompt 1526 + double t_eval_ms; // time needed for generating tokens 1527 + 1528 + int32_t n_p_eval; // number of prompt tokens 1529 + int32_t n_eval; // number of generated tokens 1530 + int32_t n_reused; // number of times a ggml compute graph had been reused 1531 + }; 1532 + 1533 + struct llama_perf_sampler_data { 1534 + double t_sample_ms; // time needed for sampling in ms 1535 + 1536 + int32_t n_sample; // number of sampled tokens 1537 + }; 1538 + 1539 + LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx); 1540 + LLAMA_API void llama_perf_context_print(const struct llama_context * ctx); 1541 + LLAMA_API void llama_perf_context_reset( struct llama_context * ctx); 1542 + 1543 + // NOTE: the following work only with samplers constructed via llama_sampler_chain_init 1544 + LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain); 1545 + LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); 1546 + LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); 1547 + 1548 + // print a breakdown of per-device memory use via LLAMA_LOG: 1549 + LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); 1550 + 1551 + // 1552 + // training 1553 + // 1554 + 1555 + // function that returns whether or not a given tensor contains trainable parameters 1556 + typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); 1557 + 1558 + // always returns true 1559 + LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); 1560 + 1561 + struct llama_opt_params { 1562 + uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0 1563 + 1564 + llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters 1565 + void * param_filter_ud; // userdata for determining which tensors contain trainable parameters 1566 + 1567 + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters 1568 + void * get_opt_pars_ud; // userdata for calculating optimizer parameters 1569 + 1570 + enum ggml_opt_optimizer_type optimizer_type; 1571 + }; 1572 + 1573 + LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); 1574 + 1575 + LLAMA_API void llama_opt_epoch( 1576 + struct llama_context * lctx, 1577 + ggml_opt_dataset_t dataset, 1578 + ggml_opt_result_t result_train, 1579 + ggml_opt_result_t result_eval, 1580 + int64_t idata_split, 1581 + ggml_opt_epoch_callback callback_train, 1582 + ggml_opt_epoch_callback callback_eval); 1583 + 1584 + #ifdef __cplusplus 1585 + } 1586 + #endif 1587 + 1588 + #endif // LLAMA_H
backend/models/bge-small.gguf

This is a binary file and will not be displayed.

+6
backend/src/assets/favicon.svg
··· 1 + <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32"> 2 + <rect width="32" height="32" rx="6" fill="#d4a76a"/> 3 + <text x="16" y="23" text-anchor="middle" 4 + font-family="ui-monospace, 'SF Mono', Menlo, monospace" 5 + font-size="22" font-weight="600" fill="#0d0d0f">k</text> 6 + </svg>
+99
backend/src/assets/index.html
··· 1 + <!doctype html> 2 + <html lang="en"> 3 + <head> 4 + <meta charset="utf-8"> 5 + <meta name="viewport" content="width=device-width,initial-scale=1"> 6 + <title>ken</title> 7 + <link rel="icon" type="image/svg+xml" href="/favicon.svg"> 8 + <link rel="stylesheet" href="/style.css"> 9 + <!--OG_META_PLACEHOLDER--> 10 + </head> 11 + <body> 12 + <header> 13 + <a href="/" class="brand">ken</a> 14 + <nav id="signed-nav" class="signed-nav hidden"> 15 + <span id="session-handle" class="session-handle"></span> 16 + <span class="sep">·</span> 17 + <button id="signout-btn" type="button" class="text-btn">logout</button> 18 + </nav> 19 + </header> 20 + 21 + <main> 22 + <!-- signed-out --> 23 + <section id="signed-out" class="hidden"> 24 + <form id="signin-form" class="row"> 25 + <div class="typeahead"> 26 + <input 27 + id="signin-input" 28 + type="text" 29 + placeholder="your handle" 30 + autocomplete="off" 31 + spellcheck="false" 32 + /> 33 + <ul id="handle-suggestions" class="suggestions hidden"></ul> 34 + </div> 35 + <button type="submit">sign in</button> 36 + </form> 37 + </section> 38 + 39 + <!-- signed-in: search is the primary action, everything else is 40 + metadata below it. 41 + order: search → progress (while indexing) → meta line (when ready) 42 + → error status (only if something goes wrong) → results --> 43 + <section id="signed-in" class="hidden"> 44 + <form id="search-form" class="row hidden"> 45 + <input 46 + id="search-input" 47 + type="text" 48 + placeholder="find a record by describing it" 49 + autocomplete="off" 50 + spellcheck="false" 51 + /> 52 + </form> 53 + 54 + <div id="progress" class="progress hidden"> 55 + <div class="progress-bar"><div class="progress-fill" id="progress-fill"></div></div> 56 + <div class="progress-meta"> 57 + <span id="progress-stage">starting…</span> 58 + <span id="progress-rate" class="muted"></span> 59 + </div> 60 + <div id="progress-flavor" class="progress-flavor muted"></div> 61 + </div> 62 + 63 + <div id="pack-meta" class="pack-meta hidden"> 64 + <span id="pack-stats" class="muted"></span> 65 + <span class="sep muted">·</span> 66 + <span id="pack-state" class="muted"></span> 67 + <button id="pack-save-btn" type="button" class="text-btn text-btn-accent hidden">save to PDS</button> 68 + <button id="pack-delete-btn" type="button" class="text-btn text-btn-danger hidden">delete</button> 69 + </div> 70 + 71 + <div id="status" class="status hidden"></div> 72 + 73 + <section id="results" class="results"></section> 74 + </section> 75 + </main> 76 + 77 + <footer> 78 + <button id="about-btn" type="button" class="text-btn">about</button> 79 + <span class="sep">·</span> 80 + <a href="https://tangled.org/zzstoatzz.io/ken" target="_blank" rel="noopener">source</a> 81 + </footer> 82 + 83 + <!-- about modal — opens from the footer. all explanation happens here 84 + rather than inline in the main flow. --> 85 + <div id="about-overlay" class="overlay hidden"></div> 86 + <div id="about-modal" class="modal hidden" role="dialog" aria-labelledby="about-title"> 87 + <h2 id="about-title">ken</h2> 88 + <p>fuzzy find any record in your atproto repo. sign in, embeddings get built locally against your own PDS, and the vector pack lives as a record you can inspect or delete.</p> 89 + <ul class="about-links"> 90 + <li><a id="about-pack-link" href="#" target="_blank" rel="noopener">example pack on pdsls.dev ↗</a></li> 91 + <li><a href="https://tangled.org/zzstoatzz.io/ken" target="_blank" rel="noopener">source ↗</a></li> 92 + </ul> 93 + <button id="about-close" type="button">close</button> 94 + </div> 95 + 96 + <script src="/nsid-logo.js"></script> 97 + <script src="/main.js"></script> 98 + </body> 99 + </html>
+832
backend/src/assets/main.js
··· 1 + // ken — auth-first UI. 2 + // 3 + // nothing happens without a signed-in session. /api/me decides which view 4 + // to render at load time: 5 + // - 200 { handle, did } → signed-in section, auto-start indexing the 6 + // authenticated user's own records, then search 7 + // - 401 → signed-out section, sign-in form only. no 8 + // drive-by explore, no third-party lookups. 9 + 10 + const $ = (s) => document.querySelector(s); 11 + 12 + const signedOutSection = $("#signed-out"); 13 + const signedInSection = $("#signed-in"); 14 + const signinForm = $("#signin-form"); 15 + const signinInput = $("#signin-input"); 16 + const handleSuggestions = $("#handle-suggestions"); 17 + const signedNav = $("#signed-nav"); 18 + const sessionHandleEl = $("#session-handle"); 19 + const signoutBtn = $("#signout-btn"); 20 + const statusEl = $("#status"); 21 + const resultsEl = $("#results"); 22 + const searchForm = $("#search-form"); 23 + const searchInput = $("#search-input"); 24 + const progressEl = $("#progress"); 25 + const progressFill = $("#progress-fill"); 26 + const progressStage = $("#progress-stage"); 27 + const progressRate = $("#progress-rate"); 28 + const progressFlavor = $("#progress-flavor"); 29 + const packMetaEl = $("#pack-meta"); 30 + const packStatsEl = $("#pack-stats"); 31 + const packSaveBtn = $("#pack-save-btn"); 32 + const packDeleteBtn = $("#pack-delete-btn"); 33 + const packStateEl = $("#pack-state"); 34 + const aboutBtn = $("#about-btn"); 35 + const aboutModal = $("#about-modal"); 36 + const aboutOverlay = $("#about-overlay"); 37 + const aboutClose = $("#about-close"); 38 + const aboutPackLink = $("#about-pack-link"); 39 + 40 + const TYPEAHEAD_URL = "https://typeahead.waow.tech/xrpc/app.bsky.actor.searchActorsTypeahead"; 41 + 42 + let me = null; // { handle, did } once signed in 43 + let pollTimer = null; 44 + let searchDebounceTimer = null; 45 + let typeaheadTimer = null; 46 + let typeaheadAbort = null; 47 + let suggestionIndex = -1; 48 + let currentSuggestions = []; 49 + let indexStartMs = 0; 50 + let lastFresh = 0; 51 + let lastFreshAt = 0; 52 + let smoothedRate = 0; 53 + 54 + // ---------- helpers ---------- 55 + 56 + function escape(s) { 57 + return String(s ?? "") 58 + .replaceAll("&", "&amp;") 59 + .replaceAll("<", "&lt;") 60 + .replaceAll(">", "&gt;") 61 + .replaceAll('"', "&quot;"); 62 + } 63 + 64 + function showStatus(msg, isError = false) { 65 + statusEl.textContent = msg; 66 + statusEl.classList.remove("hidden"); 67 + statusEl.classList.toggle("error", isError); 68 + } 69 + function hideStatus() { 70 + statusEl.classList.add("hidden"); 71 + } 72 + 73 + function formatEta(seconds) { 74 + if (!isFinite(seconds) || seconds <= 0) return ""; 75 + if (seconds < 60) return `~${Math.round(seconds)}s`; 76 + const mins = Math.floor(seconds / 60); 77 + const secs = Math.round(seconds % 60); 78 + return `~${mins}m ${secs}s`; 79 + } 80 + 81 + function pdslsUrl(uri) { 82 + return `https://pdsls.dev/${uri}`; 83 + } 84 + 85 + // ---------- view state ---------- 86 + 87 + function showSignedOut() { 88 + signedOutSection.classList.remove("hidden"); 89 + signedInSection.classList.add("hidden"); 90 + signedNav.classList.add("hidden"); 91 + sessionHandleEl.textContent = ""; 92 + setTimeout(() => signinInput.focus(), 0); 93 + } 94 + 95 + function showSignedIn(handle) { 96 + me = { handle }; 97 + signedOutSection.classList.add("hidden"); 98 + signedInSection.classList.remove("hidden"); 99 + signedNav.classList.remove("hidden"); 100 + sessionHandleEl.textContent = `@${handle}`; 101 + } 102 + 103 + // ---------- progress ---------- 104 + 105 + function startProgress() { 106 + indexStartMs = Date.now(); 107 + lastFresh = 0; 108 + lastFreshAt = indexStartMs; 109 + smoothedRate = 0; 110 + progressFill.style.width = "0%"; 111 + progressStage.textContent = "starting…"; 112 + progressRate.textContent = ""; 113 + progressFlavor.textContent = ""; 114 + progressEl.classList.remove("hidden"); 115 + } 116 + 117 + // Set the single-line progress caption — used for calibrated "last build 118 + // was X" messaging when we have prior-run metadata. No carousel; the about 119 + // modal handles any narrative. 120 + function setFlavor(msg) { 121 + if (progressFlavor.textContent === msg) return; 122 + progressFlavor.style.opacity = "0"; 123 + setTimeout(() => { 124 + progressFlavor.textContent = msg; 125 + progressFlavor.style.opacity = "1"; 126 + }, 150); 127 + } 128 + 129 + function stopProgress() { 130 + progressEl.classList.add("hidden"); 131 + } 132 + 133 + function updateProgress({ fetched, embedded, reused, walking }) { 134 + if (walking) { 135 + progressFill.style.width = "4%"; 136 + progressStage.textContent = `walking repo · fetched ${fetched.toLocaleString()} records`; 137 + progressRate.textContent = ""; 138 + return; 139 + } 140 + const total = Math.max(fetched, embedded, 1); 141 + const pct = Math.min(100, (embedded / total) * 100); 142 + progressFill.style.width = `${pct.toFixed(1)}%`; 143 + 144 + // rate/ETA measured against FRESH work only — reused vectors land in a 145 + // single poll jump and would spike the rate. 146 + const fresh = Math.max(0, embedded - (reused || 0)); 147 + const total_fresh = Math.max(0, total - (reused || 0)); 148 + 149 + const now = Date.now(); 150 + const dt = (now - lastFreshAt) / 1000; 151 + if (dt > 0.1 && fresh > lastFresh) { 152 + const instRate = (fresh - lastFresh) / dt; 153 + smoothedRate = smoothedRate === 0 ? instRate : smoothedRate * 0.7 + instRate * 0.3; 154 + lastFresh = fresh; 155 + lastFreshAt = now; 156 + } else if (fresh < lastFresh) { 157 + lastFresh = fresh; 158 + lastFreshAt = now; 159 + } 160 + 161 + if (reused && reused > 0 && total_fresh === 0) { 162 + progressStage.textContent = `reused all ${reused.toLocaleString()} vectors from your existing pack`; 163 + progressRate.textContent = ""; 164 + } else if (reused && reused > 0) { 165 + progressStage.textContent = `reused ${reused.toLocaleString()} · embedding ${fresh.toLocaleString()} / ${total_fresh.toLocaleString()} fresh`; 166 + if (smoothedRate > 0) { 167 + const eta = (total_fresh - fresh) / smoothedRate; 168 + progressRate.textContent = `${smoothedRate.toFixed(0)} rec/s · eta ${formatEta(eta)}`; 169 + } else { 170 + progressRate.textContent = ""; 171 + } 172 + } else { 173 + progressStage.textContent = `embedded ${embedded.toLocaleString()} / ${total.toLocaleString()}`; 174 + if (smoothedRate > 0) { 175 + const eta = (total - embedded) / smoothedRate; 176 + progressRate.textContent = `${smoothedRate.toFixed(0)} rec/s · eta ${formatEta(eta)}`; 177 + } 178 + } 179 + } 180 + 181 + // ---------- index + poll (signed-in) ---------- 182 + 183 + async function kickoffIndexing() { 184 + if (!me) return; 185 + if (pollTimer) clearInterval(pollTimer); 186 + resultsEl.innerHTML = ""; 187 + // search form is visible as soon as we're signed in — partial search 188 + // works against whatever rows have been embedded so far. empty results 189 + // at first, grows as indexing proceeds. 190 + searchForm.classList.remove("hidden"); 191 + hideStatus(); 192 + 193 + // if a pack already exists in the in-memory cache for this did, jump 194 + // straight to search. 195 + try { 196 + const existing = await fetch(`/api/status/${encodeURIComponent(me.handle)}`); 197 + if (existing.ok) { 198 + const ej = await existing.json(); 199 + if (ej.status === "ready") { 200 + stopProgress(); 201 + renderReady(ej); 202 + return; 203 + } 204 + if (ej.status === "indexing") { 205 + startProgress(); 206 + pollStatus(); 207 + return; 208 + } 209 + } 210 + } catch {} 211 + 212 + startProgress(); 213 + try { 214 + const r = await fetch(`/api/index/${encodeURIComponent(me.handle)}`, { method: "POST" }); 215 + const j = await r.json().catch(() => ({})); 216 + if (r.status === 401 || r.status === 403) { 217 + // session expired mid-flight. back to signed-out. 218 + stopProgress(); 219 + me = null; 220 + showSignedOut(); 221 + showStatus("session expired — sign in again", true); 222 + return; 223 + } 224 + if (!r.ok) { 225 + stopProgress(); 226 + showStatus(`error: ${j.error || r.statusText}`, true); 227 + return; 228 + } 229 + pollStatus(); 230 + } catch (e) { 231 + stopProgress(); 232 + showStatus(`network error: ${e.message}`, true); 233 + } 234 + } 235 + 236 + function renderReady(j) { 237 + const colCount = j.collections?.length || 0; 238 + const rec = j.count || 0; 239 + const elapsed = indexStartMs ? ((Date.now() - indexStartMs) / 1000).toFixed(1) : null; 240 + // hide the ready-status banner — this info lives in the pack-meta line now 241 + hideStatus(); 242 + packStatsEl.textContent = elapsed && elapsed > 0 243 + ? `${rec.toLocaleString()} records · ${colCount} collections · ${elapsed}s` 244 + : `${rec.toLocaleString()} records · ${colCount} collections`; 245 + searchForm.classList.remove("hidden"); 246 + searchInput.focus(); 247 + renderPackActions(j); 248 + } 249 + 250 + // show the save / delete buttons based on the pack's current persistence 251 + // state. the intent is that writing a record to the user's repo must be a 252 + // deliberate click, not something we do for them. 253 + function renderPackActions(j) { 254 + packMetaEl.classList.remove("hidden"); 255 + if (j.persisted) { 256 + packSaveBtn.classList.add("hidden"); 257 + packDeleteBtn.classList.remove("hidden"); 258 + packStateEl.textContent = "saved"; 259 + } else { 260 + packSaveBtn.classList.remove("hidden"); 261 + packDeleteBtn.classList.add("hidden"); 262 + packStateEl.textContent = "not saved"; 263 + } 264 + // point the about-modal's "example pack" link at whatever pack is 265 + // currently persisted, so people can see a real record on pdsls.dev. 266 + if (j.persisted_uri && aboutPackLink) { 267 + aboutPackLink.href = `https://pdsls.dev/${j.persisted_uri}`; 268 + } 269 + } 270 + 271 + async function pollStatus() { 272 + const poll = async () => { 273 + try { 274 + const r = await fetch(`/api/status/${encodeURIComponent(me.handle)}`); 275 + const j = await r.json(); 276 + if (!r.ok) { 277 + showStatus(`error: ${j.error || r.statusText}`, true); 278 + clearInterval(pollTimer); 279 + return; 280 + } 281 + const { status, records_fetched, records_embedded, records_reused, prior_build_ms, prior_count } = j; 282 + const fetched = records_fetched || 0; 283 + const embedded = records_embedded || 0; 284 + const reused = records_reused || 0; 285 + const priorMs = prior_build_ms || 0; 286 + const priorCount = prior_count || 0; 287 + // calibrated flavor line: if we have a prior pack's build stats, 288 + // show the user's real numbers. single line, no rotation. 289 + if (priorMs > 0 && priorCount > 0) { 290 + const priorSec = (priorMs / 1000).toFixed(1); 291 + setFlavor( 292 + `last build: ${priorCount.toLocaleString()} records in ${priorSec}s. unchanged records are reused.`, 293 + ); 294 + } 295 + 296 + if (status === "ready") { 297 + stopProgress(); 298 + renderReady(j); 299 + clearInterval(pollTimer); 300 + } else if (status === "error") { 301 + stopProgress(); 302 + showStatus(`indexing failed: ${j.error_msg || "unknown error"}`, true); 303 + clearInterval(pollTimer); 304 + } else { 305 + if (embedded > 0 || reused > 0) { 306 + updateProgress({ fetched, embedded, reused, walking: false }); 307 + } else { 308 + updateProgress({ fetched, embedded: 0, reused: 0, walking: true }); 309 + } 310 + } 311 + } catch (e) { 312 + stopProgress(); 313 + showStatus(`poll error: ${e.message}`, true); 314 + clearInterval(pollTimer); 315 + } 316 + }; 317 + poll(); 318 + pollTimer = setInterval(poll, 1000); 319 + } 320 + 321 + // ---------- search (signed-in) ---------- 322 + 323 + async function runSearch(q) { 324 + if (!me) return; 325 + if (!q.trim()) { 326 + resultsEl.innerHTML = ""; 327 + lastQuery = ""; 328 + updateShareButtonVisibility(); 329 + return; 330 + } 331 + lastQuery = q.trim(); 332 + updateShareButtonVisibility(); 333 + try { 334 + const r = await fetch( 335 + `/api/search/${encodeURIComponent(me.handle)}?q=${encodeURIComponent(q)}&k=30` 336 + ); 337 + const j = await r.json(); 338 + if (!r.ok) { 339 + showStatus(`search error: ${j.error || r.statusText}`, true); 340 + return; 341 + } 342 + // if the pack is still indexing, label the results as partial so the 343 + // user knows why some things might be missing. 344 + if (j.indexing && j.searchable < j.total) { 345 + showStatus( 346 + `searching ${j.searchable.toLocaleString()} of ${j.total.toLocaleString()} records (still indexing)`, 347 + ); 348 + } else if (!j.indexing && j.total > 0) { 349 + // clear any "still indexing" banner once indexing is done 350 + hideStatus(); 351 + } 352 + renderResults(j.results || []); 353 + } catch (e) { 354 + showStatus(`network error: ${e.message}`, true); 355 + } 356 + } 357 + 358 + function renderResults(rows) { 359 + resultsEl.innerHTML = ""; 360 + for (const row of rows) { 361 + const title = row.title || ""; 362 + const body = row.body || ""; 363 + const norm = (s) => s.replace(/\s+/g, " ").trim().toLowerCase(); 364 + let mainLine = title; 365 + let subLine = body; 366 + if (title && body) { 367 + const t = norm(title), 368 + b = norm(body); 369 + if (t === b || b.startsWith(t) || t.startsWith(b)) { 370 + mainLine = title.length >= body.length ? title : body; 371 + subLine = ""; 372 + } 373 + } 374 + if (!mainLine) mainLine = body || "(empty)"; 375 + 376 + const card = document.createElement("article"); 377 + card.className = "result"; 378 + card.innerHTML = ` 379 + <div class="head"> 380 + <span class="collection-chip" data-nsid="${escape(row.collection)}"> 381 + <span class="collection-logo-slot"></span> 382 + <span class="collection-name">${escape(row.collection)}</span> 383 + </span> 384 + ${row.date ? `<span class="date">${escape(row.date)}</span>` : ""} 385 + </div> 386 + <div class="body">${escape(mainLine)}</div> 387 + ${subLine ? `<div class="body-second">${escape(subLine)}</div>` : ""} 388 + <div class="actions"> 389 + <a class="pdsls-link" href="${escape(pdslsUrl(row.uri))}" target="_blank" rel="noopener" title="open in pdsls.dev">↗</a> 390 + </div> 391 + `; 392 + resultsEl.appendChild(card); 393 + } 394 + 395 + // paint nsid logos asynchronously so the main results render stays 396 + // instant. the NsidLogo module caches per-nsid so repeat searches don't 397 + // re-hit the appview. if the shim isn't loaded we just show the chip 398 + // without a logo. 399 + if (window.NsidLogo) { 400 + const nsids = rows.map((r) => r.collection); 401 + window.NsidLogo.fetchAvatarsForNsids(nsids) 402 + .then((avatarMap) => { 403 + resultsEl.querySelectorAll(".collection-chip").forEach((chip) => { 404 + const nsid = chip.dataset.nsid; 405 + const url = avatarMap.get(nsid); 406 + if (!url) return; 407 + const slot = chip.querySelector(".collection-logo-slot"); 408 + if (slot) { 409 + slot.innerHTML = `<img src="${escape(url)}" alt="" class="collection-logo" />`; 410 + } 411 + }); 412 + }) 413 + .catch(() => {}); 414 + } 415 + } 416 + 417 + searchInput.addEventListener("input", () => { 418 + clearTimeout(searchDebounceTimer); 419 + searchDebounceTimer = setTimeout(() => runSearch(searchInput.value), 250); 420 + }); 421 + searchForm.addEventListener("submit", (e) => { 422 + e.preventDefault(); 423 + clearTimeout(searchDebounceTimer); 424 + runSearch(searchInput.value); 425 + }); 426 + 427 + // ---------- sign in / sign out ---------- 428 + 429 + signinForm.addEventListener("submit", (e) => { 430 + e.preventDefault(); 431 + const h = signinInput.value.trim().replace(/^@/, ""); 432 + if (!h) return; 433 + window.location.href = `/oauth/login?handle=${encodeURIComponent(h)}`; 434 + }); 435 + 436 + signoutBtn.addEventListener("click", async () => { 437 + try { 438 + await fetch("/oauth/logout", { method: "POST" }); 439 + } catch {} 440 + me = null; 441 + if (pollTimer) clearInterval(pollTimer); 442 + stopProgress(); 443 + hideStatus(); 444 + resultsEl.innerHTML = ""; 445 + searchForm.classList.add("hidden"); 446 + packMetaEl.classList.add("hidden"); 447 + showSignedOut(); 448 + }); 449 + 450 + // ---------- about modal ---------- 451 + 452 + function showAbout() { 453 + aboutOverlay.classList.remove("hidden"); 454 + aboutModal.classList.remove("hidden"); 455 + } 456 + function hideAbout() { 457 + aboutOverlay.classList.add("hidden"); 458 + aboutModal.classList.add("hidden"); 459 + } 460 + aboutBtn.addEventListener("click", showAbout); 461 + aboutClose.addEventListener("click", hideAbout); 462 + aboutOverlay.addEventListener("click", hideAbout); 463 + document.addEventListener("keydown", (e) => { 464 + if (e.key === "Escape" && !aboutModal.classList.contains("hidden")) hideAbout(); 465 + }); 466 + 467 + // ---------- pack save / delete ---------- 468 + 469 + async function refreshPackState() { 470 + if (!me) return; 471 + try { 472 + const r = await fetch(`/api/status/${encodeURIComponent(me.handle)}`); 473 + if (!r.ok) return; 474 + const j = await r.json(); 475 + if (j.status === "ready") renderPackActions(j); 476 + } catch {} 477 + } 478 + 479 + packSaveBtn.addEventListener("click", async () => { 480 + if (!me) return; 481 + packSaveBtn.disabled = true; 482 + packSaveBtn.textContent = "saving…"; 483 + try { 484 + const r = await fetch("/api/pack/save", { method: "POST", body: "" }); 485 + const j = await r.json().catch(() => ({})); 486 + if (!r.ok) { 487 + showStatus(`save failed: ${j.error || r.statusText}`, true); 488 + } else { 489 + showStatus("pack saved to your PDS"); 490 + } 491 + } catch (e) { 492 + showStatus(`save error: ${e.message}`, true); 493 + } finally { 494 + packSaveBtn.disabled = false; 495 + packSaveBtn.textContent = "save pack to my PDS"; 496 + refreshPackState(); 497 + } 498 + }); 499 + 500 + packDeleteBtn.addEventListener("click", async () => { 501 + if (!me) return; 502 + if (!confirm("delete the saved pack record from your PDS? your in-memory search will still work until you sign out.")) return; 503 + packDeleteBtn.disabled = true; 504 + packDeleteBtn.textContent = "deleting…"; 505 + try { 506 + const r = await fetch("/api/pack/delete", { method: "POST", body: "" }); 507 + const j = await r.json().catch(() => ({})); 508 + if (!r.ok) { 509 + showStatus(`delete failed: ${j.error || r.statusText}`, true); 510 + } else { 511 + const n = j.deleted || 0; 512 + showStatus(`deleted ${n} pack record${n === 1 ? "" : "s"} from your PDS`); 513 + } 514 + } catch (e) { 515 + showStatus(`delete error: ${e.message}`, true); 516 + } finally { 517 + packDeleteBtn.disabled = false; 518 + packDeleteBtn.textContent = "delete saved pack"; 519 + refreshPackState(); 520 + } 521 + }); 522 + 523 + // ---------- handle typeahead (for the SIGN-IN input only) ---------- 524 + 525 + function hideSuggestions() { 526 + handleSuggestions.classList.add("hidden"); 527 + handleSuggestions.innerHTML = ""; 528 + currentSuggestions = []; 529 + suggestionIndex = -1; 530 + } 531 + 532 + function renderSuggestions(actors) { 533 + currentSuggestions = actors; 534 + suggestionIndex = -1; 535 + if (!actors.length) { 536 + hideSuggestions(); 537 + return; 538 + } 539 + handleSuggestions.innerHTML = actors 540 + .map( 541 + (a, i) => ` 542 + <li data-index="${i}" data-handle="${escape(a.handle)}"> 543 + ${a.avatar ? `<img src="${escape(a.avatar)}" alt="" />` : `<span class="avatar-placeholder"></span>`} 544 + <span class="handle">@${escape(a.handle)}</span> 545 + ${a.displayName ? `<span class="display">${escape(a.displayName)}</span>` : ""} 546 + </li> 547 + ` 548 + ) 549 + .join(""); 550 + handleSuggestions.classList.remove("hidden"); 551 + } 552 + 553 + async function fetchSuggestions(q) { 554 + if (typeaheadAbort) typeaheadAbort.abort(); 555 + typeaheadAbort = new AbortController(); 556 + try { 557 + const url = `${TYPEAHEAD_URL}?q=${encodeURIComponent(q)}&limit=8`; 558 + const r = await fetch(url, { signal: typeaheadAbort.signal }); 559 + if (!r.ok) return; 560 + const j = await r.json(); 561 + renderSuggestions(j.actors || []); 562 + } catch (e) { 563 + if (e.name !== "AbortError") console.warn("typeahead error", e); 564 + } 565 + } 566 + 567 + signinInput.addEventListener("input", () => { 568 + const q = signinInput.value.trim().replace(/^@/, ""); 569 + clearTimeout(typeaheadTimer); 570 + if (q.length < 2) { 571 + hideSuggestions(); 572 + return; 573 + } 574 + typeaheadTimer = setTimeout(() => fetchSuggestions(q), 120); 575 + }); 576 + 577 + signinInput.addEventListener("keydown", (e) => { 578 + if (handleSuggestions.classList.contains("hidden")) return; 579 + if (e.key === "ArrowDown") { 580 + e.preventDefault(); 581 + suggestionIndex = Math.min(suggestionIndex + 1, currentSuggestions.length - 1); 582 + updateSuggestionHighlight(); 583 + } else if (e.key === "ArrowUp") { 584 + e.preventDefault(); 585 + suggestionIndex = Math.max(suggestionIndex - 1, -1); 586 + updateSuggestionHighlight(); 587 + } else if (e.key === "Enter" && suggestionIndex >= 0) { 588 + e.preventDefault(); 589 + signinInput.value = currentSuggestions[suggestionIndex].handle; 590 + hideSuggestions(); 591 + } else if (e.key === "Escape") { 592 + hideSuggestions(); 593 + } 594 + }); 595 + 596 + function updateSuggestionHighlight() { 597 + handleSuggestions.querySelectorAll("li").forEach((li, i) => { 598 + li.classList.toggle("active", i === suggestionIndex); 599 + }); 600 + } 601 + 602 + handleSuggestions.addEventListener("mousedown", (e) => { 603 + const li = e.target.closest("li[data-handle]"); 604 + if (!li) return; 605 + e.preventDefault(); 606 + signinInput.value = li.dataset.handle; 607 + hideSuggestions(); 608 + signinInput.focus(); 609 + }); 610 + 611 + document.addEventListener("click", (e) => { 612 + if (!signinForm.contains(e.target)) hideSuggestions(); 613 + }); 614 + 615 + // ---------- share view ---------- 616 + 617 + // in share view, the URL contains ?handle=X&q=Y. we treat the page as a 618 + // public read-only browser of someone else's saved pack. no auth needed, 619 + // no save/delete buttons, search bar pre-filled and submitted with the 620 + // query from the URL. 621 + let shareView = false; 622 + 623 + async function enterShareView(handle, query) { 624 + shareView = true; 625 + me = { handle }; 626 + signedOutSection.classList.add("hidden"); 627 + signedInSection.classList.remove("hidden"); 628 + signedNav.classList.remove("hidden"); 629 + sessionHandleEl.textContent = `@${handle}`; 630 + // visual hint: show whose pack we're viewing in the meta line 631 + packMetaEl.classList.remove("hidden"); 632 + packStatsEl.textContent = `viewing @${handle}`; 633 + packStateEl.textContent = "shared view"; 634 + packSaveBtn.classList.add("hidden"); 635 + packDeleteBtn.classList.add("hidden"); 636 + // hide signout — there's no session to clear 637 + signoutBtn.classList.add("hidden"); 638 + 639 + startProgress(); 640 + 641 + // public lazy-load: walks the target's PDS via CAR + reuses every vector 642 + // from their saved pack. no auth, nothing gets written. 643 + try { 644 + const r = await fetch(`/api/share-load/${encodeURIComponent(handle)}`, { method: "POST" }); 645 + if (!r.ok) { 646 + stopProgress(); 647 + const j = await r.json().catch(() => ({})); 648 + showStatus(`couldn't load @${handle}'s pack: ${j.error || r.statusText}. they may not have saved one yet.`, true); 649 + return; 650 + } 651 + } catch (e) { 652 + stopProgress(); 653 + showStatus(`network error: ${e.message}`, true); 654 + return; 655 + } 656 + 657 + // poll until ready, then run the query 658 + await pollUntilReady(); 659 + searchInput.value = query; 660 + runSearch(query); 661 + } 662 + 663 + async function pollUntilReady() { 664 + return new Promise((resolve) => { 665 + const tick = async () => { 666 + try { 667 + const r = await fetch(`/api/status/${encodeURIComponent(me.handle)}`); 668 + const j = await r.json(); 669 + if (j.status === "ready") { 670 + stopProgress(); 671 + searchForm.classList.remove("hidden"); 672 + // refresh meta line with the real counts 673 + const colCount = j.collections?.length || 0; 674 + const rec = j.count || 0; 675 + packStatsEl.textContent = `${rec.toLocaleString()} records · ${colCount} collections · @${me.handle}`; 676 + resolve(); 677 + return; 678 + } 679 + if (j.status === "error") { 680 + stopProgress(); 681 + showStatus(`load failed: ${j.error_msg || "unknown"}`, true); 682 + resolve(); 683 + return; 684 + } 685 + // still loading 686 + if (j.records_embedded > 0 || j.records_reused > 0) { 687 + updateProgress({ 688 + fetched: j.records_fetched || 0, 689 + embedded: j.records_embedded || 0, 690 + reused: j.records_reused || 0, 691 + walking: false, 692 + }); 693 + } else { 694 + updateProgress({ fetched: 0, embedded: 0, reused: 0, walking: true }); 695 + } 696 + } catch {} 697 + setTimeout(tick, 800); 698 + }; 699 + tick(); 700 + }); 701 + } 702 + 703 + // ---------- share button + modal ---------- 704 + 705 + const shareBtn = document.createElement("button"); 706 + shareBtn.type = "button"; 707 + shareBtn.className = "text-btn text-btn-accent"; 708 + shareBtn.textContent = "share"; 709 + shareBtn.id = "share-btn"; 710 + 711 + const shareOverlay = document.createElement("div"); 712 + shareOverlay.className = "overlay hidden"; 713 + shareOverlay.id = "share-overlay"; 714 + const shareModal = document.createElement("div"); 715 + shareModal.className = "modal hidden"; 716 + shareModal.id = "share-modal"; 717 + shareModal.setAttribute("role", "dialog"); 718 + shareModal.innerHTML = ` 719 + <h2>share this search</h2> 720 + <p>the link will let anyone load <span id="share-handle"></span>'s saved pack and run this query against it.</p> 721 + <p>records on a public PDS are already publicly readable — this just gives a friendlier interface to browse them by meaning. nothing new is exposed by sharing.</p> 722 + <input id="share-url" type="text" readonly /> 723 + <div class="share-actions"> 724 + <button id="share-copy" type="button">copy link</button> 725 + <button id="share-cancel" type="button">close</button> 726 + </div> 727 + `; 728 + document.body.appendChild(shareOverlay); 729 + document.body.appendChild(shareModal); 730 + 731 + const shareHandleEl = shareModal.querySelector("#share-handle"); 732 + const shareUrlEl = shareModal.querySelector("#share-url"); 733 + const shareCopyBtn = shareModal.querySelector("#share-copy"); 734 + const shareCancelBtn = shareModal.querySelector("#share-cancel"); 735 + 736 + let lastQuery = ""; 737 + 738 + function buildShareUrl() { 739 + if (!me) return ""; 740 + const params = new URLSearchParams(); 741 + params.set("handle", me.handle); 742 + if (lastQuery) params.set("q", lastQuery); 743 + return `${window.location.origin}/?${params.toString()}`; 744 + } 745 + 746 + function openShareModal() { 747 + if (!me || !lastQuery) return; 748 + shareHandleEl.textContent = `@${me.handle}`; 749 + shareUrlEl.value = buildShareUrl(); 750 + shareOverlay.classList.remove("hidden"); 751 + shareModal.classList.remove("hidden"); 752 + } 753 + function closeShareModal() { 754 + shareOverlay.classList.add("hidden"); 755 + shareModal.classList.add("hidden"); 756 + } 757 + 758 + shareCancelBtn.addEventListener("click", closeShareModal); 759 + shareOverlay.addEventListener("click", closeShareModal); 760 + shareCopyBtn.addEventListener("click", async () => { 761 + const url = shareUrlEl.value; 762 + try { 763 + await navigator.clipboard.writeText(url); 764 + shareCopyBtn.textContent = "copied!"; 765 + setTimeout(() => (shareCopyBtn.textContent = "copy link"), 1500); 766 + } catch { 767 + // fallback for older browsers / no clipboard permission 768 + shareUrlEl.select(); 769 + document.execCommand("copy"); 770 + shareCopyBtn.textContent = "copied!"; 771 + setTimeout(() => (shareCopyBtn.textContent = "copy link"), 1500); 772 + } 773 + }); 774 + shareBtn.addEventListener("click", () => { 775 + // mobile native share if available, otherwise modal 776 + if (navigator.share && /Mobi|Android|iPhone/i.test(navigator.userAgent)) { 777 + navigator 778 + .share({ 779 + title: `ken — "${lastQuery}" in @${me.handle}'s records`, 780 + url: buildShareUrl(), 781 + }) 782 + .catch(() => openShareModal()); 783 + } else { 784 + openShareModal(); 785 + } 786 + }); 787 + 788 + // inject the share button into the meta line, after the delete button 789 + packMetaEl.appendChild(shareBtn); 790 + 791 + // hide share until there's a query to share 792 + function updateShareButtonVisibility() { 793 + if (lastQuery && me && !shareView) { 794 + shareBtn.classList.remove("hidden"); 795 + } else { 796 + shareBtn.classList.add("hidden"); 797 + } 798 + } 799 + shareBtn.classList.add("hidden"); 800 + 801 + 802 + // ---------- bootstrap ---------- 803 + 804 + async function init() { 805 + // clean any leftover ?logged_in= query from the oauth callback 806 + const params = new URLSearchParams(window.location.search); 807 + if (params.has("logged_in")) { 808 + window.history.replaceState({}, "", window.location.pathname); 809 + } 810 + 811 + // share-view path: ?handle=X&q=Y means a recipient is opening someone 812 + // else's shared search. skip auth and lazy-load the target's pack. 813 + const shareHandle = params.get("handle"); 814 + const shareQuery = params.get("q"); 815 + if (shareHandle && shareQuery) { 816 + enterShareView(shareHandle, shareQuery); 817 + return; 818 + } 819 + 820 + try { 821 + const r = await fetch("/api/me"); 822 + if (r.ok) { 823 + const j = await r.json(); 824 + showSignedIn(j.handle); 825 + kickoffIndexing(); 826 + return; 827 + } 828 + } catch {} 829 + showSignedOut(); 830 + } 831 + 832 + init();
+79
backend/src/assets/nsid-logo.js
··· 1 + // nsid-logo.js 2 + // --------------------------------------------------------------------------- 3 + // SPECIAL CASING — isolated by design so it can be ripped out. 4 + // 5 + // atproto lexicon NSIDs look like `{tld}.{sld}.{...}`. by convention the 6 + // `{sld}.{tld}` (first two segments reversed) is the app's owning domain, 7 + // and that domain almost always has a bluesky profile with an avatar. so: 8 + // 9 + // sh.tangled.feed.post → tangled.sh → (redirect) tangled.org → profile 10 + // app.bsky.feed.post → bsky.app → profile 11 + // fm.plyr.track → plyr.fm → profile 12 + // tech.waow.ken.pack → waow.tech → profile 13 + // 14 + // the heuristic is lifted from the sibling at-me project. when/if a real 15 + // lexicon-metadata service exists, this file deletes cleanly: remove the 16 + // <script> tag in index.html and any `window.NsidLogo.*` calls in main.js. 17 + // nothing else depends on it. 18 + // --------------------------------------------------------------------------- 19 + 20 + (function () { 21 + // handles that have moved between domains but whose NSIDs still use the 22 + // old prefix. add entries when you see a project migrate. 23 + const DOMAIN_REDIRECTS = { 24 + "tangled.sh": "tangled.org", 25 + }; 26 + 27 + // nsid → Promise<string|null>. caches negative results too so we don't 28 + // re-hit the appview for every render. 29 + const cache = new Map(); 30 + 31 + function nsidToDomain(nsid) { 32 + if (!nsid || typeof nsid !== "string") return null; 33 + const parts = nsid.split("."); 34 + if (parts.length < 2) return null; 35 + const reversed = `${parts[1]}.${parts[0]}`; 36 + return DOMAIN_REDIRECTS[reversed] || reversed; 37 + } 38 + 39 + async function fetchAvatarForNsid(nsid) { 40 + if (cache.has(nsid)) return cache.get(nsid); 41 + const domain = nsidToDomain(nsid); 42 + if (!domain) { 43 + cache.set(nsid, Promise.resolve(null)); 44 + return null; 45 + } 46 + const p = (async () => { 47 + try { 48 + const r = await fetch( 49 + `https://public.api.bsky.app/xrpc/app.bsky.actor.getProfile?actor=${encodeURIComponent(domain)}`, 50 + { signal: AbortSignal.timeout(3500) }, 51 + ); 52 + if (!r.ok) return null; 53 + const j = await r.json(); 54 + return j.avatar || null; 55 + } catch (e) { 56 + return null; 57 + } 58 + })(); 59 + cache.set(nsid, p); 60 + return p; 61 + } 62 + 63 + // fetch avatars for a set of NSIDs in parallel. returns a Map keyed by 64 + // the original nsid → url | null. caller can use it to paint logos into 65 + // already-rendered DOM without blocking the initial render. 66 + async function fetchAvatarsForNsids(nsids) { 67 + const unique = [...new Set(nsids)]; 68 + const entries = await Promise.all( 69 + unique.map(async (nsid) => [nsid, await fetchAvatarForNsid(nsid)]), 70 + ); 71 + return new Map(entries); 72 + } 73 + 74 + window.NsidLogo = { 75 + nsidToDomain, 76 + fetchAvatarForNsid, 77 + fetchAvatarsForNsids, 78 + }; 79 + })();
+19
backend/src/assets/og.svg
··· 1 + <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1200 630"> 2 + <rect width="1200" height="630" fill="#0d0d0f"/> 3 + <rect x="80" y="80" width="160" height="160" rx="24" fill="#d4a76a"/> 4 + <text x="160" y="200" text-anchor="middle" 5 + font-family="ui-monospace, 'SF Mono', Menlo, monospace" 6 + font-size="120" font-weight="700" fill="#0d0d0f">k</text> 7 + <text x="280" y="180" text-anchor="start" 8 + font-family="ui-monospace, 'SF Mono', Menlo, monospace" 9 + font-size="92" font-weight="500" fill="#e6e2d7">ken</text> 10 + <text x="80" y="380" text-anchor="start" 11 + font-family="ui-monospace, 'SF Mono', Menlo, monospace" 12 + font-size="44" font-weight="400" fill="#e6e2d7">fuzzy find any record</text> 13 + <text x="80" y="440" text-anchor="start" 14 + font-family="ui-monospace, 'SF Mono', Menlo, monospace" 15 + font-size="44" font-weight="400" fill="#e6e2d7">in your atproto repo.</text> 16 + <text x="80" y="560" text-anchor="start" 17 + font-family="ui-monospace, 'SF Mono', Menlo, monospace" 18 + font-size="28" font-weight="400" fill="#7a766c">vectors live on your own PDS.</text> 19 + </svg>
+505
backend/src/assets/style.css
··· 1 + /* ken — minimal, dark, monospace, mobile-first. 2 + * palette kept to 5 custom properties. single accent (warm amber) so the 3 + * page has one clear interaction color. sizing via clamp() throughout so 4 + * the layout breathes from 360px phones up to wide desktops with no 5 + * breakpoints. soft radii, subtle transitions. 6 + */ 7 + 8 + :root { 9 + --bg: #0d0d0f; 10 + --fg: #e6e2d7; 11 + --fg-mute: #7a766c; 12 + --accent: #d4a76a; 13 + --accent-dim: #8a6b40; 14 + --border: #242329; 15 + --surface: #16161a; 16 + --surface-hover: #1c1c22; 17 + --danger: #e08872; 18 + --danger-dim: #4a2a22; 19 + 20 + --radius: clamp(6px, 0.6vmin, 10px); 21 + --gap: clamp(10px, 2vmin, 16px); 22 + --pad: clamp(14px, 3vmin, 22px); 23 + --text-body: clamp(13px, 1.6vmin, 15px); 24 + --text-small: clamp(11px, 1.3vmin, 12px); 25 + --text-h: clamp(15px, 2vmin, 18px); 26 + } 27 + 28 + * { box-sizing: border-box; margin: 0; padding: 0; } 29 + 30 + html, body { 31 + background: var(--bg); 32 + color: var(--fg); 33 + font: var(--text-body) / 1.5 ui-monospace, "SF Mono", Menlo, "Cascadia Code", monospace; 34 + min-height: 100vh; 35 + -webkit-font-smoothing: antialiased; 36 + -moz-osx-font-smoothing: grayscale; 37 + } 38 + 39 + a { 40 + color: var(--accent); 41 + text-decoration: none; 42 + border-bottom: 1px solid transparent; 43 + transition: border-color 0.15s; 44 + } 45 + a:hover { border-bottom-color: var(--accent); } 46 + 47 + button { 48 + font: inherit; 49 + cursor: pointer; 50 + background: transparent; 51 + color: var(--fg); 52 + border: 1px solid var(--border); 53 + padding: clamp(8px, 1.4vmin, 10px) clamp(12px, 2vmin, 16px); 54 + border-radius: var(--radius); 55 + transition: all 0.15s; 56 + min-height: 44px; 57 + } 58 + button:hover:not(:disabled) { 59 + border-color: var(--fg-mute); 60 + } 61 + button:disabled { opacity: 0.5; cursor: wait; } 62 + 63 + .hidden { display: none !important; } 64 + .muted { color: var(--fg-mute); } 65 + 66 + /* ------- header ------- */ 67 + 68 + header { 69 + display: flex; 70 + align-items: baseline; 71 + justify-content: space-between; 72 + gap: var(--gap); 73 + padding: clamp(14px, 2.5vmin, 22px) var(--pad); 74 + } 75 + 76 + .brand { 77 + font-size: var(--text-h); 78 + color: var(--fg); 79 + border-bottom: none; 80 + letter-spacing: -0.01em; 81 + } 82 + .brand:hover { border-bottom: none; color: var(--accent); } 83 + 84 + .signed-nav { 85 + display: flex; 86 + align-items: baseline; 87 + gap: 8px; 88 + font-size: var(--text-small); 89 + color: var(--fg-mute); 90 + font-variant-numeric: tabular-nums; 91 + } 92 + .session-handle { color: var(--accent); } 93 + .sep { color: var(--fg-mute); opacity: 0.4; } 94 + 95 + .text-btn { 96 + background: transparent; 97 + border: none; 98 + padding: 0; 99 + min-height: 0; 100 + color: var(--fg-mute); 101 + font: inherit; 102 + font-size: var(--text-small); 103 + cursor: pointer; 104 + transition: color 0.15s; 105 + } 106 + .text-btn:hover { color: var(--fg); } 107 + 108 + /* ------- footer ------- */ 109 + 110 + footer { 111 + display: flex; 112 + justify-content: center; 113 + align-items: baseline; 114 + gap: 8px; 115 + padding: clamp(20px, 4vmin, 40px) var(--pad); 116 + font-size: var(--text-small); 117 + color: var(--fg-mute); 118 + } 119 + footer .text-btn { color: var(--fg-mute); } 120 + footer .text-btn:hover { color: var(--accent); } 121 + footer a { color: var(--fg-mute); } 122 + footer a:hover { color: var(--accent); } 123 + 124 + /* ------- main ------- */ 125 + 126 + main { 127 + max-width: 820px; 128 + margin: 0 auto; 129 + padding: var(--pad); 130 + padding-bottom: clamp(40px, 8vmin, 80px); 131 + } 132 + 133 + .row { 134 + position: relative; 135 + margin-bottom: var(--gap); 136 + display: flex; 137 + gap: var(--gap); 138 + } 139 + 140 + .row input[type="text"], 141 + #search-form input { 142 + flex: 1; 143 + background: var(--surface); 144 + border: 1px solid var(--border); 145 + color: var(--fg); 146 + padding: clamp(12px, 2vmin, 16px); 147 + font: inherit; 148 + font-size: var(--text-body); 149 + border-radius: var(--radius); 150 + outline: none; 151 + transition: border-color 0.15s; 152 + min-height: 44px; 153 + } 154 + .row input[type="text"]:focus, 155 + #search-form input:focus { border-color: var(--accent-dim); } 156 + .row input[type="text"]::placeholder, 157 + #search-form input::placeholder { color: var(--fg-mute); } 158 + 159 + #signed-out .row button { 160 + background: var(--accent-dim); 161 + color: var(--fg); 162 + border-color: var(--accent-dim); 163 + } 164 + #signed-out .row button:hover { 165 + background: var(--accent); 166 + color: var(--bg); 167 + border-color: var(--accent); 168 + } 169 + 170 + /* ------- typeahead ------- */ 171 + 172 + .typeahead { 173 + position: relative; 174 + flex: 1; 175 + display: flex; 176 + } 177 + .typeahead input { flex: 1; } 178 + 179 + .suggestions { 180 + position: absolute; 181 + top: calc(100% + 4px); 182 + left: 0; 183 + right: 0; 184 + list-style: none; 185 + background: var(--surface); 186 + border: 1px solid var(--border); 187 + border-radius: var(--radius); 188 + overflow: hidden; 189 + z-index: 10; 190 + max-height: 320px; 191 + overflow-y: auto; 192 + } 193 + .suggestions li { 194 + display: flex; 195 + align-items: center; 196 + gap: 10px; 197 + padding: clamp(8px, 1.5vmin, 12px) clamp(10px, 2vmin, 14px); 198 + cursor: pointer; 199 + font-size: var(--text-small); 200 + border-bottom: 1px solid var(--border); 201 + min-height: 44px; 202 + } 203 + .suggestions li:last-child { border-bottom: none; } 204 + .suggestions li.active, 205 + .suggestions li:hover { background: var(--surface-hover); } 206 + .suggestions img, 207 + .suggestions .avatar-placeholder { 208 + width: 24px; 209 + height: 24px; 210 + border-radius: 50%; 211 + background: var(--border); 212 + flex-shrink: 0; 213 + object-fit: cover; 214 + } 215 + .suggestions .handle { color: var(--fg); } 216 + .suggestions .display { color: var(--fg-mute); } 217 + 218 + /* ------- status / progress ------- */ 219 + 220 + .status { 221 + font-size: var(--text-small); 222 + color: var(--fg-mute); 223 + margin-bottom: var(--gap); 224 + padding: clamp(10px, 1.8vmin, 14px); 225 + background: var(--surface); 226 + border: 1px solid var(--border); 227 + border-radius: var(--radius); 228 + } 229 + .status.error { color: var(--danger); border-color: var(--danger-dim); } 230 + 231 + .progress { 232 + margin-bottom: var(--gap); 233 + padding: clamp(12px, 2vmin, 16px); 234 + background: var(--surface); 235 + border: 1px solid var(--border); 236 + border-radius: var(--radius); 237 + } 238 + .progress-bar { 239 + height: 4px; 240 + background: var(--border); 241 + border-radius: 2px; 242 + overflow: hidden; 243 + margin-bottom: 10px; 244 + } 245 + .progress-fill { 246 + height: 100%; 247 + background: var(--accent); 248 + width: 0%; 249 + transition: width 0.4s ease-out; 250 + position: relative; 251 + } 252 + .progress-fill::after { 253 + content: ""; 254 + position: absolute; 255 + inset: 0; 256 + background: linear-gradient(90deg, transparent, rgba(255,255,255,0.12), transparent); 257 + animation: shimmer 1.4s linear infinite; 258 + } 259 + @keyframes shimmer { 260 + 0% { transform: translateX(-100%); } 261 + 100% { transform: translateX(100%); } 262 + } 263 + .progress-meta { 264 + display: flex; 265 + justify-content: space-between; 266 + gap: var(--gap); 267 + font-size: var(--text-small); 268 + margin-bottom: 6px; 269 + font-variant-numeric: tabular-nums; 270 + flex-wrap: wrap; 271 + } 272 + .progress-flavor { 273 + font-size: var(--text-small); 274 + min-height: 1.4em; 275 + transition: opacity 0.2s; 276 + } 277 + 278 + /* ------- pack meta (single compact line) ------- */ 279 + 280 + .pack-meta { 281 + display: flex; 282 + align-items: baseline; 283 + flex-wrap: wrap; 284 + gap: 8px; 285 + font-size: var(--text-small); 286 + padding: clamp(10px, 1.5vmin, 14px) clamp(2px, 1vmin, 6px); 287 + font-variant-numeric: tabular-nums; 288 + } 289 + 290 + .text-btn-accent { 291 + color: var(--accent); 292 + font-size: var(--text-small); 293 + } 294 + .text-btn-accent:hover { color: var(--fg); } 295 + 296 + .text-btn-danger { 297 + color: var(--danger); 298 + font-size: var(--text-small); 299 + } 300 + .text-btn-danger:hover { color: var(--fg); } 301 + 302 + #share-btn { margin-left: auto; } 303 + 304 + #share-modal input[type="text"] { 305 + width: 100%; 306 + background: var(--bg); 307 + border: 1px solid var(--border); 308 + color: var(--fg); 309 + padding: 10px 12px; 310 + font: inherit; 311 + font-size: var(--text-small); 312 + border-radius: var(--radius); 313 + margin-top: clamp(10px, 2vmin, 14px); 314 + outline: none; 315 + } 316 + #share-modal input[type="text"]:focus { border-color: var(--accent-dim); } 317 + #share-modal .share-actions { 318 + display: flex; 319 + gap: 8px; 320 + margin-top: clamp(12px, 2vmin, 16px); 321 + } 322 + #share-modal .share-actions button { flex: 1; } 323 + #share-modal #share-copy { 324 + background: var(--accent-dim); 325 + border-color: var(--accent-dim); 326 + color: var(--fg); 327 + } 328 + #share-modal #share-copy:hover { 329 + background: var(--accent); 330 + color: var(--bg); 331 + border-color: var(--accent); 332 + } 333 + 334 + /* ------- results ------- */ 335 + 336 + .results { 337 + display: flex; 338 + flex-direction: column; 339 + gap: 6px; 340 + } 341 + 342 + .result { 343 + background: var(--surface); 344 + border: 1px solid var(--border); 345 + border-radius: var(--radius); 346 + padding: clamp(10px, 1.8vmin, 14px); 347 + display: grid; 348 + grid-template-columns: 1fr auto; 349 + gap: 4px 12px; 350 + transition: background 0.1s, border-color 0.1s; 351 + } 352 + .result:hover { 353 + background: var(--surface-hover); 354 + border-color: var(--fg-mute); 355 + } 356 + 357 + .result .head { 358 + grid-column: 1; 359 + display: flex; 360 + align-items: baseline; 361 + gap: 10px; 362 + min-width: 0; 363 + } 364 + 365 + .result .collection-chip { 366 + display: inline-flex; 367 + align-items: center; 368 + gap: 6px; 369 + min-width: 0; 370 + } 371 + .result .collection-name { 372 + font-size: var(--text-small); 373 + color: var(--fg-mute); 374 + white-space: nowrap; 375 + overflow: hidden; 376 + text-overflow: ellipsis; 377 + } 378 + .result .collection-logo-slot { 379 + width: 14px; 380 + height: 14px; 381 + border-radius: 3px; 382 + background: var(--border); 383 + flex-shrink: 0; 384 + display: inline-flex; 385 + overflow: hidden; 386 + } 387 + .result .collection-logo { 388 + width: 100%; 389 + height: 100%; 390 + object-fit: cover; 391 + } 392 + 393 + .result .date { 394 + font-size: var(--text-small); 395 + color: var(--fg-mute); 396 + margin-left: auto; 397 + white-space: nowrap; 398 + } 399 + 400 + .result .body { 401 + grid-column: 1; 402 + font-size: var(--text-body); 403 + line-height: 1.45; 404 + display: -webkit-box; 405 + -webkit-line-clamp: 2; 406 + -webkit-box-orient: vertical; 407 + overflow: hidden; 408 + word-break: break-word; 409 + } 410 + 411 + .result .body-second { 412 + grid-column: 1; 413 + font-size: var(--text-small); 414 + color: var(--fg-mute); 415 + line-height: 1.45; 416 + display: -webkit-box; 417 + -webkit-line-clamp: 2; 418 + -webkit-box-orient: vertical; 419 + overflow: hidden; 420 + word-break: break-word; 421 + } 422 + 423 + .result .actions { 424 + grid-column: 2; 425 + grid-row: 1 / span 3; 426 + display: flex; 427 + align-items: start; 428 + } 429 + .result .pdsls-link { 430 + color: var(--fg-mute); 431 + font-size: var(--text-small); 432 + padding: 4px 8px; 433 + border-radius: 4px; 434 + border: 1px solid transparent; 435 + } 436 + .result .pdsls-link:hover { 437 + color: var(--accent); 438 + border-color: var(--border); 439 + } 440 + 441 + /* ------- about modal ------- */ 442 + 443 + .overlay { 444 + position: fixed; 445 + inset: 0; 446 + background: rgba(0, 0, 0, 0.6); 447 + z-index: 100; 448 + } 449 + 450 + .modal { 451 + position: fixed; 452 + top: 50%; 453 + left: 50%; 454 + transform: translate(-50%, -50%); 455 + background: var(--surface); 456 + border: 1px solid var(--border); 457 + border-radius: var(--radius); 458 + padding: clamp(20px, 4vmin, 32px); 459 + max-width: min(92vw, 460px); 460 + z-index: 101; 461 + } 462 + .modal h2 { 463 + font-size: var(--text-h); 464 + margin-bottom: clamp(10px, 2vmin, 14px); 465 + color: var(--fg); 466 + font-weight: 500; 467 + } 468 + .modal p { 469 + font-size: var(--text-small); 470 + line-height: 1.6; 471 + color: var(--fg); 472 + margin-bottom: clamp(8px, 1.5vmin, 12px); 473 + } 474 + .modal p:last-of-type { color: var(--fg-mute); } 475 + .modal ul.about-links { 476 + list-style: none; 477 + margin: clamp(14px, 2.5vmin, 20px) 0; 478 + display: flex; 479 + flex-direction: column; 480 + gap: 6px; 481 + } 482 + .modal ul.about-links a { font-size: var(--text-small); } 483 + .modal button { 484 + margin-top: clamp(10px, 2vmin, 14px); 485 + font-size: var(--text-small); 486 + } 487 + 488 + /* ------- mobile layout tweaks ------- */ 489 + 490 + @media (max-width: 520px) { 491 + header { gap: 8px; } 492 + .stats { flex-basis: 100%; order: 5; margin-left: 0; } 493 + #signed-out .row { flex-direction: column; } 494 + #signed-out .row button { width: 100%; } 495 + .pack-actions-row { flex-direction: column; align-items: stretch; } 496 + .pack-btn { width: 100%; } 497 + .result { 498 + grid-template-columns: 1fr; 499 + } 500 + .result .actions { 501 + grid-column: 1; 502 + grid-row: auto; 503 + justify-content: flex-end; 504 + } 505 + }
+159
backend/src/display.zig
··· 1 + //! display heuristic — port of embedder/build_pack.py display_for_record. 2 + //! 3 + //! produces {title, body, date} strings for a record card, distinct from the 4 + //! embedding text. the embed text is machine-readable flattened json; these 5 + //! fields are what a human sees in a search result. 6 + 7 + const std = @import("std"); 8 + const json = std.json; 9 + const Allocator = std.mem.Allocator; 10 + 11 + const TITLE_FIELDS = [_][]const u8{ 12 + "title", "name", "displayName", "label", "subject", "summary", 13 + }; 14 + 15 + const BODY_FIELDS = [_][]const u8{ 16 + "text", "content", "body", "description", "plaintext", "code", 17 + "bio", "about", "message", 18 + }; 19 + 20 + const DATE_FIELDS = [_][]const u8{ 21 + "createdAt", "publishedAt", "indexedAt", "updatedAt", 22 + "endedAt", "endDate", "startedAt", "startDate", "issuedAt", 23 + }; 24 + 25 + pub const Display = struct { 26 + title: []const u8, 27 + body: []const u8, 28 + date: []const u8, 29 + }; 30 + 31 + fn isIsoPrefix(s: []const u8) bool { 32 + if (s.len < 10) return false; 33 + for (0..4) |i| if (!std.ascii.isDigit(s[i])) return false; 34 + if (s[4] != '-') return false; 35 + for (5..7) |i| if (!std.ascii.isDigit(s[i])) return false; 36 + if (s[7] != '-') return false; 37 + for (8..10) |i| if (!std.ascii.isDigit(s[i])) return false; 38 + return true; 39 + } 40 + 41 + fn looksLikeIdentifier(s: []const u8) bool { 42 + if (std.mem.startsWith(u8, s, "did:")) return true; 43 + if (std.mem.startsWith(u8, s, "at://")) return true; 44 + if (std.mem.startsWith(u8, s, "bafy") or std.mem.startsWith(u8, s, "bafk")) return true; 45 + return false; 46 + } 47 + 48 + /// extract a plain string from a json value suitable for display. 49 + /// returns null if the value has no usable content. 50 + fn stringifyForDisplay(v: json.Value) ?[]const u8 { 51 + switch (v) { 52 + .string => |s| { 53 + const trimmed = std.mem.trim(u8, s, " \t\n\r"); 54 + if (trimmed.len == 0) return null; 55 + if (looksLikeIdentifier(trimmed)) return null; 56 + if (isIsoPrefix(trimmed)) return trimmed[0..10]; 57 + return trimmed; 58 + }, 59 + .object => |obj| { 60 + inline for (&[_][]const u8{ "text", "name", "displayName", "title" }) |k| { 61 + if (obj.get(k)) |inner| { 62 + if (stringifyForDisplay(inner)) |out| return out; 63 + } 64 + } 65 + return null; 66 + }, 67 + else => return null, 68 + } 69 + } 70 + 71 + /// walk the whole record looking for the first usable string leaf. 72 + /// used as a fallback when no TITLE_FIELDS matched. 73 + fn findFirstString(value: json.Value) ?[]const u8 { 74 + switch (value) { 75 + .string => |s| { 76 + if (s.len < 3) return null; 77 + if (looksLikeIdentifier(s)) return null; 78 + if (isIsoPrefix(s)) return null; 79 + return s; 80 + }, 81 + .object => |obj| { 82 + var it = obj.iterator(); 83 + while (it.next()) |entry| { 84 + const k = entry.key_ptr.*; 85 + if (k.len == 0 or k[0] == '$') continue; 86 + if (std.mem.eql(u8, k, "cid") or std.mem.eql(u8, k, "rev") or 87 + std.mem.eql(u8, k, "sig") or std.mem.eql(u8, k, "version")) continue; 88 + if (findFirstString(entry.value_ptr.*)) |out| return out; 89 + } 90 + return null; 91 + }, 92 + .array => |arr| { 93 + for (arr.items) |item| { 94 + if (findFirstString(item)) |out| return out; 95 + } 96 + return null; 97 + }, 98 + else => return null, 99 + } 100 + } 101 + 102 + fn clip(arena: Allocator, s: []const u8, max: usize) ![]const u8 { 103 + const take = @min(s.len, max); 104 + return try arena.dupe(u8, s[0..take]); 105 + } 106 + 107 + pub fn displayForRecord(arena: Allocator, value: json.Value) !Display { 108 + var title: []const u8 = ""; 109 + var body: []const u8 = ""; 110 + var date: []const u8 = ""; 111 + 112 + if (value != .object) { 113 + const fallback = findFirstString(value) orelse ""; 114 + return .{ 115 + .title = try clip(arena, fallback, 120), 116 + .body = "", 117 + .date = "", 118 + }; 119 + } 120 + 121 + // title 122 + for (TITLE_FIELDS) |k| { 123 + if (value.object.get(k)) |v| { 124 + if (stringifyForDisplay(v)) |s| { 125 + title = try clip(arena, s, 120); 126 + break; 127 + } 128 + } 129 + } 130 + 131 + // body 132 + for (BODY_FIELDS) |k| { 133 + if (value.object.get(k)) |v| { 134 + if (stringifyForDisplay(v)) |s| { 135 + body = try clip(arena, s, 300); 136 + break; 137 + } 138 + } 139 + } 140 + 141 + // date 142 + for (DATE_FIELDS) |k| { 143 + if (value.object.get(k)) |v| { 144 + if (v == .string and isIsoPrefix(v.string)) { 145 + date = try arena.dupe(u8, v.string[0..10]); 146 + break; 147 + } 148 + } 149 + } 150 + 151 + // fallback: if no title found, use the first usable string leaf 152 + if (title.len == 0) { 153 + if (findFirstString(value)) |s| { 154 + title = try clip(arena, s, 120); 155 + } 156 + } 157 + 158 + return .{ .title = title, .body = body, .date = date }; 159 + }
+977
backend/src/indexer.zig
··· 1 + //! in-memory index of embedded PDS records per DID. 2 + //! 3 + //! the cache is a DID → IndexedPack map, mutex-protected. indexing runs on a 4 + //! worker thread: status starts as `indexing`, transitions to `ready` or 5 + //! `error` when done. search queries block on the embedder mutex but not on 6 + //! the cache lock (they take a snapshot pointer under the lock, then scan 7 + //! vectors without holding it). 8 + //! 9 + //! this is v1 — read-only, no PDS writes, no persistence. restart = cold cache. 10 + 11 + const std = @import("std"); 12 + const Io = std.Io; 13 + const Allocator = std.mem.Allocator; 14 + const json = std.json; 15 + const zat = @import("zat"); 16 + 17 + const pds = @import("pds.zig"); 18 + const repo_walk = @import("repo_walk.zig"); 19 + const record_text = @import("record_text.zig"); 20 + const display = @import("display.zig"); 21 + const Llama = @import("llama.zig"); 22 + const store = @import("state.zig"); 23 + const oauth = @import("oauth.zig"); 24 + 25 + const PACK_COLLECTION = "tech.waow.ken.pack"; 26 + /// leave ~1 MB headroom under the 5 MB blob limit so positions/json overhead 27 + /// doesn't push us over. at 384 dims × 4 bytes, this is ~2600 vectors/chunk. 28 + const BLOB_CHUNK_BYTES: usize = 4 * 1024 * 1024; 29 + 30 + pub const Status = enum { indexing, ready, @"error" }; 31 + 32 + pub const Entry = struct { 33 + uri: []const u8, 34 + /// content-addressed id of the record at index time. used by incremental 35 + /// re-index: if a future walk sees the same (uri, cid), we reuse the 36 + /// existing vector instead of re-embedding. if the record's content 37 + /// changed, the PDS assigns a new cid and we re-embed. 38 + cid: []const u8, 39 + collection: []const u8, 40 + title: []const u8, 41 + body: []const u8, 42 + date: []const u8, 43 + }; 44 + 45 + pub const IndexedPack = struct { 46 + arena: std.heap.ArenaAllocator, 47 + handle: []const u8, 48 + did: []const u8, 49 + pds_url: []const u8, 50 + status: Status, 51 + error_msg: []const u8, 52 + entries: []Entry, 53 + /// row-major [count * dim] L2-normalized f32 vectors. populated 54 + /// incrementally during indexing — use `valid` to know which rows have 55 + /// real data yet, since reused and fresh records land at non-contiguous 56 + /// positions. empty until after reserve + first doIndex pass. 57 + vectors: []f32, 58 + /// parallel to entries/vectors: 0 = vector not yet written, 1 = valid. 59 + /// writes are monotonic (0→1 only) so concurrent reads from search are 60 + /// safe even without atomics — worst case a just-written record is 61 + /// missed by the current query and picked up on the next one. 62 + valid: []u8, 63 + dim: usize, 64 + /// per-collection record counts 65 + collections: []CollectionCount, 66 + /// at-uri of the pack record on the user's PDS, if they've chosen to 67 + /// persist it. null until the user clicks "save" in the UI. saving the 68 + /// pack writes a `tech.waow.ken.pack` record + blobs; we store the 69 + /// uri so deletion and status reporting can find the existing record. 70 + /// string lives in the pack's arena. 71 + persisted_uri: ?[]const u8, 72 + /// total wall time of the indexing job in ms — walk + reuse-load + 73 + /// embed. written to the pack manifest when the user saves, so future 74 + /// runs can show calibrated ETA messages. 75 + build_ms: i64, 76 + /// pulled from the prior pack's manifest when loadExistingPack 77 + /// succeeds. 0 if there's no prior pack (or it predates these fields). 78 + /// the frontend uses this + prior_count to show "last time this took Ns 79 + /// for M records" instead of a generic loading message. 80 + prior_build_ms: i64, 81 + prior_count: usize, 82 + /// progress during indexing 83 + records_fetched: usize, 84 + records_embedded: usize, 85 + /// number of records whose vectors were loaded from the user's prior 86 + /// pack on their PDS (same uri + cid → content unchanged). counts as 87 + /// part of `records_embedded` but we report it separately so the UI can 88 + /// show reuse-baseline vs freshly-embedded progress and compute rate / 89 + /// ETA against only the fresh work. 90 + records_reused: usize, 91 + indexed_at_ms: i64, 92 + 93 + pub fn count(self: *const IndexedPack) usize { 94 + if (self.dim == 0) return 0; 95 + return self.vectors.len / self.dim; 96 + } 97 + 98 + /// number of rows whose vectors have been written so far. used by the 99 + /// status endpoint and by search to gate which rows it scores. 100 + pub fn validCount(self: *const IndexedPack) usize { 101 + var c: usize = 0; 102 + for (self.valid) |v| c += v; 103 + return c; 104 + } 105 + }; 106 + 107 + pub const CollectionCount = struct { 108 + nsid: []const u8, 109 + count: usize, 110 + }; 111 + 112 + pub const Cache = struct { 113 + allocator: Allocator, 114 + mutex: Io.Mutex, 115 + /// map from DID → *IndexedPack. the cache owns the packs. 116 + packs: std.StringHashMap(*IndexedPack), 117 + 118 + pub fn init(allocator: Allocator) Cache { 119 + return .{ 120 + .allocator = allocator, 121 + .mutex = Io.Mutex.init, 122 + .packs = std.StringHashMap(*IndexedPack).init(allocator), 123 + }; 124 + } 125 + 126 + pub fn get(self: *Cache, io: Io, did: []const u8) ?*IndexedPack { 127 + self.mutex.lockUncancelable(io); 128 + defer self.mutex.unlock(io); 129 + return self.packs.get(did); 130 + } 131 + 132 + /// drop a pack from the cache. frees the pack's arena (which contains 133 + /// everything the pack owns: entries, vectors, strings, etc) and the 134 + /// duplicated DID key. safe to call if nothing is cached for this DID. 135 + pub fn remove(self: *Cache, io: Io, did: []const u8) void { 136 + self.mutex.lockUncancelable(io); 137 + defer self.mutex.unlock(io); 138 + const kv = self.packs.fetchRemove(did) orelse return; 139 + self.allocator.free(kv.key); 140 + var arena = kv.value.arena; 141 + arena.deinit(); 142 + } 143 + 144 + /// reserve a slot for a new indexing job. returns the pack if newly 145 + /// created, null if one already exists (caller should return its status). 146 + pub fn reserve( 147 + self: *Cache, 148 + io: Io, 149 + handle: []const u8, 150 + did: []const u8, 151 + pds_url: []const u8, 152 + ) !?*IndexedPack { 153 + self.mutex.lockUncancelable(io); 154 + defer self.mutex.unlock(io); 155 + 156 + if (self.packs.get(did)) |_| return null; 157 + 158 + var arena = std.heap.ArenaAllocator.init(self.allocator); 159 + const arena_alloc = arena.allocator(); 160 + 161 + const pack = try arena_alloc.create(IndexedPack); 162 + pack.* = .{ 163 + .arena = arena, 164 + .handle = try arena_alloc.dupe(u8, handle), 165 + .did = try arena_alloc.dupe(u8, did), 166 + .pds_url = try arena_alloc.dupe(u8, pds_url), 167 + .status = .indexing, 168 + .error_msg = "", 169 + .entries = &.{}, 170 + .vectors = &.{}, 171 + .valid = &.{}, 172 + .dim = 0, 173 + .collections = &.{}, 174 + .persisted_uri = null, 175 + .build_ms = 0, 176 + .prior_build_ms = 0, 177 + .prior_count = 0, 178 + .records_fetched = 0, 179 + .records_embedded = 0, 180 + .records_reused = 0, 181 + .indexed_at_ms = 0, 182 + }; 183 + 184 + // duplicate the DID for the hash map key so the cache owns it 185 + // independently of the pack's arena (in case we ever want to evict). 186 + const key = try self.allocator.dupe(u8, did); 187 + try self.packs.put(key, pack); 188 + 189 + return pack; 190 + } 191 + }; 192 + 193 + /// the per-indexing-job context passed to the worker thread. 194 + pub const Job = struct { 195 + allocator: Allocator, 196 + io: Io, 197 + cache: *Cache, 198 + pack: *IndexedPack, 199 + embedder: *Llama.Embedder, 200 + embedder_mutex: *Io.Mutex, 201 + /// cap per collection for v1 — sampling broadly across many collections 202 + /// beats drowning in a single huge one (same lesson as the python pipeline). 203 + max_per_collection: u32, 204 + }; 205 + 206 + /// entry point for the worker thread. runs the full indexing pipeline and 207 + /// updates the pack's status when done. never returns an error — all failures 208 + /// are recorded in pack.error_msg. 209 + pub fn runJob(job: *Job) void { 210 + doIndex(job) catch |err| { 211 + std.log.err("indexing job failed for {s}: {t}", .{ job.pack.handle, err }); 212 + const msg = std.fmt.allocPrint(job.pack.arena.allocator(), "{t}", .{err}) catch "unknown"; 213 + job.pack.status = .@"error"; 214 + job.pack.error_msg = msg; 215 + }; 216 + } 217 + 218 + fn doIndex(job: *Job) !void { 219 + const pack = job.pack; 220 + const arena = pack.arena.allocator(); 221 + 222 + // track total wall time so we can stamp it into the pack record. lets 223 + // the appview / frontend show calibrated ETA messages ("repos this 224 + // size typically take N seconds") instead of a generic guess. 225 + const job_start_ns: i128 = Io.Timestamp.now(job.io, .real).nanoseconds; 226 + 227 + var transport = zat.HttpTransport.init(job.io, job.allocator); 228 + defer transport.deinit(); 229 + 230 + // per-run scratch arena for HTTP bodies — freed at end of doIndex. lets us 231 + // not balloon pack.arena with the raw listRecords JSON bodies. 232 + var scratch = std.heap.ArenaAllocator.init(job.allocator); 233 + defer scratch.deinit(); 234 + const scratch_alloc = scratch.allocator(); 235 + 236 + std.log.info("indexing @{s} ({s}) from {s}", .{ pack.handle, pack.did, pack.pds_url }); 237 + 238 + // 1. walk the repo. one HTTP call fetches the whole MST + blocks as a 239 + // CAR file; we parse it locally and yield records in lexicographic 240 + // order. if that fails (old PDS, network issue, whatever), fall 241 + // back to the paginated listRecords walker. 242 + var all_records: std.ArrayList(pds.Record) = .empty; 243 + var collection_of: std.ArrayList([]const u8) = .empty; 244 + var per_collection: std.ArrayList(CollectionCount) = .empty; 245 + 246 + if (repo_walk.walkRepo(scratch_alloc, job.io, pack.pds_url, pack.did)) |walked| { 247 + std.log.info(" car walk: {d} records, {d} KB", .{ walked.records.len, walked.car_bytes / 1024 }); 248 + for (walked.records) |r| { 249 + try all_records.append(scratch_alloc, .{ 250 + .uri = r.uri, 251 + .cid = r.cid, 252 + .value = r.value, 253 + }); 254 + try collection_of.append(scratch_alloc, r.collection); 255 + } 256 + pack.records_fetched = all_records.items.len; 257 + 258 + // derive per-collection counts from the flat walk output. records 259 + // come out lex-sorted so same-collection runs are contiguous. 260 + var cursor: usize = 0; 261 + while (cursor < collection_of.items.len) { 262 + const start = cursor; 263 + const coll = collection_of.items[start]; 264 + while (cursor < collection_of.items.len and 265 + std.mem.eql(u8, collection_of.items[cursor], coll)) 266 + { 267 + cursor += 1; 268 + } 269 + try per_collection.append(arena, .{ 270 + .nsid = try arena.dupe(u8, coll), 271 + .count = cursor - start, 272 + }); 273 + } 274 + } else |err| { 275 + std.log.warn(" car walk failed ({t}); falling back to listRecords", .{err}); 276 + 277 + const collections = try pds.describeRepo(scratch_alloc, &transport, pack.pds_url, pack.did); 278 + std.log.info(" {d} collections (listRecords fallback)", .{collections.len}); 279 + 280 + for (collections) |collection| { 281 + var cursor: ?[]const u8 = null; 282 + var count: usize = 0; 283 + while (true) { 284 + const page = pds.listRecords( 285 + scratch_alloc, 286 + &transport, 287 + pack.pds_url, 288 + pack.did, 289 + collection, 290 + cursor, 291 + 100, 292 + ) catch |list_err| { 293 + std.log.warn(" listRecords({s}) failed: {t}", .{ collection, list_err }); 294 + break; 295 + }; 296 + 297 + for (page.records) |r| { 298 + try all_records.append(scratch_alloc, r); 299 + try collection_of.append(scratch_alloc, collection); 300 + count += 1; 301 + pack.records_fetched = all_records.items.len; 302 + if (job.max_per_collection > 0 and count >= job.max_per_collection) break; 303 + } 304 + 305 + if (job.max_per_collection > 0 and count >= job.max_per_collection) break; 306 + cursor = page.cursor; 307 + if (cursor == null or page.records.len == 0) break; 308 + } 309 + 310 + if (count > 0) { 311 + try per_collection.append(arena, .{ 312 + .nsid = try arena.dupe(u8, collection), 313 + .count = count, 314 + }); 315 + } 316 + } 317 + } 318 + 319 + std.log.info(" fetched {d} records total", .{all_records.items.len}); 320 + pack.collections = per_collection.items; 321 + 322 + if (all_records.items.len == 0) { 323 + pack.status = .ready; 324 + pack.indexed_at_ms = @intCast(@divTrunc(Io.Timestamp.now(job.io, .real).nanoseconds, std.time.ns_per_ms)); 325 + return; 326 + } 327 + 328 + // 3. prepare (text, entry) pairs — drop records that fail to produce text 329 + const n = all_records.items.len; 330 + const dim: usize = @intCast(job.embedder.n_embd); 331 + 332 + var all_texts: std.ArrayList([]const u8) = .empty; 333 + var all_entries: std.ArrayList(Entry) = .empty; 334 + try all_texts.ensureTotalCapacity(scratch_alloc, n); 335 + try all_entries.ensureTotalCapacity(scratch_alloc, n); 336 + 337 + for (all_records.items, 0..) |rec, i| { 338 + const collection = collection_of.items[i]; 339 + const text = record_text.recordToText(scratch_alloc, collection, rec.value) catch continue; 340 + if (text.len == 0) continue; 341 + 342 + const disp = display.displayForRecord(arena, rec.value) catch display.Display{ 343 + .title = "", 344 + .body = "", 345 + .date = "", 346 + }; 347 + 348 + try all_texts.append(scratch_alloc, text); 349 + try all_entries.append(scratch_alloc, .{ 350 + .uri = try arena.dupe(u8, rec.uri), 351 + .cid = try arena.dupe(u8, rec.cid), 352 + .collection = try arena.dupe(u8, collection), 353 + .title = disp.title, 354 + .body = disp.body, 355 + .date = disp.date, 356 + }); 357 + } 358 + 359 + const n_prepared = all_texts.items.len; 360 + const vectors = try arena.alloc(f32, n_prepared * dim); 361 + const entries = try arena.alloc(Entry, n_prepared); 362 + const valid = try arena.alloc(u8, n_prepared); 363 + @memset(valid, 0); 364 + 365 + // expose the output buffers on the pack BEFORE the embed loop runs so 366 + // that concurrent search queries can scan the partial pack. rows start 367 + // invalid (valid[i]=0) and flip to valid (=1) as they get written. 368 + pack.vectors = vectors; 369 + pack.entries = entries; 370 + pack.valid = valid; 371 + pack.dim = dim; 372 + @memcpy(entries, all_entries.items); 373 + 374 + // 3b. try to load the previous pack from the user's PDS. if present and 375 + // the embedding dim matches, we'll reuse any entry whose (uri, cid) still 376 + // matches the current repo state. cid comes from atproto's content 377 + // addressing: if the record's bytes changed, its cid changes. same (uri, 378 + // cid) = content is byte-identical = vector is reusable. 379 + // 380 + // reuse_slots[i] holds an index into existing.old_vectors for records 381 + // we're reusing, or null for records we need to embed fresh. 382 + var existing: ?ExistingPack = null; 383 + existing = loadExistingPack(scratch_alloc, &transport, pack.pds_url, pack.did, dim) catch |err| eblk: { 384 + std.log.warn(" loadExistingPack failed ({t}) — full re-embed", .{err}); 385 + break :eblk null; 386 + }; 387 + if (existing) |ex| { 388 + pack.prior_build_ms = ex.prior_build_ms; 389 + pack.prior_count = ex.prior_count; 390 + // a pack exists on PDS — reconstruct its at-uri and mark the 391 + // in-memory pack as persisted so /api/status reports the correct 392 + // state on a fresh sign-in. without this, the frontend would show 393 + // "not saved" even though a pack is already there, prompting the 394 + // user to re-save and accumulate duplicates. 395 + pack.persisted_uri = std.fmt.allocPrint( 396 + pack.arena.allocator(), 397 + "at://{s}/{s}/{s}", 398 + .{ pack.did, PACK_COLLECTION, ex.old_pack_rkey }, 399 + ) catch null; 400 + } 401 + 402 + var reuse_slots = try scratch_alloc.alloc(?usize, n_prepared); 403 + @memset(reuse_slots, null); 404 + var n_reused: usize = 0; 405 + if (existing) |ex| { 406 + for (all_entries.items, 0..) |e, idx| { 407 + const cached = ex.uri_to_cached.get(e.uri) orelse continue; 408 + if (!std.mem.eql(u8, cached.cid, e.cid)) continue; 409 + reuse_slots[idx] = cached.vec_idx; 410 + n_reused += 1; 411 + } 412 + std.log.info(" loaded existing pack: {d} entries, reusing {d}/{d}", .{ 413 + ex.uri_to_cached.count(), 414 + n_reused, 415 + n_prepared, 416 + }); 417 + } 418 + const n_fresh = n_prepared - n_reused; 419 + 420 + // copy reused vectors upfront — they don't touch the embedder. flipping 421 + // valid[i] to 1 publishes the row to any concurrent search. 422 + if (existing) |ex| { 423 + for (0..n_prepared) |i| { 424 + const src_idx = reuse_slots[i] orelse continue; 425 + @memcpy( 426 + vectors[i * dim ..][0..dim], 427 + ex.old_vectors[src_idx * dim ..][0..dim], 428 + ); 429 + valid[i] = 1; 430 + } 431 + } 432 + pack.records_reused = n_reused; 433 + pack.records_embedded = n_reused; 434 + 435 + // copy ALL entries into the final `entries` slice up front. reuse or 436 + // embed doesn't affect the Entry record itself — only the vector. 437 + @memcpy(entries, all_entries.items); 438 + 439 + // 4. batch-embed the fresh (non-reused) records only. gather their 440 + // indices into all_texts/vectors so the batch loop can walk them 441 + // contiguously even when they're interleaved with reused entries. 442 + // 443 + // instrumentation: every LOG_EVERY batches we log wall time, tokens, and 444 + // rate. BERT self-attention is O(L²) per sequence, so a batch of long 445 + // records can take 100× longer than a batch of short records at the 446 + // same batch count — tok/s is the real hardware throughput signal. 447 + const fresh_indices = try scratch_alloc.alloc(usize, n_fresh); 448 + { 449 + var w: usize = 0; 450 + for (0..n_prepared) |i| { 451 + if (reuse_slots[i] == null) { 452 + fresh_indices[w] = i; 453 + w += 1; 454 + } 455 + } 456 + } 457 + 458 + const BATCH_SIZE: usize = 16; 459 + const LOG_EVERY: usize = 8; 460 + var batch_texts: [BATCH_SIZE][]const u8 = undefined; 461 + var batch_out: [BATCH_SIZE][]f32 = undefined; 462 + 463 + var window_records: usize = 0; 464 + var window_tokens: usize = 0; 465 + var window_start_ns: i128 = Io.Timestamp.now(job.io, .real).nanoseconds; 466 + 467 + var pos: usize = 0; 468 + while (pos < n_fresh) { 469 + const take = @min(BATCH_SIZE, n_fresh - pos); 470 + for (0..take) |k| { 471 + const i = fresh_indices[pos + k]; 472 + batch_texts[k] = all_texts.items[i]; 473 + batch_out[k] = vectors[i * dim ..][0..dim]; 474 + } 475 + 476 + var batch_arena = std.heap.ArenaAllocator.init(job.allocator); 477 + defer batch_arena.deinit(); 478 + const batch_alloc = batch_arena.allocator(); 479 + 480 + const batch_start_ns = Io.Timestamp.now(job.io, .real).nanoseconds; 481 + const stats_opt: ?Llama.Embedder.BatchStats = sblk: { 482 + job.embedder_mutex.lockUncancelable(job.io); 483 + defer job.embedder_mutex.unlock(job.io); 484 + break :sblk job.embedder.embedBatch( 485 + batch_alloc, 486 + batch_texts[0..take], 487 + batch_out[0..take], 488 + ) catch |err| { 489 + std.log.warn(" embedBatch failed at fresh pos {d}: {t}", .{ pos, err }); 490 + for (0..take) |k| @memset(batch_out[k], 0); 491 + break :sblk null; 492 + }; 493 + }; 494 + const batch_elapsed_ns = Io.Timestamp.now(job.io, .real).nanoseconds - batch_start_ns; 495 + 496 + // publish all freshly-written rows in this batch to concurrent 497 + // searches. write the vectors *then* the valid bits — on x86 stores 498 + // are ordered, so search will never see valid=1 with stale vector. 499 + for (0..take) |k| valid[fresh_indices[pos + k]] = 1; 500 + 501 + pos += take; 502 + pack.records_embedded = n_reused + pos; 503 + 504 + if (stats_opt) |stats| { 505 + window_records += take; 506 + window_tokens += stats.total_tokens; 507 + } 508 + 509 + if (pos == n_fresh or (pos / BATCH_SIZE) % LOG_EVERY == 0) { 510 + const now_ns = Io.Timestamp.now(job.io, .real).nanoseconds; 511 + const window_ms: f64 = @as(f64, @floatFromInt(now_ns - window_start_ns)) / 1_000_000.0; 512 + const rec_per_s: f64 = if (window_ms > 0) 513 + @as(f64, @floatFromInt(window_records)) * 1000.0 / window_ms 514 + else 515 + 0; 516 + const tok_per_s: f64 = if (window_ms > 0) 517 + @as(f64, @floatFromInt(window_tokens)) * 1000.0 / window_ms 518 + else 519 + 0; 520 + const last_batch_ms: f64 = @as(f64, @floatFromInt(batch_elapsed_ns)) / 1_000_000.0; 521 + const last_tokens: usize = if (stats_opt) |s| s.total_tokens else 0; 522 + const first_i = fresh_indices[pos - take]; 523 + const first_coll = all_entries.items[first_i].collection; 524 + std.log.info( 525 + " fresh {d}/{d} window: {d} rec / {d} tok in {d:.0}ms ({d:.1} rec/s, {d:.0} tok/s); last batch: {d} tok in {d:.0}ms, coll={s}", 526 + .{ 527 + pos, 528 + n_fresh, 529 + window_records, 530 + window_tokens, 531 + window_ms, 532 + rec_per_s, 533 + tok_per_s, 534 + last_tokens, 535 + last_batch_ms, 536 + first_coll, 537 + }, 538 + ); 539 + window_records = 0; 540 + window_tokens = 0; 541 + window_start_ns = now_ns; 542 + } 543 + } 544 + 545 + // pack.vectors/entries/valid/dim were published up front so concurrent 546 + // searches could see partial results; only the status flip happens here. 547 + pack.indexed_at_ms = @intCast(@divTrunc(Io.Timestamp.now(job.io, .real).nanoseconds, std.time.ns_per_ms)); 548 + pack.status = .ready; 549 + 550 + std.log.info(" pack ready: {d} reused, {d} freshly embedded", .{ n_reused, n_fresh }); 551 + 552 + pack.build_ms = @intCast(@divTrunc( 553 + Io.Timestamp.now(job.io, .real).nanoseconds - job_start_ns, 554 + std.time.ns_per_ms, 555 + )); 556 + std.log.info(" total indexing wall time: {d} ms", .{pack.build_ms}); 557 + 558 + // NOTE: no automatic writePackToPds here. writing a record + blobs to 559 + // the user's repo is a mutation they must explicitly opt into. see the 560 + // /api/pack/save endpoint (server.zig) for the opt-in write path. 561 + // loading a prior pack for reuse (read-only) still happens 562 + // automatically above — that doesn't mutate anything. 563 + } 564 + 565 + // --------------------------------------------------------------------------- 566 + // incremental re-index: load an existing pack back from the user's PDS 567 + // --------------------------------------------------------------------------- 568 + 569 + const CachedEntry = struct { 570 + cid: []const u8, 571 + vec_idx: usize, 572 + }; 573 + 574 + const ExistingPack = struct { 575 + uri_to_cached: std.StringHashMap(CachedEntry), 576 + /// row-major [count * dim] f32 vectors, already L2-normalized (that's how 577 + /// we wrote them). 578 + old_vectors: []f32, 579 + old_pack_rkey: []const u8, 580 + /// wall time of the previous full build, pulled straight from the 581 + /// manifest's `buildMs` field. 0 if the prior pack predates that field. 582 + /// used by the frontend to show calibrated "last time this took Ns" 583 + /// messaging on the next run. 584 + prior_build_ms: i64, 585 + /// number of records in the previous pack (manifest's `count` field). 586 + prior_count: usize, 587 + }; 588 + 589 + /// fetch the newest tech.waow.ken.pack record from the user's PDS, download 590 + /// its positions + vectors blobs, and build a uri→(cid, vec_idx) lookup 591 + /// table. unauthenticated: all reads go through public listRecords and 592 + /// sync.getBlob. 593 + /// 594 + /// returns null if no prior pack exists, or if the prior pack was built with 595 + /// a different embedding dim (can't reuse vectors across model changes). 596 + fn loadExistingPack( 597 + arena: Allocator, 598 + transport: *zat.HttpTransport, 599 + pds_url: []const u8, 600 + did: []const u8, 601 + current_dim: usize, 602 + ) !?ExistingPack { 603 + // take only the newest pack record. atproto lists records in descending 604 + // rkey order by default, and our rkeys are TIDs (monotonic time-based), 605 + // so records[0] is the latest. 606 + const page = try pds.listRecords(arena, transport, pds_url, did, PACK_COLLECTION, null, 1); 607 + if (page.records.len == 0) { 608 + std.log.info(" no existing pack on PDS", .{}); 609 + return null; 610 + } 611 + 612 + const pack_record = page.records[0]; 613 + std.log.info(" found existing pack: {s}", .{pack_record.uri}); 614 + const rkey = blk: { 615 + const last_slash = std.mem.lastIndexOfScalar(u8, pack_record.uri, '/') orelse { 616 + std.log.warn(" pack uri missing slash: {s}", .{pack_record.uri}); 617 + return null; 618 + }; 619 + break :blk pack_record.uri[last_slash + 1 ..]; 620 + }; 621 + const old_rkey = try arena.dupe(u8, rkey); 622 + 623 + if (pack_record.value != .object) { 624 + std.log.warn(" pack value is not an object (tag={t})", .{pack_record.value}); 625 + return null; 626 + } 627 + const rec_obj = pack_record.value.object; 628 + 629 + // dim must match — vectors from a different model live in a different 630 + // space and can't be mixed with fresh ones. 631 + const dim_val = rec_obj.get("dim") orelse { 632 + std.log.warn(" pack missing 'dim' field", .{}); 633 + return null; 634 + }; 635 + if (dim_val != .integer) { 636 + std.log.warn(" pack 'dim' is not integer (tag={t})", .{dim_val}); 637 + return null; 638 + } 639 + const old_dim: usize = @intCast(dim_val.integer); 640 + if (old_dim != current_dim) { 641 + std.log.info(" existing pack has dim={d}, current={d} — can't reuse", .{ old_dim, current_dim }); 642 + return null; 643 + } 644 + 645 + // positions blob → JSON array of {uri, cid} 646 + const positions_blob = rec_obj.get("positions") orelse { 647 + std.log.warn(" pack missing 'positions' field", .{}); 648 + return null; 649 + }; 650 + const positions_cid = (try extractBlobCid(positions_blob)) orelse { 651 + std.log.warn(" pack 'positions' has no extractable cid", .{}); 652 + return null; 653 + }; 654 + std.log.info(" loading positions blob {s}", .{positions_cid}); 655 + const positions_bytes = try pds.getBlob(arena, transport, pds_url, did, positions_cid); 656 + std.log.info(" positions blob {d} bytes", .{positions_bytes.len}); 657 + const positions_parsed = json.parseFromSliceLeaky(json.Value, arena, positions_bytes, .{}) catch |err| { 658 + std.log.warn(" positions blob parse failed: {t} (first 120 bytes: {s})", .{ err, positions_bytes[0..@min(positions_bytes.len, 120)] }); 659 + return null; 660 + }; 661 + if (positions_parsed != .array) { 662 + std.log.warn(" positions parsed but not an array (tag={t})", .{positions_parsed}); 663 + return null; 664 + } 665 + const n = positions_parsed.array.items.len; 666 + std.log.info(" positions has {d} entries", .{n}); 667 + 668 + // vectors array → chunked blobs of raw f32 bytes 669 + const vectors_val = rec_obj.get("vectors") orelse { 670 + std.log.warn(" pack missing 'vectors' field", .{}); 671 + return null; 672 + }; 673 + if (vectors_val != .array) { 674 + std.log.warn(" pack 'vectors' is not array (tag={t})", .{vectors_val}); 675 + return null; 676 + } 677 + 678 + const old_vectors = try arena.alloc(f32, n * old_dim); 679 + const old_vectors_bytes: []u8 = std.mem.sliceAsBytes(old_vectors); 680 + var byte_offset: usize = 0; 681 + for (vectors_val.array.items) |chunk_val| { 682 + const chunk_cid = (try extractBlobCid(chunk_val)) orelse return null; 683 + const chunk_bytes = try pds.getBlob(arena, transport, pds_url, did, chunk_cid); 684 + if (byte_offset + chunk_bytes.len > old_vectors_bytes.len) { 685 + // truncated / mismatched blob; bail out of reuse rather than 686 + // risk corrupt vectors 687 + std.log.warn(" existing pack vector blob overflow — skipping reuse", .{}); 688 + return null; 689 + } 690 + @memcpy(old_vectors_bytes[byte_offset..][0..chunk_bytes.len], chunk_bytes); 691 + byte_offset += chunk_bytes.len; 692 + } 693 + if (byte_offset != old_vectors_bytes.len) { 694 + std.log.warn( 695 + " existing pack vector size mismatch: got {d} bytes, expected {d} — skipping reuse", 696 + .{ byte_offset, old_vectors_bytes.len }, 697 + ); 698 + return null; 699 + } 700 + 701 + // uri → (cid, vec_idx). entries without a `cid` field (packs written by 702 + // older builds) are skipped: without content addressing we can't verify 703 + // the record is unchanged, and re-embedding is the safe default. 704 + var map = std.StringHashMap(CachedEntry).init(arena); 705 + for (positions_parsed.array.items, 0..) |entry_val, i| { 706 + if (entry_val != .object) continue; 707 + const obj = entry_val.object; 708 + const uri_v = obj.get("uri") orelse continue; 709 + const cid_v = obj.get("cid") orelse continue; 710 + if (uri_v != .string or cid_v != .string) continue; 711 + try map.put( 712 + try arena.dupe(u8, uri_v.string), 713 + .{ 714 + .cid = try arena.dupe(u8, cid_v.string), 715 + .vec_idx = i, 716 + }, 717 + ); 718 + } 719 + 720 + // prior-run metadata for calibrated ETA messaging. absence is fine — 721 + // packs written before these fields existed just report 0. 722 + const prior_build_ms: i64 = blk: { 723 + const v = rec_obj.get("buildMs") orelse break :blk 0; 724 + break :blk switch (v) { 725 + .integer => |i| i, 726 + else => 0, 727 + }; 728 + }; 729 + const prior_count: usize = blk: { 730 + const v = rec_obj.get("count") orelse break :blk 0; 731 + break :blk switch (v) { 732 + .integer => |i| if (i > 0) @intCast(i) else 0, 733 + else => 0, 734 + }; 735 + }; 736 + 737 + return .{ 738 + .uri_to_cached = map, 739 + .old_vectors = old_vectors, 740 + .old_pack_rkey = old_rkey, 741 + .prior_build_ms = prior_build_ms, 742 + .prior_count = prior_count, 743 + }; 744 + } 745 + 746 + /// extract `$link` out of a blob-ref JSON value (the shape written by 747 + /// uploadBlob): `{"$type":"blob","ref":{"$link":"<cid>"},...}`. 748 + fn extractBlobCid(val: json.Value) !?[]const u8 { 749 + if (val != .object) return null; 750 + const ref = val.object.get("ref") orelse return null; 751 + if (ref != .object) return null; 752 + const link = ref.object.get("$link") orelse return null; 753 + if (link != .string) return null; 754 + return link.string; 755 + } 756 + 757 + /// serialize the indexed pack's positions + vectors as PDS blobs, then 758 + /// create a `tech.waow.ken.pack` record referencing them. the returned 759 + /// at-uri is duped into `result_alloc` so it outlives the internal arena. 760 + /// 761 + /// this is the only write path to the user's PDS. it is ONLY reachable via 762 + /// `POST /api/pack/save` — doIndex never calls it directly. writing a 763 + /// record to someone's repo is a mutation they must explicitly consent to. 764 + pub fn writePackToPds( 765 + result_alloc: Allocator, 766 + session: store.Session, 767 + pack: *IndexedPack, 768 + ) ![]u8 { 769 + const build_ms = pack.build_ms; 770 + const n_reused = pack.records_reused; 771 + const n_fresh = if (pack.records_embedded >= pack.records_reused) 772 + pack.records_embedded - pack.records_reused 773 + else 774 + 0; 775 + var arena = std.heap.ArenaAllocator.init(result_alloc); 776 + defer arena.deinit(); 777 + const alloc = arena.allocator(); 778 + 779 + const n = pack.entries.len; 780 + const dim = pack.dim; 781 + 782 + // ---- positions blob: JSON array of {uri, cid} per entry ---- 783 + // this is the MINIMAL shape needed by loadExistingPack for incremental 784 + // reuse. title/body/date/collection are reconstructed from the fresh 785 + // walk on every re-index, so storing them in the pack is pure waste — 786 + // they bloated the blob to ~4 MB for a 17k-record repo, which tripped 787 + // the user's PDS blob-size limit and caused uploadBlob to return an 788 + // error that we couldn't even parse. ~1 MB now. 789 + var positions: std.ArrayList(u8) = .empty; 790 + try positions.appendSlice(alloc, "["); 791 + for (pack.entries, 0..) |e, i| { 792 + if (i > 0) try positions.appendSlice(alloc, ","); 793 + try positions.appendSlice(alloc, "{\"uri\":"); 794 + try writeJsonStringTo(&positions, alloc, e.uri); 795 + try positions.appendSlice(alloc, ",\"cid\":"); 796 + try writeJsonStringTo(&positions, alloc, e.cid); 797 + try positions.appendSlice(alloc, "}"); 798 + } 799 + try positions.appendSlice(alloc, "]"); 800 + 801 + if (positions.items.len > 5 * 1024 * 1024) { 802 + return error.PositionsTooLarge; 803 + } 804 + 805 + std.log.info(" uploading positions blob ({d} KB)", .{positions.items.len / 1024}); 806 + // declare as octet-stream even though the bytes are JSON. the ATProto 807 + // PDS has a bug in sync.getBlob where application/json blobs come back 808 + // as a serialized fs.ReadStream object instead of the file contents. 809 + // binary mime types hit a different code path that correctly pipes the 810 + // stream. we parse as JSON regardless on the load side, so the mime 811 + // lie is invisible to consumers. verified empirically 2026-04-09. 812 + const positions_ref = try oauth.uploadBlob(alloc, session, positions.items, "application/octet-stream"); 813 + 814 + // ---- vector chunks: raw f32 bytes, chunked under the blob limit ---- 815 + const vector_bytes_per_record = dim * @sizeOf(f32); 816 + const records_per_chunk = @max(@as(usize, 1), BLOB_CHUNK_BYTES / vector_bytes_per_record); 817 + 818 + var vector_refs: std.ArrayList(oauth.BlobRef) = .empty; 819 + var offset: usize = 0; 820 + while (offset < n) { 821 + const take = @min(records_per_chunk, n - offset); 822 + const byte_start = offset * dim; 823 + const byte_end = (offset + take) * dim; 824 + // f32 slice → byte slice 825 + const float_slice = pack.vectors[byte_start..byte_end]; 826 + const byte_slice: []const u8 = @as([*]const u8, @ptrCast(float_slice.ptr))[0 .. float_slice.len * @sizeOf(f32)]; 827 + 828 + std.log.info(" uploading vectors chunk {d}: {d} records, {d} KB", .{ 829 + vector_refs.items.len, 830 + take, 831 + byte_slice.len / 1024, 832 + }); 833 + const ref = try oauth.uploadBlob(alloc, session, byte_slice, "application/octet-stream"); 834 + try vector_refs.append(alloc, ref); 835 + offset += take; 836 + } 837 + 838 + // ---- create the manifest record ---- 839 + var record: std.ArrayList(u8) = .empty; 840 + try record.appendSlice(alloc, "{"); 841 + try record.print(alloc, "\"$type\":\"{s}\",", .{PACK_COLLECTION}); 842 + try record.appendSlice(alloc, "\"model\":\"BAAI/bge-small-en-v1.5\","); 843 + try record.print(alloc, "\"dim\":{d},", .{dim}); 844 + try record.appendSlice(alloc, "\"encoding\":\"float32\","); 845 + try record.print(alloc, "\"count\":{d},", .{n}); 846 + 847 + // positions is a single blob ref 848 + try record.print(alloc, 849 + "\"positions\":{{\"$type\":\"blob\",\"ref\":{{\"$link\":\"{s}\"}},\"mimeType\":\"{s}\",\"size\":{d}}},", 850 + .{ positions_ref.cid, positions_ref.mime_type, positions_ref.size }, 851 + ); 852 + 853 + // vectors is an array of blob refs 854 + try record.appendSlice(alloc, "\"vectors\":["); 855 + for (vector_refs.items, 0..) |r, i| { 856 + if (i > 0) try record.appendSlice(alloc, ","); 857 + try record.print(alloc, 858 + "{{\"$type\":\"blob\",\"ref\":{{\"$link\":\"{s}\"}},\"mimeType\":\"{s}\",\"size\":{d}}}", 859 + .{ r.cid, r.mime_type, r.size }, 860 + ); 861 + } 862 + try record.appendSlice(alloc, "],"); 863 + 864 + // build metrics — appview / frontend use these for calibrated ETA 865 + // messaging on the next run ("the last full build took N seconds for M 866 + // records, so yours will take ~X") 867 + try record.print(alloc, "\"buildMs\":{d},", .{build_ms}); 868 + try record.print(alloc, "\"reusedCount\":{d},", .{n_reused}); 869 + try record.print(alloc, "\"freshCount\":{d},", .{n_fresh}); 870 + 871 + // createdAt — best-effort iso8601 872 + const now_secs: i64 = @intCast(@divTrunc(Io.Timestamp.now(oauth.config().io, .real).nanoseconds, std.time.ns_per_s)); 873 + const epoch_secs: std.time.epoch.EpochSeconds = .{ .secs = @intCast(now_secs) }; 874 + const day = epoch_secs.getDaySeconds(); 875 + const year_day = epoch_secs.getEpochDay().calculateYearDay(); 876 + const md = year_day.calculateMonthDay(); 877 + try record.print(alloc, 878 + "\"createdAt\":\"{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}.000Z\"", 879 + .{ 880 + year_day.year, 881 + @intFromEnum(md.month), 882 + md.day_index + 1, 883 + day.getHoursIntoDay(), 884 + day.getMinutesIntoHour(), 885 + day.getSecondsIntoMinute(), 886 + }, 887 + ); 888 + try record.appendSlice(alloc, "}"); 889 + 890 + std.log.info(" createRecord {s} ({d} bytes)", .{ PACK_COLLECTION, record.items.len }); 891 + const pack_uri = try oauth.createRecord(alloc, session, PACK_COLLECTION, record.items); 892 + std.log.info(" pack stored at {s}", .{pack_uri}); 893 + // dupe out of the arena so the caller can hold it after deinit 894 + return result_alloc.dupe(u8, pack_uri); 895 + } 896 + 897 + fn writeJsonStringTo(buf: *std.ArrayList(u8), alloc: Allocator, s: []const u8) !void { 898 + try buf.append(alloc, '"'); 899 + for (s) |ch| { 900 + switch (ch) { 901 + '"' => try buf.appendSlice(alloc, "\\\""), 902 + '\\' => try buf.appendSlice(alloc, "\\\\"), 903 + '\n' => try buf.appendSlice(alloc, "\\n"), 904 + '\r' => try buf.appendSlice(alloc, "\\r"), 905 + '\t' => try buf.appendSlice(alloc, "\\t"), 906 + 0...0x08, 0x0b, 0x0c, 0x0e...0x1f => { 907 + try buf.print(alloc, "\\u{x:0>4}", .{ch}); 908 + }, 909 + else => try buf.append(alloc, ch), 910 + } 911 + } 912 + try buf.append(alloc, '"'); 913 + } 914 + 915 + /// top-K search over an indexed pack. returns indices into pack.entries, 916 + /// sorted by descending cosine similarity. caller owns the returned slice. 917 + /// 918 + /// safe to call during indexing: only rows with valid[i]==1 are scored, 919 + /// so the result reflects the partial pack state at snapshot time. 920 + pub fn search( 921 + allocator: Allocator, 922 + pack: *const IndexedPack, 923 + query_vec: []const f32, 924 + k: usize, 925 + ) ![]usize { 926 + if (pack.entries.len == 0 or pack.dim == 0) return &.{}; 927 + if (query_vec.len != pack.dim) return &.{}; 928 + 929 + // query is assumed already L2-normalized (Llama.embedder does this) 930 + const n = pack.entries.len; 931 + const dim = pack.dim; 932 + const scores = try allocator.alloc(f32, n); 933 + defer allocator.free(scores); 934 + 935 + // sentinel = invalid row, ensures it never wins a top-K slot 936 + const neg_inf = -std.math.inf(f32); 937 + var valid_count: usize = 0; 938 + for (0..n) |i| { 939 + if (pack.valid[i] == 0) { 940 + scores[i] = neg_inf; 941 + continue; 942 + } 943 + valid_count += 1; 944 + const row = pack.vectors[i * dim ..][0..dim]; 945 + var s: f32 = 0; 946 + for (0..dim) |j| s += row[j] * query_vec[j]; 947 + scores[i] = s; 948 + } 949 + 950 + if (valid_count == 0) return &.{}; 951 + 952 + const kk = @min(k, valid_count); 953 + const idxs = try allocator.alloc(usize, kk); 954 + errdefer allocator.free(idxs); 955 + 956 + // O(n * k) top-k selection — fine for k=30, n up to tens of thousands 957 + var taken: std.DynamicBitSetUnmanaged = try .initEmpty(allocator, n); 958 + defer taken.deinit(allocator); 959 + 960 + for (0..kk) |slot| { 961 + var best: isize = -1; 962 + var best_score: f32 = neg_inf; 963 + for (0..n) |i| { 964 + if (taken.isSet(i)) continue; 965 + if (scores[i] > best_score) { 966 + best_score = scores[i]; 967 + best = @intCast(i); 968 + } 969 + } 970 + if (best < 0) return idxs[0..slot]; 971 + const u: usize = @intCast(best); 972 + idxs[slot] = u; 973 + taken.set(u); 974 + } 975 + 976 + return idxs; 977 + }
+303
backend/src/llama.zig
··· 1 + //! llama.cpp wrapper for bge-small-en-v1.5 embedding inference. 2 + //! 3 + //! model + context are loaded once at startup and reused across requests. 4 + //! `embed()` takes a text string, returns an owned 384-dim f32 vector. 5 + //! 6 + //! NOTE: not thread-safe yet — callers must serialize access to a shared 7 + //! Embedder (Mutex + queue, or per-thread Embedders). for phase B we accept 8 + //! a single global embedder behind a mutex. 9 + 10 + const std = @import("std"); 11 + const Allocator = std.mem.Allocator; 12 + 13 + // disable glibc's fortified stdio wrappers — they use __builtin_va_arg_pack 14 + // which zig's C translator doesn't support. transitively included by llama.h 15 + // → ggml.h → stdio.h → bits/stdio2.h on x86_64-linux-gnu builds. works fine on 16 + // macos (uses its own stdio) and on the dockerized linux build with this shim. 17 + const c = @cImport({ 18 + @cDefine("_FORTIFY_SOURCE", "0"); 19 + @cInclude("llama.h"); 20 + @cInclude("ggml-backend.h"); 21 + }); 22 + 23 + pub const Embedder = struct { 24 + model: *c.llama_model, 25 + ctx: *c.llama_context, 26 + n_embd: usize, 27 + 28 + pub const InitError = error{ 29 + ModelLoadFailed, 30 + ContextFailed, 31 + }; 32 + 33 + pub fn init(model_path: [:0]const u8) InitError!Embedder { 34 + c.llama_backend_init(); 35 + // on linux the ggml backends (cpu, blas, etc.) ship as separate .so files 36 + // that must be explicitly loaded before a model is loaded. the default 37 + // search path in ggml_backend_load_all() is unreliable in stripped-down 38 + // containers — use the explicit _from_path variant so we're not at the 39 + // mercy of argv[0] detection. on macos metal is compiled into libllama 40 + // so ggml_backend_load_all is a no-op there. 41 + const backend_path: [:0]const u8 = if (std.c.getenv("GGML_BACKEND_PATH")) |p| 42 + std.mem.span(@as([*:0]const u8, @ptrCast(p))) 43 + else 44 + "/app/lib"; 45 + c.ggml_backend_load_all_from_path(backend_path.ptr); 46 + 47 + const mp = c.llama_model_default_params(); 48 + const model = c.llama_model_load_from_file(model_path.ptr, mp) orelse { 49 + std.log.err("llama_model_load_from_file failed for {s}", .{model_path}); 50 + return error.ModelLoadFailed; 51 + }; 52 + 53 + // 4096-token context + batch lets us process ~16 records at a time 54 + // (each capped to 256 tokens via MAX_TOKENS_PER_RECORD below). that's 55 + // the biggest win over single-record embedding: llama.cpp's encode 56 + // vectorizes across the full batch, so per-record overhead drops from 57 + // ~300 ms/rec on shared-cpu-2x to ~30 ms/rec. 58 + var cp = c.llama_context_default_params(); 59 + cp.n_ctx = 4096; 60 + cp.n_batch = 4096; 61 + cp.n_ubatch = 4096; 62 + // CRITICAL: default is 1 → llama_encode rejects any seq_id > 0 with 63 + // "failed to initialize batch". must be >= BATCH_SIZE used by 64 + // indexer.zig. 32 gives headroom without hurting memory. 65 + cp.n_seq_max = 32; 66 + cp.embeddings = true; 67 + cp.pooling_type = c.LLAMA_POOLING_TYPE_CLS; 68 + // match the fly performance-2x vm: 2 dedicated vCPUs. oversubscribing 69 + // here (n_threads > vCPUs) just causes context switches, not speedup. 70 + cp.n_threads = 2; 71 + cp.n_threads_batch = 2; 72 + 73 + const ctx = c.llama_init_from_model(model, cp) orelse { 74 + c.llama_model_free(model); 75 + return error.ContextFailed; 76 + }; 77 + 78 + const n_embd: usize = @intCast(c.llama_model_n_embd(model)); 79 + 80 + std.log.info("llama: loaded {s}, n_embd={d}", .{ model_path, n_embd }); 81 + 82 + return .{ 83 + .model = model, 84 + .ctx = ctx, 85 + .n_embd = n_embd, 86 + }; 87 + } 88 + 89 + pub fn deinit(self: *Embedder) void { 90 + c.llama_free(self.ctx); 91 + c.llama_model_free(self.model); 92 + c.llama_backend_free(); 93 + } 94 + 95 + pub const EmbedError = error{ 96 + TokenizeFailed, 97 + EncodeFailed, 98 + NoEmbeddings, 99 + BatchTooLarge, 100 + OutOfMemory, 101 + }; 102 + 103 + /// tokens per record cap (after truncation). bge-small's native context 104 + /// is 512; we cap at 256 so 16 records fit in a 4096-token batch. 105 + pub const MAX_TOKENS_PER_RECORD: usize = 256; 106 + /// scratch buffer used when calling llama_tokenize. MUST be large enough 107 + /// to hold the *untruncated* tokenization of any reasonable atproto 108 + /// record — otherwise llama_tokenize returns -N (needed N slots) WITHOUT 109 + /// writing any tokens, leaving the scratch buffer uninitialized. we then 110 + /// fed that garbage to llama_encode and got rc=-1 forever. 2048 covers 111 + /// anything that isn't pathological; if we still overflow, we skip the 112 + /// record rather than synthesize bogus tokens. 113 + pub const TOKENIZE_SCRATCH: usize = 4096; 114 + pub const BATCH_TOKEN_BUDGET: usize = 4096; 115 + 116 + pub const BatchStats = struct { 117 + /// tokens summed across every sequence in the batch (after truncation) 118 + total_tokens: usize, 119 + /// longest single sequence in the batch 120 + max_seq_tokens: usize, 121 + }; 122 + 123 + /// embed many texts in one llama_encode call. writes L2-normalized 124 + /// vectors into `out[i][0..n_embd]`. `texts.len` and `out.len` must 125 + /// match; caller owns `out` and its sub-slices. 126 + /// 127 + /// returns token stats so callers can correlate wall time to actual 128 + /// work. texts longer than MAX_TOKENS_PER_RECORD are truncated. callers 129 + /// must keep `texts.len <= n_seq_max` (32 — see init). 130 + pub fn embedBatch( 131 + self: *Embedder, 132 + allocator: Allocator, 133 + texts: []const []const u8, 134 + out: [][]f32, 135 + ) EmbedError!BatchStats { 136 + std.debug.assert(texts.len == out.len); 137 + if (texts.len == 0) return .{ .total_tokens = 0, .max_seq_tokens = 0 }; 138 + 139 + const vocab = c.llama_model_get_vocab(self.model); 140 + 141 + // tokenize everything up front so we can check the total token count 142 + // before building the llama_batch. 143 + var per_seq_tokens = try allocator.alloc([]c.llama_token, texts.len); 144 + defer { 145 + for (per_seq_tokens) |t| allocator.free(t); 146 + allocator.free(per_seq_tokens); 147 + } 148 + 149 + var total_tokens: usize = 0; 150 + var max_seq_tokens: usize = 0; 151 + for (texts, 0..) |text, i| { 152 + // scratch is TOKENIZE_SCRATCH wide so llama_tokenize can actually 153 + // write the full tokenization. only after it succeeds do we 154 + // truncate to MAX_TOKENS_PER_RECORD. this avoids the 155 + // uninitialized-memory bug: on overflow llama_tokenize returns 156 + // -needed WITHOUT touching the buffer. 157 + var scratch: [TOKENIZE_SCRATCH]c.llama_token = undefined; 158 + const n_raw = c.llama_tokenize( 159 + vocab, 160 + text.ptr, 161 + @intCast(text.len), 162 + &scratch, 163 + scratch.len, 164 + true, // add_special 165 + false, // parse_special 166 + ); 167 + if (n_raw <= 0) { 168 + // either an error (0) or overflow (<0, -n_raw was needed). 169 + // skip this record rather than synthesize bogus tokens — 170 + // better to have one dropped record than a poisoned batch. 171 + std.log.warn( 172 + "llama_tokenize failed for seq {d} (n_raw={d}, text_len={d}) — skipping", 173 + .{ i, n_raw, text.len }, 174 + ); 175 + return error.TokenizeFailed; 176 + } 177 + const n_real: usize = @intCast(n_raw); 178 + // truncate to the batch cap. this keeps [CLS] at position 0 179 + // (which is all CLS pooling actually reads) and drops the [SEP] 180 + // and any tail tokens beyond MAX_TOKENS_PER_RECORD. 181 + const n = @min(n_real, MAX_TOKENS_PER_RECORD); 182 + const slice = try allocator.alloc(c.llama_token, n); 183 + @memcpy(slice, scratch[0..n]); 184 + per_seq_tokens[i] = slice; 185 + total_tokens += n; 186 + if (n > max_seq_tokens) max_seq_tokens = n; 187 + } 188 + 189 + if (total_tokens > BATCH_TOKEN_BUDGET) return error.BatchTooLarge; 190 + 191 + // clear any KV state from the previous batch 192 + c.llama_memory_clear(c.llama_get_memory(self.ctx), true); 193 + 194 + var batch = c.llama_batch_init(@intCast(BATCH_TOKEN_BUDGET), 0, @intCast(texts.len)); 195 + defer c.llama_batch_free(batch); 196 + 197 + // fill the batch: one row per (token, seq_id, pos) tuple. with 198 + // LLAMA_POOLING_TYPE_CLS the pooling aggregates across every token 199 + // that has logits=1, so we enable all of them. 200 + for (per_seq_tokens, 0..) |seq_tokens, seq_id| { 201 + for (seq_tokens, 0..) |tok, pos| { 202 + const idx: usize = @intCast(batch.n_tokens); 203 + batch.token[idx] = tok; 204 + batch.pos[idx] = @intCast(pos); 205 + batch.n_seq_id[idx] = 1; 206 + batch.seq_id[idx][0] = @intCast(seq_id); 207 + batch.logits[idx] = 1; 208 + batch.n_tokens += 1; 209 + } 210 + } 211 + 212 + const rc = c.llama_encode(self.ctx, batch); 213 + if (rc != 0) { 214 + // dump enough info to diagnose: per-seq token counts, totals, 215 + // first 60 chars of each text. one line per failure so grep works. 216 + var preview: [64]u8 = undefined; 217 + std.log.warn( 218 + "embedBatch encode rc={d} n_seq={d} total_tokens={d} max_seq_tokens={d}", 219 + .{ rc, texts.len, total_tokens, max_seq_tokens }, 220 + ); 221 + for (texts, 0..) |t, i| { 222 + const pv_len = @min(t.len, preview.len); 223 + @memcpy(preview[0..pv_len], t[0..pv_len]); 224 + // scrub newlines in the preview 225 + for (preview[0..pv_len]) |*ch| { 226 + if (ch.* == '\n' or ch.* == '\r') ch.* = ' '; 227 + } 228 + std.log.warn(" seq[{d}] tokens={d} text='{s}'", .{ 229 + i, 230 + per_seq_tokens[i].len, 231 + preview[0..pv_len], 232 + }); 233 + } 234 + return error.EncodeFailed; 235 + } 236 + 237 + const stats: BatchStats = .{ 238 + .total_tokens = total_tokens, 239 + .max_seq_tokens = max_seq_tokens, 240 + }; 241 + 242 + for (0..texts.len) |seq_id| { 243 + const raw = c.llama_get_embeddings_seq(self.ctx, @intCast(seq_id)) orelse 244 + return error.NoEmbeddings; 245 + var sum_sq: f32 = 0; 246 + for (0..self.n_embd) |i| { 247 + out[seq_id][i] = raw[i]; 248 + sum_sq += raw[i] * raw[i]; 249 + } 250 + const norm: f32 = @sqrt(sum_sq); 251 + if (norm > 0) { 252 + for (out[seq_id]) |*v| v.* /= norm; 253 + } 254 + } 255 + 256 + return stats; 257 + } 258 + 259 + /// embed a single text string. returns an owned, L2-normalized f32 slice 260 + /// of length `self.n_embd`. caller frees. kept for the /api/embed debug 261 + /// endpoint and single-shot query embedding. 262 + pub fn embed(self: *Embedder, allocator: Allocator, text: []const u8) EmbedError![]f32 { 263 + const vocab = c.llama_model_get_vocab(self.model); 264 + 265 + // tokenize (bge expects BOS/EOS around the input → add_special = true) 266 + var tokens_buf: [512]c.llama_token = undefined; 267 + const n_tokens = c.llama_tokenize( 268 + vocab, 269 + text.ptr, 270 + @intCast(text.len), 271 + &tokens_buf, 272 + tokens_buf.len, 273 + true, // add_special 274 + false, // parse_special 275 + ); 276 + if (n_tokens < 0) return error.TokenizeFailed; 277 + 278 + // clear KV cache from the previous call 279 + c.llama_memory_clear(c.llama_get_memory(self.ctx), true); 280 + 281 + const batch = c.llama_batch_get_one(&tokens_buf, n_tokens); 282 + const rc = c.llama_encode(self.ctx, batch); 283 + if (rc != 0) return error.EncodeFailed; 284 + 285 + const raw = c.llama_get_embeddings_seq(self.ctx, 0) orelse return error.NoEmbeddings; 286 + 287 + // copy out + L2 normalize in one pass 288 + const out = try allocator.alloc(f32, self.n_embd); 289 + errdefer allocator.free(out); 290 + 291 + var sum_sq: f32 = 0; 292 + for (0..self.n_embd) |i| { 293 + out[i] = raw[i]; 294 + sum_sq += raw[i] * raw[i]; 295 + } 296 + const norm: f32 = @sqrt(sum_sq); 297 + if (norm > 0) { 298 + for (out) |*v| v.* /= norm; 299 + } 300 + 301 + return out; 302 + } 303 + };
+117
backend/src/main.zig
··· 1 + //! embed-on-pds backend — zig http server + llama.cpp inference + PDS index. 2 + //! 3 + //! phase F: OAuth + vectors-on-PDS persistence. 4 + //! no persistent local storage — OAuth sessions live in memory (reset on 5 + //! redeploy), llama model + viewer assets live in the image, indexed packs 6 + //! write back to the user's own PDS as records + blobs. 7 + 8 + const std = @import("std"); 9 + const Io = std.Io; 10 + const Thread = std.Thread; 11 + const server = @import("server.zig"); 12 + const Llama = @import("llama.zig"); 13 + const indexer = @import("indexer.zig"); 14 + const state = @import("state.zig"); 15 + const oauth = @import("oauth.zig"); 16 + 17 + const SOCKET_TIMEOUT_SECS = 30; 18 + 19 + var threaded_io: Io.Threaded = undefined; 20 + pub const std_options_debug_threaded_io: ?*Io.Threaded = &threaded_io; 21 + 22 + fn getenv(name: [*:0]const u8) ?[]const u8 { 23 + return if (std.c.getenv(name)) |p| std.mem.span(p) else null; 24 + } 25 + 26 + pub fn main(init: std.process.Init) !void { 27 + _ = init; 28 + const allocator = std.heap.smp_allocator; 29 + 30 + threaded_io = Io.Threaded.init(allocator, .{}); 31 + const io = threaded_io.io(); 32 + 33 + const port: u16 = blk: { 34 + const port_str = getenv("PORT") orelse "3000"; 35 + break :blk std.fmt.parseInt(u16, port_str, 10) catch 3000; 36 + }; 37 + 38 + // --- state (in-memory oauth sessions) --- 39 + state.init(io, allocator); 40 + defer state.close(); 41 + 42 + // --- oauth config --- 43 + oauth.init(.{ 44 + .io = io, 45 + .client_id = getenv("OAUTH_CLIENT_ID") orelse "https://embed-on-pds.fly.dev/oauth-client-metadata.json", 46 + .redirect_uri = getenv("OAUTH_REDIRECT_URI") orelse "https://embed-on-pds.fly.dev/oauth/callback", 47 + .frontend_origin = getenv("FRONTEND_ORIGIN") orelse "https://embed-on-pds.fly.dev", 48 + .client_key_hex = getenv("OAUTH_CLIENT_SECRET_KEY") orelse "", 49 + }); 50 + if (oauth.config().client_key_hex.len != 64) { 51 + std.log.warn("OAUTH_CLIENT_SECRET_KEY not set — oauth flows will fail", .{}); 52 + } 53 + 54 + // --- llama model --- 55 + const model_path: [:0]const u8 = if (std.c.getenv("MODEL_PATH")) |p| 56 + std.mem.span(@as([*:0]const u8, @ptrCast(p))) 57 + else 58 + "models/bge-small.gguf"; 59 + 60 + var embedder = Llama.Embedder.init(model_path) catch |err| { 61 + std.log.err("failed to init embedder: {t}", .{err}); 62 + return err; 63 + }; 64 + defer embedder.deinit(); 65 + 66 + var embedder_mutex: Io.Mutex = Io.Mutex.init; 67 + var cache = indexer.Cache.init(allocator); 68 + 69 + const dev_noauth = blk: { 70 + const v = getenv("EMBED_DEV_NOAUTH") orelse break :blk false; 71 + break :blk std.mem.eql(u8, v, "1") or std.mem.eql(u8, v, "true"); 72 + }; 73 + if (dev_noauth) { 74 + std.log.warn("EMBED_DEV_NOAUTH=1 — consent gate bypassed. DO NOT DEPLOY WITH THIS.", .{}); 75 + } 76 + 77 + var app = server.App{ 78 + .io = io, 79 + .allocator = allocator, 80 + .embedder = &embedder, 81 + .embedder_mutex = &embedder_mutex, 82 + .cache = &cache, 83 + .dev_noauth = dev_noauth, 84 + }; 85 + 86 + var addr = try Io.net.IpAddress.parse("::", port); 87 + var listener = addr.listen(io, .{ .reuse_address = true }) catch |err| { 88 + std.log.err("failed to listen on port {d}: {t}", .{ port, err }); 89 + return err; 90 + }; 91 + defer listener.deinit(io); 92 + 93 + std.log.info("embed-on-pds listening on :{d}", .{port}); 94 + 95 + while (true) { 96 + const stream = listener.accept(io) catch |err| { 97 + std.log.err("accept error: {t}", .{err}); 98 + continue; 99 + }; 100 + setSocketTimeout(stream.socket.handle, SOCKET_TIMEOUT_SECS) catch {}; 101 + const t = Thread.spawn(.{}, server.handleConnection, .{ stream, &app }) catch |err| { 102 + std.log.err("spawn error: {t}", .{err}); 103 + stream.close(io); 104 + continue; 105 + }; 106 + t.detach(); 107 + } 108 + } 109 + 110 + fn setSocketTimeout(fd: std.posix.fd_t, secs: u32) !void { 111 + const timeout = std.mem.toBytes(std.posix.timeval{ 112 + .sec = @intCast(secs), 113 + .usec = 0, 114 + }); 115 + try std.posix.setsockopt(fd, std.posix.SOL.SOCKET, std.posix.SO.RCVTIMEO, &timeout); 116 + try std.posix.setsockopt(fd, std.posix.SOL.SOCKET, std.posix.SO.SNDTIMEO, &timeout); 117 + }
+1059
backend/src/oauth.zig
··· 1 + //! OAuth client + authenticated PDS request helpers. 2 + //! 3 + //! lifted from pollz/backend/src/http.zig — the OAuth flow itself (PAR, 4 + //! token exchange, DPoP nonce retry, token refresh) is mechanical protocol 5 + //! work and there's no value in re-inventing it. what differs: 6 + //! - scopes: `atproto repo:tech.waow.ken.pack blob:*/*` 7 + //! - session storage is in state.zig (in-memory HashMap, no disk) 8 + //! - we add record/blob write wrappers specific to the embed.pack use case 9 + //! 10 + //! reference: https://atproto.com/specs/oauth 11 + 12 + const std = @import("std"); 13 + const Io = std.Io; 14 + const http = std.http; 15 + const mem = std.mem; 16 + const json = std.json; 17 + const Allocator = mem.Allocator; 18 + 19 + const zat = @import("zat"); 20 + const zat_oauth = zat.oauth; 21 + // renamed from `state` to `store` so it doesn't shadow the OAuth `state` 22 + // parameter in handleLogin/handleCallback 23 + const store = @import("state.zig"); 24 + 25 + pub const SCOPE = "atproto repo:tech.waow.ken.pack blob:*/*"; 26 + 27 + // runtime config — set by server.zig on init via envConfig() 28 + pub const Config = struct { 29 + io: Io, 30 + client_id: []const u8, 31 + redirect_uri: []const u8, 32 + frontend_origin: []const u8, 33 + client_key_hex: []const u8, // 64 hex chars (32 bytes p256 private) 34 + }; 35 + 36 + var cfg: Config = undefined; 37 + var cfg_set: bool = false; 38 + 39 + pub fn init(c: Config) void { 40 + cfg = c; 41 + cfg_set = true; 42 + } 43 + 44 + pub fn config() Config { 45 + std.debug.assert(cfg_set); 46 + return cfg; 47 + } 48 + 49 + pub fn getClientKeypair() !zat.Keypair { 50 + if (cfg.client_key_hex.len != 64) return error.InvalidClientKey; 51 + var key_bytes: [32]u8 = undefined; 52 + _ = std.fmt.hexToBytes(&key_bytes, cfg.client_key_hex) catch return error.InvalidClientKey; 53 + return zat.Keypair.fromSecretKey(.p256, key_bytes); 54 + } 55 + 56 + pub fn keypairFromHex(hex: []const u8) !zat.Keypair { 57 + if (hex.len != 64) return error.InvalidKeyHex; 58 + var key_bytes: [32]u8 = undefined; 59 + _ = std.fmt.hexToBytes(&key_bytes, hex) catch return error.InvalidKeyHex; 60 + return zat.Keypair.fromSecretKey(.p256, key_bytes); 61 + } 62 + 63 + // --------------------------------------------------------------------------- 64 + // basic HTTP helpers 65 + // --------------------------------------------------------------------------- 66 + 67 + pub fn httpGet(alloc: Allocator, url: []const u8) ![]u8 { 68 + var client: std.http.Client = .{ .allocator = alloc, .io = cfg.io }; 69 + defer client.deinit(); 70 + 71 + var aw: std.Io.Writer.Allocating = .init(alloc); 72 + const result = client.fetch(.{ 73 + .location = .{ .url = url }, 74 + .response_writer = &aw.writer, 75 + .headers = .{ .accept_encoding = .{ .override = "identity" } }, 76 + }) catch { 77 + aw.deinit(); 78 + return error.FetchFailed; 79 + }; 80 + if (result.status != .ok) { 81 + aw.deinit(); 82 + return error.FetchFailed; 83 + } 84 + return aw.toOwnedSlice() catch error.FetchFailed; 85 + } 86 + 87 + pub const HttpResult = struct { 88 + status: http.Status, 89 + body: []u8, 90 + dpop_nonce: ?[]const u8, 91 + }; 92 + 93 + pub fn doPost( 94 + alloc: Allocator, 95 + url: []const u8, 96 + payload: []const u8, 97 + extra_headers: []const http.Header, 98 + ) !HttpResult { 99 + var client: std.http.Client = .{ .allocator = alloc, .io = cfg.io }; 100 + defer client.deinit(); 101 + 102 + var req = try client.request(.POST, try std.Uri.parse(url), .{ 103 + .extra_headers = extra_headers, 104 + .headers = .{ 105 + .content_type = .{ .override = "application/x-www-form-urlencoded" }, 106 + .accept_encoding = .{ .override = "identity" }, 107 + }, 108 + }); 109 + defer req.deinit(); 110 + 111 + req.transfer_encoding = .{ .content_length = payload.len }; 112 + var body_writer = try req.sendBodyUnflushed(&.{}); 113 + try body_writer.writer.writeAll(payload); 114 + try body_writer.end(); 115 + try req.connection.?.flush(); 116 + 117 + var redirect_buf: [1]u8 = undefined; 118 + var response = req.receiveHead(&redirect_buf) catch return error.FetchFailed; 119 + 120 + var dpop_nonce: ?[]const u8 = null; 121 + var it = response.head.iterateHeaders(); 122 + while (it.next()) |h| { 123 + if (std.ascii.eqlIgnoreCase(h.name, "dpop-nonce")) { 124 + dpop_nonce = try alloc.dupe(u8, h.value); 125 + break; 126 + } 127 + } 128 + 129 + var aw: std.Io.Writer.Allocating = .init(alloc); 130 + const reader = response.reader(&.{}); 131 + _ = reader.streamRemaining(&aw.writer) catch { 132 + aw.deinit(); 133 + return error.FetchFailed; 134 + }; 135 + const resp_body = aw.toOwnedSlice() catch return error.FetchFailed; 136 + 137 + return .{ .status = response.head.status, .body = resp_body, .dpop_nonce = dpop_nonce }; 138 + } 139 + 140 + pub fn isDpopNonceError(status: http.Status, body: []const u8) bool { 141 + if (status != .bad_request and status != .unauthorized) return false; 142 + return mem.indexOf(u8, body, "use_dpop_nonce") != null; 143 + } 144 + 145 + pub fn isWwwAuthNonceError(status: http.Status, www_auth: ?[]const u8) bool { 146 + if (status != .unauthorized) return false; 147 + const h = www_auth orelse return false; 148 + return mem.indexOf(u8, h, "use_dpop_nonce") != null; 149 + } 150 + 151 + // --------------------------------------------------------------------------- 152 + // oauth metadata discovery 153 + // --------------------------------------------------------------------------- 154 + 155 + pub fn fetchAuthServerUrl(alloc: Allocator, pds_url: []const u8) ![]const u8 { 156 + const url = try std.fmt.allocPrint(alloc, "{s}/.well-known/oauth-protected-resource", .{pds_url}); 157 + defer alloc.free(url); 158 + 159 + const body = try httpGet(alloc, url); 160 + defer alloc.free(body); 161 + 162 + const parsed = try json.parseFromSlice(json.Value, alloc, body, .{}); 163 + defer parsed.deinit(); 164 + 165 + const servers = parsed.value.object.get("authorization_servers") orelse return error.NoAuthServers; 166 + if (servers != .array or servers.array.items.len == 0) return error.NoAuthServers; 167 + const first = servers.array.items[0]; 168 + if (first != .string) return error.NoAuthServers; 169 + return alloc.dupe(u8, first.string); 170 + } 171 + 172 + pub fn fetchAuthServerMeta(alloc: Allocator, authserver_url: []const u8) !json.Parsed(json.Value) { 173 + const url = try std.fmt.allocPrint(alloc, "{s}/.well-known/oauth-authorization-server", .{authserver_url}); 174 + defer alloc.free(url); 175 + const body = try httpGet(alloc, url); 176 + return json.parseFromSlice(json.Value, alloc, body, .{}); 177 + } 178 + 179 + pub fn jsonGetString(value: json.Value, key: []const u8) ?[]const u8 { 180 + if (value != .object) return null; 181 + const v = value.object.get(key) orelse return null; 182 + if (v != .string) return null; 183 + return v.string; 184 + } 185 + 186 + // --------------------------------------------------------------------------- 187 + // PAR + token exchange + refresh 188 + // --------------------------------------------------------------------------- 189 + 190 + pub const ParResult = struct { request_uri: []const u8, dpop_nonce: []const u8 }; 191 + pub const ParParams = struct { 192 + par_url: []const u8, 193 + authserver_url: []const u8, 194 + client_id: []const u8, 195 + redirect_uri: []const u8, 196 + scope: []const u8, 197 + state: []const u8, 198 + pkce_challenge: []const u8, 199 + handle: []const u8, 200 + client_keypair: *const zat.Keypair, 201 + dpop_keypair: *const zat.Keypair, 202 + }; 203 + 204 + pub fn sendParRequest(alloc: Allocator, params: ParParams) !ParResult { 205 + const client_assertion = try zat_oauth.createClientAssertion( 206 + alloc, cfg.io, params.client_keypair, params.client_id, params.authserver_url, 207 + ); 208 + defer alloc.free(client_assertion); 209 + 210 + const dpop_proof = try zat_oauth.createDpopProof( 211 + alloc, cfg.io, params.dpop_keypair, "POST", params.par_url, null, null, 212 + ); 213 + defer alloc.free(dpop_proof); 214 + 215 + const form_params = [_][2][]const u8{ 216 + .{ "response_type", "code" }, 217 + .{ "code_challenge", params.pkce_challenge }, 218 + .{ "code_challenge_method", "S256" }, 219 + .{ "redirect_uri", params.redirect_uri }, 220 + .{ "scope", params.scope }, 221 + .{ "state", params.state }, 222 + .{ "login_hint", params.handle }, 223 + .{ "client_id", params.client_id }, 224 + .{ "client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" }, 225 + .{ "client_assertion", client_assertion }, 226 + }; 227 + const form_body = try zat_oauth.formEncode(alloc, &form_params); 228 + defer alloc.free(form_body); 229 + 230 + var result = try doPost(alloc, params.par_url, form_body, &.{ 231 + .{ .name = "DPoP", .value = dpop_proof }, 232 + }); 233 + 234 + if (isDpopNonceError(result.status, result.body)) { 235 + const nonce = result.dpop_nonce orelse return error.MissingDpopNonce; 236 + alloc.free(result.body); 237 + const proof2 = try zat_oauth.createDpopProof( 238 + alloc, cfg.io, params.dpop_keypair, "POST", params.par_url, nonce, null, 239 + ); 240 + defer alloc.free(proof2); 241 + result = try doPost(alloc, params.par_url, form_body, &.{ 242 + .{ .name = "DPoP", .value = proof2 }, 243 + }); 244 + } 245 + defer alloc.free(result.body); 246 + 247 + if (result.status != .ok and result.status != .created) { 248 + std.log.warn("PAR error ({t}): {s}", .{ result.status, result.body }); 249 + return error.ParFailed; 250 + } 251 + 252 + const parsed = try json.parseFromSlice(json.Value, alloc, result.body, .{}); 253 + defer parsed.deinit(); 254 + const request_uri = jsonGetString(parsed.value, "request_uri") orelse return error.MissingRequestUri; 255 + 256 + return .{ 257 + .request_uri = try alloc.dupe(u8, request_uri), 258 + .dpop_nonce = if (result.dpop_nonce) |n| try alloc.dupe(u8, n) else try alloc.dupe(u8, ""), 259 + }; 260 + } 261 + 262 + pub const TokenResult = struct { 263 + access_token: []const u8, 264 + refresh_token: []const u8, 265 + sub: []const u8, 266 + dpop_nonce: []const u8, 267 + }; 268 + 269 + pub const TokenParams = struct { 270 + token_url: []const u8, 271 + authserver_url: []const u8, 272 + client_id: []const u8, 273 + redirect_uri: []const u8, 274 + code: []const u8, 275 + pkce_verifier: []const u8, 276 + client_keypair: *const zat.Keypair, 277 + dpop_keypair: *const zat.Keypair, 278 + dpop_nonce: []const u8, 279 + }; 280 + 281 + pub fn sendTokenRequest(alloc: Allocator, params: TokenParams) !TokenResult { 282 + const client_assertion = try zat_oauth.createClientAssertion( 283 + alloc, cfg.io, params.client_keypair, params.client_id, params.authserver_url, 284 + ); 285 + defer alloc.free(client_assertion); 286 + 287 + const dpop_proof = try zat_oauth.createDpopProof( 288 + alloc, cfg.io, params.dpop_keypair, "POST", params.token_url, 289 + if (params.dpop_nonce.len > 0) params.dpop_nonce else null, null, 290 + ); 291 + defer alloc.free(dpop_proof); 292 + 293 + const form_params = [_][2][]const u8{ 294 + .{ "grant_type", "authorization_code" }, 295 + .{ "code", params.code }, 296 + .{ "redirect_uri", params.redirect_uri }, 297 + .{ "code_verifier", params.pkce_verifier }, 298 + .{ "client_id", params.client_id }, 299 + .{ "client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" }, 300 + .{ "client_assertion", client_assertion }, 301 + }; 302 + const form_body = try zat_oauth.formEncode(alloc, &form_params); 303 + defer alloc.free(form_body); 304 + 305 + var result = try doPost(alloc, params.token_url, form_body, &.{ 306 + .{ .name = "DPoP", .value = dpop_proof }, 307 + }); 308 + 309 + if (isDpopNonceError(result.status, result.body)) { 310 + const nonce = result.dpop_nonce orelse return error.MissingDpopNonce; 311 + alloc.free(result.body); 312 + const proof2 = try zat_oauth.createDpopProof( 313 + alloc, cfg.io, params.dpop_keypair, "POST", params.token_url, nonce, null, 314 + ); 315 + defer alloc.free(proof2); 316 + result = try doPost(alloc, params.token_url, form_body, &.{ 317 + .{ .name = "DPoP", .value = proof2 }, 318 + }); 319 + } 320 + defer alloc.free(result.body); 321 + 322 + if (result.status != .ok) { 323 + std.log.warn("token exchange error ({t}): {s}", .{ result.status, result.body }); 324 + return error.TokenExchangeFailed; 325 + } 326 + 327 + const parsed = try json.parseFromSlice(json.Value, alloc, result.body, .{}); 328 + defer parsed.deinit(); 329 + 330 + return .{ 331 + .access_token = try alloc.dupe(u8, jsonGetString(parsed.value, "access_token") orelse return error.MissingAccessToken), 332 + .refresh_token = try alloc.dupe(u8, jsonGetString(parsed.value, "refresh_token") orelse return error.MissingRefreshToken), 333 + .sub = try alloc.dupe(u8, jsonGetString(parsed.value, "sub") orelse return error.MissingSub), 334 + .dpop_nonce = if (result.dpop_nonce) |n| try alloc.dupe(u8, n) else try alloc.dupe(u8, ""), 335 + }; 336 + } 337 + 338 + // --------------------------------------------------------------------------- 339 + // authenticated PDS request (DPoP + nonce retry + 401 refresh) 340 + // --------------------------------------------------------------------------- 341 + 342 + pub const PdsError = error{ 343 + Unauthorized, 344 + FetchFailed, 345 + InvalidSessionKey, 346 + AuthHeaderTooLong, 347 + DpopNonceRetryExhausted, 348 + OutOfMemory, 349 + }; 350 + 351 + /// make an authenticated request to a user's PDS. returns the response body 352 + /// as an owned slice. on 401 automatically refreshes the access token once 353 + /// and retries; otherwise propagates the error. 354 + pub fn pdsAuthedRequest( 355 + alloc: Allocator, 356 + session: store.Session, 357 + method_str: []const u8, 358 + path: []const u8, 359 + body: ?[]const u8, 360 + content_type: []const u8, 361 + ) ![]u8 { 362 + const dpop_keypair = keypairFromHex(session.dpop_private_key) catch return error.InvalidSessionKey; 363 + 364 + var access_token_buf: [2048]u8 = undefined; 365 + @memcpy(access_token_buf[0..session.access_token.len], session.access_token); 366 + var access_token_len = session.access_token.len; 367 + var refreshed = false; 368 + 369 + for (0..2) |attempt| { 370 + const access_token = access_token_buf[0..access_token_len]; 371 + const url = try std.fmt.allocPrint(alloc, "{s}{s}", .{ session.pds_url, path }); 372 + defer alloc.free(url); 373 + 374 + const ath = try zat_oauth.accessTokenHash(alloc, access_token); 375 + defer alloc.free(ath); 376 + 377 + var nonce: ?[]const u8 = if (session.dpop_pds_nonce.len > 0) session.dpop_pds_nonce else null; 378 + 379 + const inner = for (0..2) |_| { 380 + const dpop_proof = try zat_oauth.createDpopProof(alloc, cfg.io, &dpop_keypair, method_str, url, nonce, ath); 381 + defer alloc.free(dpop_proof); 382 + 383 + var auth_hdr_buf: [4096]u8 = undefined; 384 + const auth_header = std.fmt.bufPrint(&auth_hdr_buf, "DPoP {s}", .{access_token}) catch return error.AuthHeaderTooLong; 385 + 386 + const http_method: http.Method = if (mem.eql(u8, method_str, "POST")) .POST else .GET; 387 + 388 + var client: std.http.Client = .{ .allocator = alloc, .io = cfg.io }; 389 + defer client.deinit(); 390 + 391 + var req = try client.request(http_method, try std.Uri.parse(url), .{ 392 + .extra_headers = &.{ 393 + .{ .name = "Authorization", .value = auth_header }, 394 + .{ .name = "DPoP", .value = dpop_proof }, 395 + }, 396 + // two typed overrides: 397 + // 1. content_type carries the body's mime 398 + // 2. accept_encoding=identity forces the PDS to send raw 399 + // bytes. std.http.Client's client.request() low-level 400 + // path does NOT transparently decompress responses, so 401 + // without this we get gzip bytes and json.parseFromSlice 402 + // fails with SyntaxError. this has to be in the TYPED 403 + // headers slot — putting it in extra_headers is 404 + // additive (zig still sends its default gzip/deflate/ 405 + // zstd header alongside yours) and the server picks 406 + // whichever first. see notes/languages/ziglang/0.15/io.md 407 + // for the full writeup and an empirical repro. 408 + .headers = .{ 409 + .content_type = .{ .override = content_type }, 410 + .accept_encoding = .{ .override = "identity" }, 411 + }, 412 + }); 413 + defer req.deinit(); 414 + 415 + if (body) |b| { 416 + req.transfer_encoding = .{ .content_length = b.len }; 417 + var body_writer = try req.sendBodyUnflushed(&.{}); 418 + try body_writer.writer.writeAll(b); 419 + try body_writer.end(); 420 + try req.connection.?.flush(); 421 + } else { 422 + try req.sendBodiless(); 423 + } 424 + 425 + var redirect_buf: [1]u8 = undefined; 426 + var response = req.receiveHead(&redirect_buf) catch return error.FetchFailed; 427 + 428 + var new_nonce: ?[]const u8 = null; 429 + var www_auth: ?[]const u8 = null; 430 + var hit = response.head.iterateHeaders(); 431 + while (hit.next()) |h| { 432 + if (std.ascii.eqlIgnoreCase(h.name, "dpop-nonce")) { 433 + new_nonce = h.value; 434 + } else if (std.ascii.eqlIgnoreCase(h.name, "www-authenticate")) { 435 + www_auth = h.value; 436 + } 437 + } 438 + 439 + var aw: std.Io.Writer.Allocating = .init(alloc); 440 + const reader = response.reader(&.{}); 441 + _ = reader.streamRemaining(&aw.writer) catch { 442 + aw.deinit(); 443 + return error.FetchFailed; 444 + }; 445 + const resp_body = aw.toOwnedSlice() catch return error.FetchFailed; 446 + 447 + if (new_nonce) |n| store.updateSessionNonce(session.did, .pds, n); 448 + 449 + const is_nonce_err = new_nonce != null and (isDpopNonceError(response.head.status, resp_body) or isWwwAuthNonceError(response.head.status, www_auth)); 450 + if (is_nonce_err) { 451 + alloc.free(resp_body); 452 + nonce = try alloc.dupe(u8, new_nonce.?); 453 + continue; 454 + } 455 + break .{ response.head.status, resp_body }; 456 + } else { 457 + return error.DpopNonceRetryExhausted; 458 + }; 459 + 460 + const status = inner[0]; 461 + const resp_body = inner[1]; 462 + 463 + if (status != .unauthorized) { 464 + return resp_body; 465 + } 466 + 467 + alloc.free(resp_body); 468 + if (attempt > 0 or refreshed) return error.Unauthorized; 469 + 470 + std.log.info("access token rejected for {s}, refreshing", .{session.did}); 471 + const new_tokens = refreshAccessToken(alloc, session, &dpop_keypair) catch return error.Unauthorized; 472 + if (new_tokens.access_token.len > access_token_buf.len) return error.AuthHeaderTooLong; 473 + @memcpy(access_token_buf[0..new_tokens.access_token.len], new_tokens.access_token); 474 + access_token_len = new_tokens.access_token.len; 475 + refreshed = true; 476 + } 477 + 478 + return error.Unauthorized; 479 + } 480 + 481 + fn refreshAccessToken( 482 + alloc: Allocator, 483 + session: store.Session, 484 + dpop_keypair: *const zat.Keypair, 485 + ) !TokenResult { 486 + var authserver_meta = try fetchAuthServerMeta(alloc, session.authserver_iss); 487 + defer authserver_meta.deinit(); 488 + 489 + const token_url = jsonGetString(authserver_meta.value, "token_endpoint") orelse return error.MissingTokenEndpoint; 490 + 491 + const client_keypair = getClientKeypair() catch return error.InvalidSessionKey; 492 + const client_id = cfg.client_id; 493 + 494 + const client_assertion = try zat_oauth.createClientAssertion(alloc, cfg.io, &client_keypair, client_id, session.authserver_iss); 495 + defer alloc.free(client_assertion); 496 + 497 + var authserver_nonce: ?[]const u8 = if (session.dpop_authserver_nonce.len > 0) session.dpop_authserver_nonce else null; 498 + 499 + for (0..2) |_| { 500 + const dpop_proof = try zat_oauth.createDpopProof(alloc, cfg.io, dpop_keypair, "POST", token_url, authserver_nonce, null); 501 + defer alloc.free(dpop_proof); 502 + 503 + const form_params = [_][2][]const u8{ 504 + .{ "grant_type", "refresh_token" }, 505 + .{ "refresh_token", session.refresh_token }, 506 + .{ "client_id", client_id }, 507 + .{ "client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" }, 508 + .{ "client_assertion", client_assertion }, 509 + }; 510 + const form_body = try zat_oauth.formEncode(alloc, &form_params); 511 + defer alloc.free(form_body); 512 + 513 + const result = try doPost(alloc, token_url, form_body, &.{ 514 + .{ .name = "DPoP", .value = dpop_proof }, 515 + }); 516 + 517 + if (result.dpop_nonce) |n| store.updateSessionNonce(session.did, .authserver, n); 518 + 519 + if (isDpopNonceError(result.status, result.body)) { 520 + authserver_nonce = result.dpop_nonce; 521 + alloc.free(result.body); 522 + continue; 523 + } 524 + 525 + defer alloc.free(result.body); 526 + if (result.status != .ok) { 527 + std.log.warn("token refresh error ({t}): {s}", .{ result.status, result.body }); 528 + return error.TokenRefreshFailed; 529 + } 530 + 531 + const parsed = try json.parseFromSlice(json.Value, alloc, result.body, .{}); 532 + defer parsed.deinit(); 533 + 534 + const new_access = try alloc.dupe(u8, jsonGetString(parsed.value, "access_token") orelse return error.MissingAccessToken); 535 + const new_refresh = try alloc.dupe(u8, jsonGetString(parsed.value, "refresh_token") orelse return error.MissingRefreshToken); 536 + 537 + store.updateSessionTokens(session.did, new_access, new_refresh); 538 + 539 + return .{ 540 + .access_token = new_access, 541 + .refresh_token = new_refresh, 542 + .sub = try alloc.dupe(u8, session.did), 543 + .dpop_nonce = if (result.dpop_nonce) |n| try alloc.dupe(u8, n) else try alloc.dupe(u8, ""), 544 + }; 545 + } 546 + return error.TokenRefreshFailed; 547 + } 548 + 549 + // --------------------------------------------------------------------------- 550 + // high-level PDS write operations (for the indexer to use) 551 + // --------------------------------------------------------------------------- 552 + 553 + // --------------------------------------------------------------------------- 554 + // HTTP route handlers — called from server.zig's dispatcher 555 + // --------------------------------------------------------------------------- 556 + 557 + pub fn handleClientMetadata(request: *http.Server.Request) !void { 558 + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); 559 + defer arena.deinit(); 560 + const alloc = arena.allocator(); 561 + 562 + const keypair = getClientKeypair() catch { 563 + try sendError(request, .internal_server_error, "server configuration error"); 564 + return; 565 + }; 566 + const jwk = keypair.jwk(alloc) catch { 567 + try sendError(request, .internal_server_error, "key error"); 568 + return; 569 + }; 570 + 571 + var body: std.ArrayList(u8) = .empty; 572 + try body.print(alloc, 573 + \\{{ 574 + \\ "client_id": "{s}", 575 + \\ "client_name": "ken", 576 + \\ "client_uri": "{s}", 577 + \\ "application_type": "web", 578 + \\ "grant_types": ["authorization_code", "refresh_token"], 579 + \\ "response_types": ["code"], 580 + \\ "redirect_uris": ["{s}"], 581 + \\ "token_endpoint_auth_method": "private_key_jwt", 582 + \\ "token_endpoint_auth_signing_alg": "ES256", 583 + \\ "scope": "{s}", 584 + \\ "dpop_bound_access_tokens": true, 585 + \\ "jwks": {{"keys": [{s}]}} 586 + \\}} 587 + , .{ cfg.client_id, getClientOrigin(), cfg.redirect_uri, SCOPE, jwk }); 588 + 589 + try sendJson(request, body.items); 590 + } 591 + 592 + pub fn handleJwks(request: *http.Server.Request) !void { 593 + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); 594 + defer arena.deinit(); 595 + const alloc = arena.allocator(); 596 + 597 + const keypair = getClientKeypair() catch { 598 + try sendError(request, .internal_server_error, "server configuration error"); 599 + return; 600 + }; 601 + const jwks = zat_oauth.jwksJson(alloc, &keypair) catch { 602 + try sendError(request, .internal_server_error, "key error"); 603 + return; 604 + }; 605 + try sendJson(request, jwks); 606 + } 607 + 608 + pub fn handleLogin(request: *http.Server.Request) !void { 609 + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); 610 + defer arena.deinit(); 611 + const alloc = arena.allocator(); 612 + 613 + const target = request.head.target; 614 + const handle_str = extractQueryParam(target, "handle") orelse { 615 + try sendError(request, .bad_request, "missing handle parameter"); 616 + return; 617 + }; 618 + 619 + var handle_resolver = zat.HandleResolver.init(cfg.io, alloc); 620 + defer handle_resolver.deinit(); 621 + const did = handle_resolver.resolve(zat.Handle.parse(handle_str) orelse { 622 + try sendError(request, .bad_request, "invalid handle"); 623 + return; 624 + }) catch { 625 + try sendError(request, .bad_request, "could not resolve handle"); 626 + return; 627 + }; 628 + 629 + var did_resolver = zat.DidResolver.init(cfg.io, alloc); 630 + defer did_resolver.deinit(); 631 + var did_doc = did_resolver.resolve(zat.Did.parse(did) orelse { 632 + try sendError(request, .bad_request, "invalid DID"); 633 + return; 634 + }) catch { 635 + try sendError(request, .bad_request, "could not resolve DID"); 636 + return; 637 + }; 638 + defer did_doc.deinit(); 639 + 640 + const pds_url = did_doc.pdsEndpoint() orelse { 641 + try sendError(request, .bad_request, "no PDS endpoint"); 642 + return; 643 + }; 644 + 645 + const authserver_url = fetchAuthServerUrl(alloc, pds_url) catch { 646 + try sendError(request, .bad_request, "could not discover auth server"); 647 + return; 648 + }; 649 + 650 + var authserver_meta = fetchAuthServerMeta(alloc, authserver_url) catch { 651 + try sendError(request, .bad_request, "could not fetch auth server metadata"); 652 + return; 653 + }; 654 + defer authserver_meta.deinit(); 655 + 656 + const authserver_iss = jsonGetString(authserver_meta.value, "issuer") orelse { 657 + try sendError(request, .bad_request, "auth server missing issuer"); 658 + return; 659 + }; 660 + const par_url = jsonGetString(authserver_meta.value, "pushed_authorization_request_endpoint") orelse { 661 + try sendError(request, .bad_request, "auth server missing PAR endpoint"); 662 + return; 663 + }; 664 + const authorization_endpoint = jsonGetString(authserver_meta.value, "authorization_endpoint") orelse { 665 + try sendError(request, .bad_request, "auth server missing authorization endpoint"); 666 + return; 667 + }; 668 + 669 + const pkce_verifier = try zat_oauth.generatePkceVerifier(alloc, cfg.io); 670 + const pkce_challenge = try zat_oauth.generatePkceChallenge(alloc, pkce_verifier); 671 + const state = try zat_oauth.generateState(alloc, cfg.io); 672 + 673 + var dpop_key_bytes: [32]u8 = undefined; 674 + cfg.io.random(&dpop_key_bytes); 675 + const dpop_keypair = zat.Keypair.fromSecretKey(.p256, dpop_key_bytes) catch { 676 + try sendError(request, .internal_server_error, "key generation failed"); 677 + return; 678 + }; 679 + 680 + const client_keypair = getClientKeypair() catch { 681 + try sendError(request, .internal_server_error, "server configuration error"); 682 + return; 683 + }; 684 + 685 + const par_result = sendParRequest(alloc, .{ 686 + .par_url = par_url, 687 + .authserver_url = authserver_iss, 688 + .client_id = cfg.client_id, 689 + .redirect_uri = cfg.redirect_uri, 690 + .scope = SCOPE, 691 + .state = state, 692 + .pkce_challenge = pkce_challenge, 693 + .handle = handle_str, 694 + .client_keypair = &client_keypair, 695 + .dpop_keypair = &dpop_keypair, 696 + }) catch { 697 + try sendError(request, .bad_gateway, "PAR request failed"); 698 + return; 699 + }; 700 + 701 + const dpop_hex = std.fmt.bytesToHex(dpop_key_bytes, .lower); 702 + store.insertAuthRequest( 703 + state, authserver_iss, did, handle_str, pds_url, 704 + pkce_verifier, SCOPE, par_result.dpop_nonce, &dpop_hex, 705 + ) catch { 706 + try sendError(request, .internal_server_error, "could not store auth request"); 707 + return; 708 + }; 709 + 710 + var redirect_url: std.ArrayList(u8) = .empty; 711 + try redirect_url.print(alloc, "{s}?request_uri={s}&client_id={s}&state={s}", .{ 712 + authorization_endpoint, par_result.request_uri, cfg.client_id, state, 713 + }); 714 + try sendRedirect(request, redirect_url.items); 715 + } 716 + 717 + pub fn handleCallback(request: *http.Server.Request) !void { 718 + var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator); 719 + defer arena.deinit(); 720 + const alloc = arena.allocator(); 721 + 722 + const target = request.head.target; 723 + const code = extractQueryParam(target, "code") orelse { 724 + try sendError(request, .bad_request, "missing code"); 725 + return; 726 + }; 727 + const state = extractQueryParam(target, "state") orelse { 728 + try sendError(request, .bad_request, "missing state"); 729 + return; 730 + }; 731 + const iss_raw = extractQueryParam(target, "iss"); 732 + const iss = if (iss_raw) |raw| blk: { 733 + const buf = try alloc.dupe(u8, raw); 734 + break :blk std.Uri.percentDecodeBackwards(buf, buf); 735 + } else null; 736 + 737 + const auth_req = (try store.getAuthRequest(alloc, state)) orelse { 738 + try sendError(request, .bad_request, "unknown state — login may have expired"); 739 + return; 740 + }; 741 + 742 + if (iss) |issuer| { 743 + if (!mem.eql(u8, issuer, auth_req.authserver_iss)) { 744 + try sendError(request, .bad_request, "issuer mismatch"); 745 + return; 746 + } 747 + } 748 + 749 + const dpop_keypair = keypairFromHex(auth_req.dpop_private_key) catch { 750 + try sendError(request, .internal_server_error, "invalid stored key"); 751 + return; 752 + }; 753 + 754 + const client_keypair = getClientKeypair() catch { 755 + try sendError(request, .internal_server_error, "server configuration error"); 756 + return; 757 + }; 758 + 759 + var authserver_meta = fetchAuthServerMeta(alloc, auth_req.authserver_iss) catch { 760 + try sendError(request, .bad_gateway, "could not fetch auth server metadata"); 761 + return; 762 + }; 763 + defer authserver_meta.deinit(); 764 + 765 + const token_url = jsonGetString(authserver_meta.value, "token_endpoint") orelse { 766 + try sendError(request, .bad_gateway, "auth server missing token endpoint"); 767 + return; 768 + }; 769 + 770 + const token_result = sendTokenRequest(alloc, .{ 771 + .token_url = token_url, 772 + .authserver_url = auth_req.authserver_iss, 773 + .client_id = cfg.client_id, 774 + .redirect_uri = cfg.redirect_uri, 775 + .code = code, 776 + .pkce_verifier = auth_req.pkce_verifier, 777 + .client_keypair = &client_keypair, 778 + .dpop_keypair = &dpop_keypair, 779 + .dpop_nonce = auth_req.dpop_authserver_nonce, 780 + }) catch { 781 + try sendError(request, .bad_gateway, "token exchange failed"); 782 + return; 783 + }; 784 + 785 + if (!mem.eql(u8, token_result.sub, auth_req.did)) { 786 + try sendError(request, .bad_request, "token subject mismatch"); 787 + return; 788 + } 789 + 790 + store.upsertSession( 791 + auth_req.did, auth_req.handle, auth_req.pds_url, auth_req.authserver_iss, 792 + token_result.access_token, token_result.refresh_token, 793 + token_result.dpop_nonce, "", 794 + auth_req.dpop_private_key, 795 + ) catch { 796 + try sendError(request, .internal_server_error, "could not store session"); 797 + return; 798 + }; 799 + store.deleteAuthRequest(state); 800 + 801 + // redirect back to the frontend with ?logged_in={handle} so the JS can 802 + // resume the explore flow automatically 803 + var redirect_url: std.ArrayList(u8) = .empty; 804 + try redirect_url.print(alloc, "{s}/?logged_in={s}", .{ cfg.frontend_origin, auth_req.handle }); 805 + 806 + var cookie_buf: [512]u8 = undefined; 807 + const cookie = std.fmt.bufPrint( 808 + &cookie_buf, 809 + "embed_session={s}; HttpOnly; Secure; SameSite=Lax; Path=/; Max-Age=2592000", 810 + .{auth_req.did}, 811 + ) catch { 812 + try sendError(request, .internal_server_error, "cookie error"); 813 + return; 814 + }; 815 + 816 + try request.respond("", .{ 817 + .status = .found, 818 + .extra_headers = &.{ 819 + .{ .name = "location", .value = redirect_url.items }, 820 + .{ .name = "set-cookie", .value = cookie }, 821 + }, 822 + }); 823 + } 824 + 825 + pub fn handleLogout(request: *http.Server.Request) !void { 826 + if (getSessionDid(request)) |did| { 827 + store.deleteSession(did); 828 + } 829 + try request.respond("{\"ok\":true}", .{ 830 + .status = .ok, 831 + .extra_headers = &.{ 832 + .{ .name = "content-type", .value = "application/json" }, 833 + .{ .name = "set-cookie", .value = "embed_session=; HttpOnly; Secure; SameSite=Lax; Path=/; Max-Age=0" }, 834 + }, 835 + }); 836 + } 837 + 838 + // --------------------------------------------------------------------------- 839 + // response helpers + cookie parsing 840 + // --------------------------------------------------------------------------- 841 + 842 + pub fn getSessionDid(request: *http.Server.Request) ?[]const u8 { 843 + var it = request.iterateHeaders(); 844 + while (it.next()) |h| { 845 + if (std.ascii.eqlIgnoreCase(h.name, "cookie")) { 846 + return parseCookieValue(h.value, "embed_session"); 847 + } 848 + } 849 + return null; 850 + } 851 + 852 + fn parseCookieValue(cookie_header: []const u8, name: []const u8) ?[]const u8 { 853 + var it = mem.splitSequence(u8, cookie_header, "; "); 854 + while (it.next()) |pair| { 855 + if (mem.startsWith(u8, pair, name)) { 856 + if (pair.len > name.len and pair[name.len] == '=') { 857 + return pair[name.len + 1 ..]; 858 + } 859 + } 860 + } 861 + return null; 862 + } 863 + 864 + fn extractQueryParam(target: []const u8, name: []const u8) ?[]const u8 { 865 + const q_idx = mem.indexOf(u8, target, "?") orelse return null; 866 + const query = target[q_idx + 1 ..]; 867 + var it = mem.splitScalar(u8, query, '&'); 868 + while (it.next()) |pair| { 869 + const eq_idx = mem.indexOf(u8, pair, "=") orelse continue; 870 + if (mem.eql(u8, pair[0..eq_idx], name)) { 871 + return pair[eq_idx + 1 ..]; 872 + } 873 + } 874 + return null; 875 + } 876 + 877 + fn getClientOrigin() []const u8 { 878 + const cid = cfg.client_id; 879 + const scheme_end = mem.indexOf(u8, cid, "://") orelse return cid; 880 + const after = cid[scheme_end + 3 ..]; 881 + const path_start = mem.indexOf(u8, after, "/") orelse return cid; 882 + return cid[0 .. scheme_end + 3 + path_start]; 883 + } 884 + 885 + fn sendError(request: *http.Server.Request, status: http.Status, message: []const u8) !void { 886 + var buf: [512]u8 = undefined; 887 + const body = std.fmt.bufPrint(&buf, "{{\"error\":\"{s}\"}}", .{message}) catch "{\"error\":\"internal error\"}"; 888 + try request.respond(body, .{ 889 + .status = status, 890 + .extra_headers = &.{ 891 + .{ .name = "content-type", .value = "application/json" }, 892 + .{ .name = "access-control-allow-origin", .value = "*" }, 893 + }, 894 + }); 895 + } 896 + 897 + fn sendJson(request: *http.Server.Request, body: []const u8) !void { 898 + try request.respond(body, .{ 899 + .status = .ok, 900 + .extra_headers = &.{ 901 + .{ .name = "content-type", .value = "application/json" }, 902 + .{ .name = "access-control-allow-origin", .value = "*" }, 903 + }, 904 + }); 905 + } 906 + 907 + fn sendRedirect(request: *http.Server.Request, location: []const u8) !void { 908 + try request.respond("", .{ 909 + .status = .found, 910 + .extra_headers = &.{ 911 + .{ .name = "location", .value = location }, 912 + }, 913 + }); 914 + } 915 + 916 + // --------------------------------------------------------------------------- 917 + // high-level PDS write operations (for the indexer to use) 918 + // --------------------------------------------------------------------------- 919 + 920 + pub const BlobRef = struct { 921 + cid: []const u8, 922 + mime_type: []const u8, 923 + size: i64, 924 + }; 925 + 926 + /// upload a blob via com.atproto.repo.uploadBlob. returns the blob ref fields 927 + /// that go into the referencing record. caller owns cid string. 928 + pub fn uploadBlob( 929 + alloc: Allocator, 930 + session: store.Session, 931 + bytes: []const u8, 932 + mime_type: []const u8, 933 + ) !BlobRef { 934 + const resp = try pdsAuthedRequest( 935 + alloc, 936 + session, 937 + "POST", 938 + "/xrpc/com.atproto.repo.uploadBlob", 939 + bytes, 940 + mime_type, 941 + ); 942 + defer alloc.free(resp); 943 + 944 + const parsed = json.parseFromSlice(json.Value, alloc, resp, .{}) catch |err| { 945 + // log enough of the body to diagnose a non-JSON error from the PDS 946 + // (413 Payload Too Large, HTML error page, empty body, etc) — we 947 + // couldn't figure out "ParseFailed" without this. 948 + const preview_len = @min(resp.len, 400); 949 + std.log.warn( 950 + "uploadBlob: response was not json ({t}). bytes={d}, body[0..{d}]={s}", 951 + .{ err, bytes.len, preview_len, resp[0..preview_len] }, 952 + ); 953 + return error.ParseFailed; 954 + }; 955 + defer parsed.deinit(); 956 + 957 + const blob = parsed.value.object.get("blob") orelse { 958 + const preview_len = @min(resp.len, 400); 959 + std.log.warn( 960 + "uploadBlob: response missing 'blob' field. bytes={d}, body[0..{d}]={s}", 961 + .{ bytes.len, preview_len, resp[0..preview_len] }, 962 + ); 963 + return error.MissingBlob; 964 + }; 965 + if (blob != .object) return error.MissingBlob; 966 + 967 + const ref = blob.object.get("ref") orelse return error.MissingBlob; 968 + var cid: []const u8 = ""; 969 + if (ref == .object) { 970 + if (ref.object.get("$link")) |link| { 971 + if (link == .string) cid = link.string; 972 + } 973 + } 974 + if (cid.len == 0) return error.MissingBlob; 975 + 976 + const mime = if (blob.object.get("mimeType")) |m| 977 + if (m == .string) m.string else mime_type 978 + else 979 + mime_type; 980 + 981 + const size: i64 = if (blob.object.get("size")) |s| 982 + if (s == .integer) s.integer else @intCast(bytes.len) 983 + else 984 + @intCast(bytes.len); 985 + 986 + return .{ 987 + .cid = try alloc.dupe(u8, cid), 988 + .mime_type = try alloc.dupe(u8, mime), 989 + .size = size, 990 + }; 991 + } 992 + 993 + /// create a record at `collection/<tid>`. returns the at-uri as an owned 994 + /// slice. callers that need just the rkey can parse it via zat.AtUri. 995 + pub fn createRecord( 996 + alloc: Allocator, 997 + session: store.Session, 998 + collection: []const u8, 999 + record_json: []const u8, 1000 + ) ![]u8 { 1001 + const body = try std.fmt.allocPrint( 1002 + alloc, 1003 + "{{\"repo\":\"{s}\",\"collection\":\"{s}\",\"record\":{s}}}", 1004 + .{ session.did, collection, record_json }, 1005 + ); 1006 + defer alloc.free(body); 1007 + 1008 + const resp = try pdsAuthedRequest( 1009 + alloc, 1010 + session, 1011 + "POST", 1012 + "/xrpc/com.atproto.repo.createRecord", 1013 + body, 1014 + "application/json", 1015 + ); 1016 + defer alloc.free(resp); 1017 + 1018 + const parsed = json.parseFromSlice(json.Value, alloc, resp, .{}) catch |err| { 1019 + const preview_len = @min(resp.len, 400); 1020 + std.log.warn( 1021 + "createRecord: response was not json ({t}). body[0..{d}]={s}", 1022 + .{ err, preview_len, resp[0..preview_len] }, 1023 + ); 1024 + return error.ParseFailed; 1025 + }; 1026 + defer parsed.deinit(); 1027 + 1028 + const uri_v = parsed.value.object.get("uri") orelse return error.MissingUri; 1029 + if (uri_v != .string) return error.MissingUri; 1030 + return alloc.dupe(u8, uri_v.string); 1031 + } 1032 + 1033 + /// delete a record via com.atproto.repo.deleteRecord. used to sweep the old 1034 + /// pack after a successful incremental re-index so the PDS only holds one 1035 + /// current pack at a time. 1036 + pub fn deleteRecord( 1037 + alloc: Allocator, 1038 + session: store.Session, 1039 + collection: []const u8, 1040 + rkey: []const u8, 1041 + ) !void { 1042 + const body = try std.fmt.allocPrint( 1043 + alloc, 1044 + "{{\"repo\":\"{s}\",\"collection\":\"{s}\",\"rkey\":\"{s}\"}}", 1045 + .{ session.did, collection, rkey }, 1046 + ); 1047 + defer alloc.free(body); 1048 + 1049 + const resp = try pdsAuthedRequest( 1050 + alloc, 1051 + session, 1052 + "POST", 1053 + "/xrpc/com.atproto.repo.deleteRecord", 1054 + body, 1055 + "application/json", 1056 + ); 1057 + defer alloc.free(resp); 1058 + // PDS returns a small JSON envelope on success; we don't need its fields. 1059 + }
+221
backend/src/pds.zig
··· 1 + //! thin PDS client: handle → DID, DID → PDS URL, describeRepo, listRecords. 2 + //! 3 + //! for v1 all reads go through bsky's public appview (resolveHandle) and 4 + //! plc.directory (DID document). no auth required — we're only reading 5 + //! public data to build an in-memory search index. 6 + 7 + const std = @import("std"); 8 + const json = std.json; 9 + const Allocator = std.mem.Allocator; 10 + const zat = @import("zat"); 11 + 12 + const PUBLIC_APPVIEW = "https://public.api.bsky.app"; 13 + const PLC_DIRECTORY = "https://plc.directory"; 14 + 15 + pub const ResolveError = error{ 16 + ResolveFailed, 17 + NoPdsEndpoint, 18 + RequestFailed, 19 + OutOfMemory, 20 + ParseFailed, 21 + }; 22 + 23 + pub const Identity = struct { 24 + did: []const u8, 25 + pds: []const u8, 26 + }; 27 + 28 + /// resolve a handle or DID to {did, pds_url}. allocates into `arena`. 29 + pub fn resolveIdentity( 30 + arena: Allocator, 31 + transport: *zat.HttpTransport, 32 + actor: []const u8, 33 + ) ResolveError!Identity { 34 + // step 1: if `actor` already looks like a DID, use it directly; otherwise 35 + // resolve via bsky's public appview. 36 + const did: []const u8 = if (std.mem.startsWith(u8, actor, "did:")) 37 + try arena.dupe(u8, actor) 38 + else 39 + try resolveHandle(arena, transport, actor); 40 + 41 + // step 2: DID document → PDS service endpoint 42 + const pds = try resolvePdsFromDid(arena, transport, did); 43 + return .{ .did = did, .pds = pds }; 44 + } 45 + 46 + fn resolveHandle( 47 + arena: Allocator, 48 + transport: *zat.HttpTransport, 49 + handle: []const u8, 50 + ) ResolveError![]const u8 { 51 + const url = try std.fmt.allocPrint( 52 + arena, 53 + "{s}/xrpc/com.atproto.identity.resolveHandle?handle={s}", 54 + .{ PUBLIC_APPVIEW, handle }, 55 + ); 56 + 57 + const result = transport.fetch(.{ .url = url }) catch return error.RequestFailed; 58 + if (result.status != .ok) return error.ResolveFailed; 59 + 60 + const parsed = json.parseFromSliceLeaky(json.Value, arena, result.body, .{}) catch 61 + return error.ParseFailed; 62 + 63 + const did_val = parsed.object.get("did") orelse return error.ResolveFailed; 64 + if (did_val != .string) return error.ResolveFailed; 65 + return try arena.dupe(u8, did_val.string); 66 + } 67 + 68 + fn resolvePdsFromDid( 69 + arena: Allocator, 70 + transport: *zat.HttpTransport, 71 + did: []const u8, 72 + ) ResolveError![]const u8 { 73 + const url = if (std.mem.startsWith(u8, did, "did:plc:")) 74 + try std.fmt.allocPrint(arena, "{s}/{s}", .{ PLC_DIRECTORY, did }) 75 + else if (std.mem.startsWith(u8, did, "did:web:")) blk: { 76 + const domain = did["did:web:".len..]; 77 + break :blk try std.fmt.allocPrint(arena, "https://{s}/.well-known/did.json", .{domain}); 78 + } else return error.ResolveFailed; 79 + 80 + const result = transport.fetch(.{ .url = url }) catch return error.RequestFailed; 81 + if (result.status != .ok) return error.ResolveFailed; 82 + 83 + const parsed = json.parseFromSliceLeaky(json.Value, arena, result.body, .{}) catch 84 + return error.ParseFailed; 85 + 86 + const services = parsed.object.get("service") orelse return error.NoPdsEndpoint; 87 + if (services != .array) return error.NoPdsEndpoint; 88 + 89 + for (services.array.items) |svc| { 90 + if (svc != .object) continue; 91 + const svc_type = svc.object.get("type") orelse continue; 92 + if (svc_type != .string) continue; 93 + if (!std.mem.eql(u8, svc_type.string, "AtprotoPersonalDataServer")) continue; 94 + const endpoint = svc.object.get("serviceEndpoint") orelse continue; 95 + if (endpoint != .string) continue; 96 + return try arena.dupe(u8, endpoint.string); 97 + } 98 + return error.NoPdsEndpoint; 99 + } 100 + 101 + /// get the list of collection NSIDs present in a repo. 102 + pub fn describeRepo( 103 + arena: Allocator, 104 + transport: *zat.HttpTransport, 105 + pds: []const u8, 106 + did: []const u8, 107 + ) ResolveError![]const []const u8 { 108 + const url = try std.fmt.allocPrint( 109 + arena, 110 + "{s}/xrpc/com.atproto.repo.describeRepo?repo={s}", 111 + .{ pds, did }, 112 + ); 113 + 114 + const result = transport.fetch(.{ .url = url }) catch return error.RequestFailed; 115 + if (result.status != .ok) return error.ResolveFailed; 116 + 117 + const parsed = json.parseFromSliceLeaky(json.Value, arena, result.body, .{}) catch 118 + return error.ParseFailed; 119 + 120 + const collections = parsed.object.get("collections") orelse return error.ParseFailed; 121 + if (collections != .array) return error.ParseFailed; 122 + 123 + var out: std.ArrayList([]const u8) = .empty; 124 + for (collections.array.items) |c| { 125 + if (c != .string) continue; 126 + try out.append(arena, try arena.dupe(u8, c.string)); 127 + } 128 + return out.items; 129 + } 130 + 131 + pub const Record = struct { 132 + uri: []const u8, 133 + cid: []const u8, 134 + value: json.Value, // arena-backed, walked by record_text.zig 135 + }; 136 + 137 + pub const ListPage = struct { 138 + records: []Record, 139 + cursor: ?[]const u8, 140 + }; 141 + 142 + /// one page of listRecords for a given collection. caller loops until cursor 143 + /// is null to get everything. `limit` max is 100 per the spec. 144 + pub fn listRecords( 145 + arena: Allocator, 146 + transport: *zat.HttpTransport, 147 + pds: []const u8, 148 + did: []const u8, 149 + collection: []const u8, 150 + cursor: ?[]const u8, 151 + limit: u32, 152 + ) ResolveError!ListPage { 153 + const url = if (cursor) |c| 154 + try std.fmt.allocPrint( 155 + arena, 156 + "{s}/xrpc/com.atproto.repo.listRecords?repo={s}&collection={s}&limit={d}&cursor={s}", 157 + .{ pds, did, collection, limit, c }, 158 + ) 159 + else 160 + try std.fmt.allocPrint( 161 + arena, 162 + "{s}/xrpc/com.atproto.repo.listRecords?repo={s}&collection={s}&limit={d}", 163 + .{ pds, did, collection, limit }, 164 + ); 165 + 166 + const result = transport.fetch(.{ .url = url }) catch return error.RequestFailed; 167 + if (result.status != .ok) return error.ResolveFailed; 168 + 169 + const parsed = json.parseFromSliceLeaky(json.Value, arena, result.body, .{}) catch 170 + return error.ParseFailed; 171 + 172 + var out_records: std.ArrayList(Record) = .empty; 173 + 174 + if (parsed.object.get("records")) |records_val| { 175 + if (records_val == .array) { 176 + for (records_val.array.items) |r| { 177 + if (r != .object) continue; 178 + const uri = (r.object.get("uri") orelse continue); 179 + const cid = (r.object.get("cid") orelse continue); 180 + const value = r.object.get("value") orelse continue; 181 + if (uri != .string or cid != .string) continue; 182 + try out_records.append(arena, .{ 183 + .uri = try arena.dupe(u8, uri.string), 184 + .cid = try arena.dupe(u8, cid.string), 185 + .value = value, 186 + }); 187 + } 188 + } 189 + } 190 + 191 + const cursor_out: ?[]const u8 = blk: { 192 + const v = parsed.object.get("cursor") orelse break :blk null; 193 + if (v != .string) break :blk null; 194 + if (v.string.len == 0) break :blk null; 195 + break :blk try arena.dupe(u8, v.string); 196 + }; 197 + 198 + return .{ .records = out_records.items, .cursor = cursor_out }; 199 + } 200 + 201 + /// fetch a blob by cid via com.atproto.sync.getBlob. unauthenticated — blobs 202 + /// in a public PDS are readable by anyone. returns the raw bytes owned by 203 + /// `arena`. 204 + pub fn getBlob( 205 + arena: Allocator, 206 + transport: *zat.HttpTransport, 207 + pds: []const u8, 208 + did: []const u8, 209 + cid: []const u8, 210 + ) ResolveError![]const u8 { 211 + const url = try std.fmt.allocPrint( 212 + arena, 213 + "{s}/xrpc/com.atproto.sync.getBlob?did={s}&cid={s}", 214 + .{ pds, did, cid }, 215 + ); 216 + const result = transport.fetch(.{ .url = url }) catch return error.RequestFailed; 217 + if (result.status != .ok) return error.ResolveFailed; 218 + // result.body is transport-owned; dup into the caller's arena so it 219 + // outlives this call. 220 + return try arena.dupe(u8, result.body); 221 + }
+189
backend/src/record_text.zig
··· 1 + //! record_to_text heuristic — port of embedder/build_pack.py. 2 + //! 3 + //! walks arbitrary atproto record JSON and produces embedding-ready text. 4 + //! rules: 5 + //! - prepend `collection: <nsid>` so the lex name is in the vector 6 + //! - drop atproto plumbing keys ($type, cid, rev, sig, prev, version, $link) 7 + //! - drop identifier-shaped strings (DIDs, at-uris, CIDs, TID rkeys, hex hashes) 8 + //! - convert iso timestamps to year only (keeps temporal signal, not noise) 9 + //! - render as flat `key.path: value` lines 10 + //! - strongRefs and DIDs are currently dropped (v2: deref via slingshot) 11 + //! - truncate final text at 4000 chars 12 + 13 + const std = @import("std"); 14 + const json = std.json; 15 + const Allocator = std.mem.Allocator; 16 + 17 + const MAX_TEXT_LEN = 4000; 18 + 19 + const NOISE_KEYS = [_][]const u8{ 20 + "$type", "cid", "rev", "sig", "prev", "version", "$link", 21 + }; 22 + 23 + fn isNoiseKey(k: []const u8) bool { 24 + if (k.len == 0) return false; 25 + if (k[0] == '$') return true; 26 + for (NOISE_KEYS) |n| if (std.mem.eql(u8, k, n)) return true; 27 + return false; 28 + } 29 + 30 + fn isIdentifier(s: []const u8) bool { 31 + if (s.len == 0) return false; 32 + if (std.mem.startsWith(u8, s, "did:plc:")) return true; 33 + if (std.mem.startsWith(u8, s, "did:web:")) return true; 34 + if (std.mem.startsWith(u8, s, "at://")) return true; 35 + // CID v1 raw/dag-cbor 36 + if (std.mem.startsWith(u8, s, "bafy")) return true; 37 + if (std.mem.startsWith(u8, s, "bafk")) return true; 38 + // TID format: exactly 13 chars, lowercase base32 subset [2-7a-z] 39 + if (s.len == 13) { 40 + var all_tid = true; 41 + for (s) |ch| { 42 + const ok = (ch >= 'a' and ch <= 'z') or (ch >= '2' and ch <= '7'); 43 + if (!ok) { 44 + all_tid = false; 45 + break; 46 + } 47 + } 48 + if (all_tid) return true; 49 + } 50 + // hex hash of 32+ chars 51 + if (s.len >= 32) { 52 + var all_hex = true; 53 + for (s) |ch| { 54 + const ok = (ch >= '0' and ch <= '9') or (ch >= 'a' and ch <= 'f'); 55 + if (!ok) { 56 + all_hex = false; 57 + break; 58 + } 59 + } 60 + if (all_hex) return true; 61 + } 62 + return false; 63 + } 64 + 65 + fn looksLikeIsoTimestamp(s: []const u8) bool { 66 + // YYYY-MM-DDT... pattern (loose) 67 + if (s.len < 11) return false; 68 + for (0..4) |i| if (!std.ascii.isDigit(s[i])) return false; 69 + if (s[4] != '-') return false; 70 + for (5..7) |i| if (!std.ascii.isDigit(s[i])) return false; 71 + if (s[7] != '-') return false; 72 + for (8..10) |i| if (!std.ascii.isDigit(s[i])) return false; 73 + return s[10] == 'T' or s[10] == ' '; 74 + } 75 + 76 + /// shape check: is this a strongRef {uri: "at://...", cid: "bafy..."}? 77 + fn isStrongRef(v: json.Value) bool { 78 + if (v != .object) return false; 79 + const uri_v = v.object.get("uri") orelse return false; 80 + const cid_v = v.object.get("cid") orelse return false; 81 + if (uri_v != .string or cid_v != .string) return false; 82 + return std.mem.startsWith(u8, uri_v.string, "at://"); 83 + } 84 + 85 + pub const Builder = struct { 86 + buf: std.ArrayList(u8), 87 + allocator: Allocator, 88 + 89 + pub fn init(allocator: Allocator) Builder { 90 + return .{ .buf = .empty, .allocator = allocator }; 91 + } 92 + 93 + fn appendLine(self: *Builder, path: []const u8, value: []const u8) !void { 94 + if (self.buf.items.len >= MAX_TEXT_LEN) return; 95 + if (self.buf.items.len > 0) try self.buf.append(self.allocator, '\n'); 96 + if (path.len > 0) { 97 + try self.buf.appendSlice(self.allocator, path); 98 + try self.buf.appendSlice(self.allocator, ": "); 99 + } 100 + try self.buf.appendSlice(self.allocator, value); 101 + } 102 + }; 103 + 104 + /// walk a parsed record value, appending flattened text to `builder`. 105 + fn walk(builder: *Builder, node: json.Value, path: []const u8) anyerror!void { 106 + if (builder.buf.items.len >= MAX_TEXT_LEN) return; 107 + 108 + // strongRef short-circuit: drop (v2: inline dereffed content) 109 + if (isStrongRef(node)) return; 110 + 111 + switch (node) { 112 + .object => |obj| { 113 + var it = obj.iterator(); 114 + while (it.next()) |entry| { 115 + const k = entry.key_ptr.*; 116 + if (isNoiseKey(k)) continue; 117 + 118 + const next_path = if (path.len == 0) 119 + try builder.allocator.dupe(u8, k) 120 + else 121 + try std.fmt.allocPrint(builder.allocator, "{s}.{s}", .{ path, k }); 122 + 123 + try walk(builder, entry.value_ptr.*, next_path); 124 + } 125 + }, 126 + .array => |arr| { 127 + // pure string arrays → comma-joined 128 + var all_strings = true; 129 + for (arr.items) |item| { 130 + if (item != .string) { 131 + all_strings = false; 132 + break; 133 + } 134 + } 135 + if (all_strings) { 136 + var kept: std.ArrayList([]const u8) = .empty; 137 + for (arr.items) |item| { 138 + if (!isIdentifier(item.string)) { 139 + try kept.append(builder.allocator, item.string); 140 + } 141 + } 142 + if (kept.items.len > 0) { 143 + const joined = try std.mem.join(builder.allocator, ", ", kept.items); 144 + try builder.appendLine(path, joined); 145 + } 146 + } else { 147 + for (arr.items) |item| try walk(builder, item, path); 148 + } 149 + }, 150 + .string => |s| { 151 + // DID reference → drop (v2: inline profile) 152 + if (std.mem.startsWith(u8, s, "did:plc:") or std.mem.startsWith(u8, s, "did:web:")) return; 153 + if (isIdentifier(s)) return; 154 + if (looksLikeIsoTimestamp(s)) { 155 + try builder.appendLine(path, s[0..4]); // keep year only 156 + return; 157 + } 158 + try builder.appendLine(path, s); 159 + }, 160 + .integer => |i| { 161 + if (i == 0) return; 162 + const s = try std.fmt.allocPrint(builder.allocator, "{d}", .{i}); 163 + try builder.appendLine(path, s); 164 + }, 165 + .float => |f| { 166 + if (f == 0.0) return; 167 + const s = try std.fmt.allocPrint(builder.allocator, "{d}", .{f}); 168 + try builder.appendLine(path, s); 169 + }, 170 + .bool => return, // rarely meaningful for retrieval 171 + .null => return, 172 + .number_string => |s| try builder.appendLine(path, s), 173 + } 174 + } 175 + 176 + /// produce embedding-ready text for a record. result is allocated in `arena` 177 + /// and truncated at MAX_TEXT_LEN. 178 + pub fn recordToText( 179 + arena: Allocator, 180 + collection: []const u8, 181 + value: json.Value, 182 + ) ![]const u8 { 183 + var b = Builder.init(arena); 184 + // prepend the NSID — strongest structural signal 185 + try b.appendLine("collection", collection); 186 + try walk(&b, value, ""); 187 + const len = @min(b.buf.items.len, MAX_TEXT_LEN); 188 + return b.buf.items[0..len]; 189 + }
+182
backend/src/repo_walk.zig
··· 1 + //! single-request repo walk via `com.atproto.sync.getRepo`. 2 + //! 3 + //! the previous walker used paginated `listRecords` on every collection — 4 + //! ~172 HTTP round trips for a 17k-record repo. this does one request, 5 + //! pulls the entire repo as a CAR file (DAG-CBOR blocks + MST structure), 6 + //! and yields records by walking the MST locally. 7 + //! 8 + //! relies on zat for the protocol primitives: 9 + //! - zat.car.read / findBlock for CAR parsing 10 + //! - zat.mst.decodeMstNode for MST node decoding 11 + //! - zat.cbor.decodeAll for per-record DAG-CBOR decoding 12 + //! - zat.multibase.base32lower.encode for CID → "bafy..." stringification 13 + //! 14 + //! records come out with the same (uri, cid, value) shape as pds.listRecords 15 + //! so indexer.doIndex can swap the two walks behind a try-CAR-fallback-to- 16 + //! listRecords flow. 17 + 18 + const std = @import("std"); 19 + const Io = std.Io; 20 + const Allocator = std.mem.Allocator; 21 + const json = std.json; 22 + const zat = @import("zat"); 23 + 24 + pub const Record = struct { 25 + /// at://{did}/{collection}/{rkey} 26 + uri: []const u8, 27 + /// base32 multibase CID string ("bafy...") 28 + cid: []const u8, 29 + /// NSID, derived from the MST key's collection prefix 30 + collection: []const u8, 31 + /// parsed record body, converted from DAG-CBOR to std.json.Value so it 32 + /// plugs into the existing record_text / display pipelines 33 + value: json.Value, 34 + }; 35 + 36 + pub const WalkResult = struct { 37 + records: []Record, 38 + /// total CAR size in bytes (for logging / diagnostics) 39 + car_bytes: usize, 40 + }; 41 + 42 + pub const WalkError = error{ 43 + FetchFailed, 44 + EmptyCar, 45 + NoCommitBlock, 46 + InvalidCommit, 47 + NoDataField, 48 + InvalidDataField, 49 + BlockNotFound, 50 + InvalidMstNode, 51 + InvalidRecordCbor, 52 + OutOfMemory, 53 + }; 54 + 55 + /// Fetch the repo CAR for `did` from `pds_url`, parse it, walk the MST, and 56 + /// return every record. All results live in `arena`. 57 + pub fn walkRepo( 58 + arena: Allocator, 59 + io: Io, 60 + pds_url: []const u8, 61 + did: []const u8, 62 + ) WalkError!WalkResult { 63 + // 1. HTTP GET /xrpc/com.atproto.sync.getRepo?did=... 64 + var transport = zat.HttpTransport.init(io, arena); 65 + defer transport.deinit(); 66 + const url = std.fmt.allocPrint(arena, "{s}/xrpc/com.atproto.sync.getRepo?did={s}", .{ pds_url, did }) catch return error.OutOfMemory; 67 + const result = transport.fetch(.{ .url = url }) catch return error.FetchFailed; 68 + if (result.status != .ok) return error.FetchFailed; 69 + const car_bytes = result.body; 70 + if (car_bytes.len == 0) return error.EmptyCar; 71 + 72 + // 2. parse CAR → blocks keyed by CID. bump the safety limits well above 73 + // zat's defaults (2 MB / 10k blocks) — a 17k-record repo lands around 74 + // 20 MB with ~18k blocks, and we want headroom for larger repos. 75 + const repo_car = zat.car.readWithOptions(arena, car_bytes, .{ 76 + .max_size = 256 * 1024 * 1024, 77 + .max_blocks = 200_000, 78 + }) catch return error.EmptyCar; 79 + if (repo_car.roots.len == 0) return error.EmptyCar; 80 + 81 + // 3. commit block → data CID (MST root) 82 + const commit_cid_raw = repo_car.roots[0].raw; 83 + const commit_data = zat.car.findBlock(repo_car, commit_cid_raw) orelse return error.NoCommitBlock; 84 + const commit_value = zat.cbor.decodeAll(arena, commit_data) catch return error.InvalidCommit; 85 + const data_cbor = commit_value.get("data") orelse return error.NoDataField; 86 + const data_cid_raw = switch (data_cbor) { 87 + .cid => |c| c.raw, 88 + else => return error.InvalidDataField, 89 + }; 90 + 91 + // 4. walk the MST from the data root, yielding records in order 92 + var records: std.ArrayList(Record) = .empty; 93 + try walkNode(arena, repo_car, data_cid_raw, did, &records); 94 + 95 + return .{ 96 + .records = try records.toOwnedSlice(arena), 97 + .car_bytes = car_bytes.len, 98 + }; 99 + } 100 + 101 + fn walkNode( 102 + arena: Allocator, 103 + repo_car: zat.car.Car, 104 + node_cid: []const u8, 105 + did: []const u8, 106 + out: *std.ArrayList(Record), 107 + ) WalkError!void { 108 + const node_data = zat.car.findBlock(repo_car, node_cid) orelse return error.BlockNotFound; 109 + const node = zat.mst.decodeMstNode(arena, node_data) catch return error.InvalidMstNode; 110 + 111 + // MST invariant: left subtree holds keys strictly less than every key 112 + // in this node. walk it first so the output ends up in lexicographic 113 + // order (same contract as pds.listRecords, alphabetical by collection). 114 + if (node.left) |left_cid| try walkNode(arena, repo_car, left_cid, did, out); 115 + 116 + // prefix-compressed keys: entries[i].key = entries[i-1].key[0..prefix_len] ++ suffix. 117 + // a 512-byte reconstruction buffer is plenty — atproto MST keys are 118 + // `{collection}/{rkey}`, which maxes out around 128 bytes in practice. 119 + var key_buf: [512]u8 = undefined; 120 + for (node.entries) |entry| { 121 + if (entry.prefix_len + entry.key_suffix.len > key_buf.len) continue; 122 + @memcpy(key_buf[entry.prefix_len..][0..entry.key_suffix.len], entry.key_suffix); 123 + const key = key_buf[0 .. entry.prefix_len + entry.key_suffix.len]; 124 + 125 + const slash = std.mem.indexOfScalar(u8, key, '/') orelse continue; 126 + const collection = arena.dupe(u8, key[0..slash]) catch return error.OutOfMemory; 127 + const rkey = key[slash + 1 ..]; 128 + 129 + // record block — raw DAG-CBOR bytes 130 + const value_data = zat.car.findBlock(repo_car, entry.value) orelse continue; 131 + const value_cbor = zat.cbor.decodeAll(arena, value_data) catch continue; 132 + const value_json = try cborToJson(arena, value_cbor); 133 + 134 + const uri = std.fmt.allocPrint(arena, "at://{s}/{s}/{s}", .{ did, collection, rkey }) catch return error.OutOfMemory; 135 + const cid_str = zat.multibase.encode(arena, .base32lower, entry.value) catch return error.OutOfMemory; 136 + 137 + try out.append(arena, .{ 138 + .uri = uri, 139 + .cid = cid_str, 140 + .collection = collection, 141 + .value = value_json, 142 + }); 143 + 144 + // right subtree of THIS entry (keys between this one and the next) 145 + if (entry.tree) |tree_cid| try walkNode(arena, repo_car, tree_cid, did, out); 146 + } 147 + } 148 + 149 + /// Convert a zat.cbor.Value into a std.json.Value. This exists so records 150 + /// pulled through the CAR walker can flow through the same record_text / 151 + /// display pipeline as records from pds.listRecords (which returns 152 + /// json.Value already, since listRecords is a JSON XRPC endpoint). 153 + /// 154 + /// dropped: byte strings and CID refs. record_text treats them as noise 155 + /// anyway — CID-shaped strings get filtered by isIdentifier, and raw bytes 156 + /// don't contribute semantic signal. 157 + fn cborToJson(arena: Allocator, cv: zat.cbor.Value) WalkError!json.Value { 158 + return switch (cv) { 159 + .unsigned => |u| .{ .integer = std.math.cast(i64, u) orelse return .{ .null = {} } }, 160 + .negative => |n| .{ .integer = n }, 161 + .text => |t| .{ .string = arena.dupe(u8, t) catch return error.OutOfMemory }, 162 + .boolean => |b| .{ .bool = b }, 163 + .null => .{ .null = {} }, 164 + .bytes => .{ .null = {} }, 165 + .cid => .{ .null = {} }, 166 + .array => |arr| blk: { 167 + var items: json.Array = .init(arena); 168 + items.ensureTotalCapacity(arr.len) catch return error.OutOfMemory; 169 + for (arr) |v| items.appendAssumeCapacity(try cborToJson(arena, v)); 170 + break :blk .{ .array = items }; 171 + }, 172 + .map => |entries| blk: { 173 + var obj: json.ObjectMap = .init(arena); 174 + obj.ensureTotalCapacity(@intCast(entries.len)) catch return error.OutOfMemory; 175 + for (entries) |e| { 176 + const key_copy = arena.dupe(u8, e.key) catch return error.OutOfMemory; 177 + obj.putAssumeCapacity(key_copy, try cborToJson(arena, e.value)); 178 + } 179 + break :blk .{ .object = obj }; 180 + }, 181 + }; 182 + }
+857
backend/src/server.zig
··· 1 + //! HTTP routing for embed-on-pds. 2 + //! 3 + //! routes: 4 + //! GET / → index.html (embedded) 5 + //! GET /style.css → style.css (embedded) 6 + //! GET /main.js → main.js (embedded) 7 + //! GET /health → {status, phase} 8 + //! POST /api/embed → {dim, vector} (debug endpoint) 9 + //! POST /api/index/:handle → kick off indexing job; returns {status} 10 + //! GET /api/status/:handle → {status, records_fetched, records_embedded, ...} 11 + //! GET /api/search/:handle?q= → {results: [{uri, collection, title, body, date}, ...]} 12 + 13 + const std = @import("std"); 14 + const Io = std.Io; 15 + const http = std.http; 16 + const mem = std.mem; 17 + const json = std.json; 18 + const Thread = std.Thread; 19 + const Allocator = mem.Allocator; 20 + 21 + const Llama = @import("llama.zig"); 22 + const indexer = @import("indexer.zig"); 23 + const pds = @import("pds.zig"); 24 + const oauth = @import("oauth.zig"); 25 + const store = @import("state.zig"); 26 + const zat = @import("zat"); 27 + 28 + const HTTP_BUF_SIZE = 65536; 29 + 30 + // ---------- embedded viewer assets ---------- 31 + 32 + const INDEX_HTML = @embedFile("assets/index.html"); 33 + const STYLE_CSS = @embedFile("assets/style.css"); 34 + const MAIN_JS = @embedFile("assets/main.js"); 35 + const FAVICON_SVG = @embedFile("assets/favicon.svg"); 36 + const OG_SVG = @embedFile("assets/og.svg"); 37 + // isolated special-casing, can be removed without touching anything else. 38 + // see the header comment in the file itself. 39 + const NSID_LOGO_JS = @embedFile("assets/nsid-logo.js"); 40 + 41 + // ---------- app state ---------- 42 + 43 + pub const App = struct { 44 + io: Io, 45 + allocator: Allocator, 46 + embedder: *Llama.Embedder, 47 + embedder_mutex: *Io.Mutex, 48 + cache: *indexer.Cache, 49 + /// dev-only: skip the consent gate in handleIndex. set from 50 + /// EMBED_DEV_NOAUTH=1 at startup. never true in prod. 51 + dev_noauth: bool = false, 52 + }; 53 + 54 + // ---------- connection handling ---------- 55 + 56 + pub fn handleConnection(stream: Io.net.Stream, app: *App) void { 57 + defer stream.close(app.io); 58 + 59 + var read_buffer: [HTTP_BUF_SIZE]u8 = undefined; 60 + var write_buffer: [HTTP_BUF_SIZE]u8 = undefined; 61 + 62 + var reader = stream.reader(app.io, &read_buffer); 63 + var writer = stream.writer(app.io, &write_buffer); 64 + 65 + var srv = http.Server.init(&reader.interface, &writer.interface); 66 + 67 + while (true) { 68 + var request = srv.receiveHead() catch |err| { 69 + if (err != error.HttpConnectionClosing and err != error.EndOfStream) { 70 + std.log.debug("http receive error: {t}", .{err}); 71 + } 72 + return; 73 + }; 74 + 75 + handleRequest(&request, app) catch |err| { 76 + std.log.err("request error: {t}", .{err}); 77 + return; 78 + }; 79 + 80 + if (!request.head.keep_alive) return; 81 + } 82 + } 83 + 84 + fn handleRequest(request: *http.Server.Request, app: *App) !void { 85 + if (request.head.method == .OPTIONS) { 86 + try sendCorsPreflight(request); 87 + return; 88 + } 89 + 90 + const target = request.head.target; 91 + const path = if (mem.indexOf(u8, target, "?")) |qi| target[0..qi] else target; 92 + const query = if (mem.indexOf(u8, target, "?")) |qi| target[qi + 1 ..] else ""; 93 + 94 + if (mem.eql(u8, path, "/")) { 95 + try handleIndexHtml(request, app, query); 96 + } else if (mem.eql(u8, path, "/style.css")) { 97 + try sendAsset(request, STYLE_CSS, "text/css; charset=utf-8"); 98 + } else if (mem.eql(u8, path, "/main.js")) { 99 + try sendAsset(request, MAIN_JS, "application/javascript; charset=utf-8"); 100 + } else if (mem.eql(u8, path, "/nsid-logo.js")) { 101 + try sendAsset(request, NSID_LOGO_JS, "application/javascript; charset=utf-8"); 102 + } else if (mem.eql(u8, path, "/favicon.svg") or mem.eql(u8, path, "/favicon.ico")) { 103 + try sendAsset(request, FAVICON_SVG, "image/svg+xml"); 104 + } else if (mem.eql(u8, path, "/og.svg")) { 105 + try sendAsset(request, OG_SVG, "image/svg+xml"); 106 + } else if (mem.eql(u8, path, "/health")) { 107 + try sendJson(request, "{\"status\":\"ok\",\"phase\":\"F\"}"); 108 + } else if (mem.eql(u8, path, "/oauth-client-metadata.json")) { 109 + try oauth.handleClientMetadata(request); 110 + } else if (mem.eql(u8, path, "/oauth/jwks")) { 111 + try oauth.handleJwks(request); 112 + } else if (mem.eql(u8, path, "/oauth/login")) { 113 + try oauth.handleLogin(request); 114 + } else if (mem.eql(u8, path, "/oauth/callback")) { 115 + try oauth.handleCallback(request); 116 + } else if (mem.eql(u8, path, "/oauth/logout") and request.head.method == .POST) { 117 + try handleLogout(request, app); 118 + } else if (mem.eql(u8, path, "/api/me")) { 119 + try handleMe(request, app); 120 + } else if (mem.eql(u8, path, "/api/pack/save") and request.head.method == .POST) { 121 + try handlePackSave(request, app); 122 + } else if (mem.eql(u8, path, "/api/pack/delete") and request.head.method == .POST) { 123 + try handlePackDelete(request, app); 124 + } else if (mem.eql(u8, path, "/api/embed") and request.head.method == .POST) { 125 + try handleEmbed(request, app); 126 + } else if (mem.startsWith(u8, path, "/api/index/") and request.head.method == .POST) { 127 + const handle = path["/api/index/".len..]; 128 + try handleIndex(request, app, handle); 129 + } else if (mem.startsWith(u8, path, "/api/share-load/") and request.head.method == .POST) { 130 + const handle = path["/api/share-load/".len..]; 131 + try handleShareLoad(request, app, handle); 132 + } else if (mem.startsWith(u8, path, "/api/status/")) { 133 + const handle = path["/api/status/".len..]; 134 + try handleStatus(request, app, handle); 135 + } else if (mem.startsWith(u8, path, "/api/search/")) { 136 + const handle = path["/api/search/".len..]; 137 + try handleSearch(request, app, handle, query); 138 + } else { 139 + try sendNotFound(request); 140 + } 141 + } 142 + 143 + // ---------- route: GET / (with optional ?handle=X&q=Y for share previews) ---------- 144 + 145 + /// serve the SPA shell. when `?handle=X&q=Y` is present in the query string 146 + /// we treat the request as a share view and inject Open Graph + Twitter 147 + /// card meta tags into the HTML head so social crawlers (bsky, slack, 148 + /// discord, telegram, etc) render a preview that names the handle and the 149 + /// query. crawlers don't run JS, so static-tag injection is the only way 150 + /// to get useful previews from a single-page app. 151 + fn handleIndexHtml(request: *http.Server.Request, app: *App, query: []const u8) !void { 152 + var arena = std.heap.ArenaAllocator.init(app.allocator); 153 + defer arena.deinit(); 154 + const alloc = arena.allocator(); 155 + 156 + const handle_raw = extractQueryParam(query, "handle"); 157 + const q_raw = extractQueryParam(query, "q"); 158 + 159 + const placeholder = "<!--OG_META_PLACEHOLDER-->"; 160 + const default_meta = 161 + \\<meta property="og:title" content="ken"> 162 + \\<meta property="og:description" content="fuzzy find any record in your atproto repo. vectors live on your own PDS."> 163 + \\<meta property="og:image" content="https://ken.waow.tech/og.svg"> 164 + \\<meta property="og:url" content="https://ken.waow.tech/"> 165 + \\<meta property="og:type" content="website"> 166 + \\<meta name="twitter:card" content="summary_large_image"> 167 + \\<meta name="twitter:title" content="ken"> 168 + \\<meta name="twitter:description" content="fuzzy find any record in your atproto repo."> 169 + \\<meta name="twitter:image" content="https://ken.waow.tech/og.svg"> 170 + ; 171 + 172 + // when share params are present, build a per-query preview 173 + const meta_block: []const u8 = if (handle_raw != null and q_raw != null) blk: { 174 + const handle = try urlDecode(alloc, handle_raw.?); 175 + const q = try urlDecode(alloc, q_raw.?); 176 + // escape for HTML attribute context 177 + const handle_e = try htmlAttrEscape(alloc, handle); 178 + const q_e = try htmlAttrEscape(alloc, q); 179 + const handle_u = try urlEncode(alloc, handle); 180 + const q_u = try urlEncode(alloc, q); 181 + break :blk try std.fmt.allocPrint(alloc, 182 + \\<meta property="og:title" content="ken — &quot;{s}&quot; in @{s}'s records"> 183 + \\<meta property="og:description" content="semantic search across @{s}'s atproto repo."> 184 + \\<meta property="og:image" content="https://ken.waow.tech/og.svg"> 185 + \\<meta property="og:url" content="https://ken.waow.tech/?handle={s}&amp;q={s}"> 186 + \\<meta property="og:type" content="website"> 187 + \\<meta name="twitter:card" content="summary_large_image"> 188 + \\<meta name="twitter:title" content="ken — &quot;{s}&quot; in @{s}'s records"> 189 + \\<meta name="twitter:description" content="semantic search across @{s}'s atproto repo."> 190 + \\<meta name="twitter:image" content="https://ken.waow.tech/og.svg"> 191 + , .{ q_e, handle_e, handle_e, handle_u, q_u, q_e, handle_e, handle_e }); 192 + } else default_meta; 193 + 194 + // substitute the placeholder. INDEX_HTML is comptime-known but the 195 + // substituted result is per-request, so we build it in the arena. 196 + const idx = mem.indexOf(u8, INDEX_HTML, placeholder) orelse { 197 + // placeholder missing — fall back to serving the static page 198 + try sendAsset(request, INDEX_HTML, "text/html; charset=utf-8"); 199 + return; 200 + }; 201 + const out = try std.fmt.allocPrint(alloc, "{s}{s}{s}", .{ 202 + INDEX_HTML[0..idx], 203 + meta_block, 204 + INDEX_HTML[idx + placeholder.len ..], 205 + }); 206 + try request.respond(out, .{ 207 + .status = .ok, 208 + .extra_headers = &.{ 209 + .{ .name = "content-type", .value = "text/html; charset=utf-8" }, 210 + .{ .name = "cache-control", .value = "no-store" }, 211 + }, 212 + }); 213 + } 214 + 215 + /// pull a single param value out of a URL query string. returns the raw 216 + /// (still percent-encoded) value, or null if not found. caller is 217 + /// responsible for urlDecode if they want the cleartext form. 218 + fn extractQueryParam(query: []const u8, name: []const u8) ?[]const u8 { 219 + var it = mem.splitScalar(u8, query, '&'); 220 + while (it.next()) |param| { 221 + if (mem.startsWith(u8, param, name) and 222 + param.len > name.len and param[name.len] == '=') 223 + { 224 + return param[name.len + 1 ..]; 225 + } 226 + } 227 + return null; 228 + } 229 + 230 + /// percent-encode unsafe URL characters. used for building the canonical 231 + /// share URL inside og:url. lazy implementation; only encodes the bytes 232 + /// that actually break URLs in attribute context. 233 + fn urlEncode(alloc: Allocator, s: []const u8) ![]const u8 { 234 + var out: std.ArrayList(u8) = .empty; 235 + for (s) |ch| { 236 + const safe = (ch >= 'a' and ch <= 'z') or 237 + (ch >= 'A' and ch <= 'Z') or 238 + (ch >= '0' and ch <= '9') or 239 + ch == '-' or ch == '_' or ch == '.' or ch == '~'; 240 + if (safe) { 241 + try out.append(alloc, ch); 242 + } else { 243 + try out.print(alloc, "%{X:0>2}", .{ch}); 244 + } 245 + } 246 + return out.items; 247 + } 248 + 249 + /// minimal HTML attribute escape — enough for OG content="" injection. 250 + fn htmlAttrEscape(alloc: Allocator, s: []const u8) ![]const u8 { 251 + var out: std.ArrayList(u8) = .empty; 252 + for (s) |ch| { 253 + switch (ch) { 254 + '&' => try out.appendSlice(alloc, "&amp;"), 255 + '"' => try out.appendSlice(alloc, "&quot;"), 256 + '<' => try out.appendSlice(alloc, "&lt;"), 257 + '>' => try out.appendSlice(alloc, "&gt;"), 258 + else => try out.append(alloc, ch), 259 + } 260 + } 261 + return out.items; 262 + } 263 + 264 + // ---------- route: /oauth/logout ---------- 265 + 266 + /// wraps oauth.handleLogout with cache eviction. this is the crucial bit: 267 + /// the in-memory IndexedPack is keyed by DID, not by session, so a bare 268 + /// session delete leaves the pack sitting in the cache. next sign-in hits 269 + /// the same DID → cache hit → returns the stale pack → user never sees 270 + /// new records. evicting on logout means the next sign-in runs a fresh 271 + /// walk, which still reuses vectors via loadExistingPack so it stays fast. 272 + fn handleLogout(request: *http.Server.Request, app: *App) !void { 273 + if (oauth.getSessionDid(request)) |did| { 274 + app.cache.remove(app.io, did); 275 + } 276 + try oauth.handleLogout(request); 277 + } 278 + 279 + // ---------- route: /api/me ---------- 280 + 281 + /// returns {handle, did} for the authenticated session, or 401 otherwise. 282 + /// the frontend uses this to decide whether to show the signed-in UI. 283 + fn handleMe(request: *http.Server.Request, app: *App) !void { 284 + var arena = std.heap.ArenaAllocator.init(app.allocator); 285 + defer arena.deinit(); 286 + const alloc = arena.allocator(); 287 + 288 + const did = oauth.getSessionDid(request) orelse { 289 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"not signed in\"}"); 290 + return; 291 + }; 292 + const session = (try store.getSession(alloc, did)) orelse { 293 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"session expired\"}"); 294 + return; 295 + }; 296 + 297 + var buf: std.ArrayList(u8) = .empty; 298 + try buf.appendSlice(alloc, "{\"handle\":"); 299 + try writeJsonString(&buf, alloc, session.handle); 300 + try buf.appendSlice(alloc, ",\"did\":"); 301 + try writeJsonString(&buf, alloc, session.did); 302 + try buf.appendSlice(alloc, "}"); 303 + try sendJson(request, buf.items); 304 + } 305 + 306 + // ---------- routes: /api/pack/save + /api/pack/delete ---------- 307 + // 308 + // writing a tech.waow.ken.pack record (and its blobs) to the user's PDS 309 + // is an explicit opt-in. both endpoints require the session's DID to match 310 + // the pack's DID — you can only mutate your own repo. 311 + 312 + /// `POST /api/pack/save` → serialize the in-memory pack to PDS blobs + 313 + /// record. updates `pack.persisted_uri` so future status calls report the 314 + /// pack as saved. safe to call repeatedly (writes a new record each time 315 + /// and updates the pointer; the old record becomes orphaned unless the 316 + /// caller deletes first via /api/pack/delete). 317 + fn handlePackSave(request: *http.Server.Request, app: *App) !void { 318 + var arena = std.heap.ArenaAllocator.init(app.allocator); 319 + defer arena.deinit(); 320 + const alloc = arena.allocator(); 321 + 322 + const did = oauth.getSessionDid(request) orelse { 323 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"not signed in\"}"); 324 + return; 325 + }; 326 + const session = (try store.getSession(alloc, did)) orelse { 327 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"session expired\"}"); 328 + return; 329 + }; 330 + 331 + const pack = app.cache.get(app.io, did) orelse { 332 + try sendJsonStatus(request, .not_found, "{\"error\":\"no pack for this did — index first\"}"); 333 + return; 334 + }; 335 + if (pack.status != .ready) { 336 + try sendJsonStatus(request, .service_unavailable, "{\"error\":\"pack still indexing\"}"); 337 + return; 338 + } 339 + 340 + // capture the old URI before we write the new one. replace semantics: 341 + // write-then-delete-old, so a write failure leaves the prior pack 342 + // intact and we never end up with zero packs during the swap. also 343 + // prevents duplicate accumulation from repeated save clicks. 344 + const old_uri = pack.persisted_uri; 345 + 346 + // dupe into pack's own arena so persisted_uri outlives the request 347 + const pack_alloc = pack.arena.allocator(); 348 + const new_uri = indexer.writePackToPds(pack_alloc, session, pack) catch |err| { 349 + std.log.warn("handlePackSave: writePackToPds failed: {t}", .{err}); 350 + try sendJsonStatus(request, .internal_server_error, "{\"error\":\"failed to save pack\"}"); 351 + return; 352 + }; 353 + pack.persisted_uri = new_uri; 354 + 355 + // best-effort delete of the prior pack record. if this fails, the 356 + // orphan lingers but the user can sweep it with the delete button. 357 + if (old_uri) |ou| { 358 + if (extractRkey(ou)) |rkey| { 359 + oauth.deleteRecord(alloc, session, "tech.waow.ken.pack", rkey) catch |err| { 360 + std.log.warn("handlePackSave: deleteRecord(old pack) failed: {t}", .{err}); 361 + }; 362 + } 363 + } 364 + 365 + var buf: std.ArrayList(u8) = .empty; 366 + try buf.appendSlice(alloc, "{\"ok\":true,\"persisted_uri\":"); 367 + try writeJsonString(&buf, alloc, new_uri); 368 + try buf.appendSlice(alloc, "}"); 369 + try sendJson(request, buf.items); 370 + } 371 + 372 + /// pull the rkey (last path segment) out of an at-uri like 373 + /// `at://did:plc:.../tech.waow.ken.pack/3abc`. returns null on malformed input. 374 + fn extractRkey(at_uri: []const u8) ?[]const u8 { 375 + const idx = mem.lastIndexOfScalar(u8, at_uri, '/') orelse return null; 376 + if (idx + 1 >= at_uri.len) return null; 377 + return at_uri[idx + 1 ..]; 378 + } 379 + 380 + /// `POST /api/pack/delete` → delete the tech.waow.ken.pack record from 381 + /// the user's PDS. the in-memory pack stays valid for search; only the 382 + /// persisted copy is removed. safe to call if nothing is persisted. 383 + fn handlePackDelete(request: *http.Server.Request, app: *App) !void { 384 + var arena = std.heap.ArenaAllocator.init(app.allocator); 385 + defer arena.deinit(); 386 + const alloc = arena.allocator(); 387 + 388 + const did = oauth.getSessionDid(request) orelse { 389 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"not signed in\"}"); 390 + return; 391 + }; 392 + const session = (try store.getSession(alloc, did)) orelse { 393 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"session expired\"}"); 394 + return; 395 + }; 396 + 397 + const pack = app.cache.get(app.io, did) orelse { 398 + try sendJsonStatus(request, .not_found, "{\"error\":\"no pack for this did\"}"); 399 + return; 400 + }; 401 + 402 + // if the pack isn't persisted right now, we still want to be able to 403 + // clean up any stray pack records left from earlier runs — list and 404 + // delete everything in the collection. this makes the delete button 405 + // idempotent from the user's point of view: "there is nothing of mine 406 + // on the PDS in this collection" is the postcondition. 407 + var transport = zat.HttpTransport.init(app.io, alloc); 408 + defer transport.deinit(); 409 + 410 + const PACK_COLLECTION = "tech.waow.ken.pack"; 411 + var deleted: usize = 0; 412 + var cursor: ?[]const u8 = null; 413 + while (true) { 414 + const page = pds.listRecords(alloc, &transport, pack.pds_url, pack.did, PACK_COLLECTION, cursor, 100) catch break; 415 + for (page.records) |r| { 416 + const rkey = blk: { 417 + const at_uri = zat.AtUri.parse(r.uri) orelse continue; 418 + break :blk at_uri.rkey() orelse continue; 419 + }; 420 + oauth.deleteRecord(alloc, session, PACK_COLLECTION, rkey) catch |err| { 421 + std.log.warn(" deleteRecord({s}) failed: {t}", .{ rkey, err }); 422 + continue; 423 + }; 424 + deleted += 1; 425 + } 426 + cursor = page.cursor; 427 + if (cursor == null or page.records.len == 0) break; 428 + } 429 + 430 + pack.persisted_uri = null; 431 + 432 + var buf: std.ArrayList(u8) = .empty; 433 + try buf.print(alloc, "{{\"ok\":true,\"deleted\":{d}}}", .{deleted}); 434 + try sendJson(request, buf.items); 435 + } 436 + 437 + // ---------- route: /api/embed (debug) ---------- 438 + 439 + const EmbedRequest = struct { q: []const u8 }; 440 + 441 + fn handleEmbed(request: *http.Server.Request, app: *App) !void { 442 + var arena = std.heap.ArenaAllocator.init(app.allocator); 443 + defer arena.deinit(); 444 + const alloc = arena.allocator(); 445 + 446 + const body_reader = request.readerExpectContinue(&.{}) catch |err| { 447 + std.log.warn("reader init failed: {t}", .{err}); 448 + try sendJsonStatus(request, .bad_request, "{\"error\":\"failed to read body\"}"); 449 + return; 450 + }; 451 + const body = body_reader.allocRemaining(alloc, Io.Limit.limited(16 * 1024)) catch |err| { 452 + std.log.warn("failed to read /api/embed body: {t}", .{err}); 453 + try sendJsonStatus(request, .bad_request, "{\"error\":\"failed to read body\"}"); 454 + return; 455 + }; 456 + 457 + const parsed = json.parseFromSliceLeaky(EmbedRequest, alloc, body, .{ 458 + .ignore_unknown_fields = true, 459 + }) catch { 460 + try sendJsonStatus(request, .bad_request, "{\"error\":\"invalid json\"}"); 461 + return; 462 + }; 463 + if (parsed.q.len == 0 or parsed.q.len > 2000) { 464 + try sendJsonStatus(request, .bad_request, "{\"error\":\"bad query length\"}"); 465 + return; 466 + } 467 + 468 + const vec: []f32 = blk: { 469 + app.embedder_mutex.lockUncancelable(app.io); 470 + defer app.embedder_mutex.unlock(app.io); 471 + break :blk app.embedder.embed(alloc, parsed.q) catch |err| { 472 + std.log.err("embed error: {t}", .{err}); 473 + try sendJsonStatus(request, .internal_server_error, "{\"error\":\"embed failed\"}"); 474 + return; 475 + }; 476 + }; 477 + 478 + var buf: std.ArrayList(u8) = .empty; 479 + try buf.print(alloc, "{{\"dim\":{d},\"vector\":[", .{vec.len}); 480 + for (vec, 0..) |v, i| { 481 + if (i > 0) try buf.appendSlice(alloc, ","); 482 + try buf.print(alloc, "{d:.8}", .{v}); 483 + } 484 + try buf.appendSlice(alloc, "]}"); 485 + 486 + try sendJson(request, buf.items); 487 + } 488 + 489 + // ---------- route: /api/index/:handle ---------- 490 + 491 + /// shared kick-off used by both `POST /api/index/:handle` (consent-gated) 492 + /// and `POST /api/share-load/:handle` (public). resolves the handle, checks 493 + /// the cache, reserves a new pack if needed, spawns the indexer worker, 494 + /// and writes the status response. doIndex is read-only now (no auto-write 495 + /// to PDS), so this is safe to call for any handle without consent. 496 + fn kickoffIndexing(request: *http.Server.Request, app: *App, handle: []const u8) !void { 497 + var resolve_arena = std.heap.ArenaAllocator.init(app.allocator); 498 + defer resolve_arena.deinit(); 499 + const ra = resolve_arena.allocator(); 500 + 501 + var transport = zat.HttpTransport.init(app.io, ra); 502 + defer transport.deinit(); 503 + 504 + const identity = pds.resolveIdentity(ra, &transport, handle) catch |err| { 505 + std.log.warn("resolveIdentity failed for {s}: {t}", .{ handle, err }); 506 + try sendJsonStatus(request, .not_found, "{\"error\":\"failed to resolve handle\"}"); 507 + return; 508 + }; 509 + 510 + if (app.cache.get(app.io, identity.did)) |existing| { 511 + try writeStatusResponse(request, app, existing); 512 + return; 513 + } 514 + 515 + const pack = (try app.cache.reserve(app.io, handle, identity.did, identity.pds)) orelse { 516 + const existing = app.cache.get(app.io, identity.did).?; 517 + try writeStatusResponse(request, app, existing); 518 + return; 519 + }; 520 + 521 + const job = try app.allocator.create(indexer.Job); 522 + job.* = .{ 523 + .allocator = app.allocator, 524 + .io = app.io, 525 + .cache = app.cache, 526 + .pack = pack, 527 + .embedder = app.embedder, 528 + .embedder_mutex = app.embedder_mutex, 529 + .max_per_collection = 0, 530 + }; 531 + 532 + const t = Thread.spawn(.{}, indexerWorker, .{job}) catch |err| { 533 + std.log.err("failed to spawn indexer thread: {t}", .{err}); 534 + app.allocator.destroy(job); 535 + try sendJsonStatus(request, .internal_server_error, "{\"error\":\"failed to spawn worker\"}"); 536 + return; 537 + }; 538 + t.detach(); 539 + 540 + try writeStatusResponse(request, app, pack); 541 + } 542 + 543 + /// `POST /api/index/:handle` — owner-gated indexing entry point. ensures 544 + /// the requesting user can only index their own DID, then delegates to 545 + /// kickoffIndexing. 546 + fn handleIndex(request: *http.Server.Request, app: *App, handle: []const u8) !void { 547 + if (handle.len == 0 or handle.len > 253) { 548 + try sendJsonStatus(request, .bad_request, "{\"error\":\"invalid handle\"}"); 549 + return; 550 + } 551 + 552 + // dev-only escape hatch for local smoke tests 553 + const session_did_opt: ?[]const u8 = blk: { 554 + if (app.dev_noauth) break :blk null; 555 + break :blk oauth.getSessionDid(request) orelse { 556 + try sendJsonStatus(request, .unauthorized, "{\"error\":\"sign in to index your own records\"}"); 557 + return; 558 + }; 559 + }; 560 + 561 + if (session_did_opt) |session_did| { 562 + // resolve to verify ownership before kicking anything off 563 + var arena = std.heap.ArenaAllocator.init(app.allocator); 564 + defer arena.deinit(); 565 + const ra = arena.allocator(); 566 + var transport = zat.HttpTransport.init(app.io, ra); 567 + defer transport.deinit(); 568 + const identity = pds.resolveIdentity(ra, &transport, handle) catch { 569 + try sendJsonStatus(request, .not_found, "{\"error\":\"failed to resolve handle\"}"); 570 + return; 571 + }; 572 + if (!mem.eql(u8, identity.did, session_did)) { 573 + try sendJsonStatus(request, .forbidden, "{\"error\":\"you can only index your own records\"}"); 574 + return; 575 + } 576 + } 577 + 578 + try kickoffIndexing(request, app, handle); 579 + } 580 + 581 + /// `POST /api/share-load/:handle` — public load. spawns the same indexer 582 + /// worker as /api/index, but with no consent gate. this is safe because 583 + /// doIndex is read-only: it walks the user's PDS via CAR (public), pulls 584 + /// their already-published `tech.waow.ken.pack` record (public), and 585 + /// reuses every vector by (uri, cid). nothing gets written. the resulting 586 + /// pack lands in the in-memory cache so /api/search can serve queries 587 + /// against it. takes a few seconds for a 17k-record repo. 588 + fn handleShareLoad(request: *http.Server.Request, app: *App, handle: []const u8) !void { 589 + if (handle.len == 0 or handle.len > 253) { 590 + try sendJsonStatus(request, .bad_request, "{\"error\":\"invalid handle\"}"); 591 + return; 592 + } 593 + try kickoffIndexing(request, app, handle); 594 + } 595 + 596 + fn indexerWorker(job: *indexer.Job) void { 597 + defer job.allocator.destroy(job); 598 + indexer.runJob(job); 599 + } 600 + 601 + // ---------- route: /api/status/:handle ---------- 602 + 603 + fn handleStatus(request: *http.Server.Request, app: *App, handle: []const u8) !void { 604 + // resolve handle → DID to find the cached pack 605 + var arena = std.heap.ArenaAllocator.init(app.allocator); 606 + defer arena.deinit(); 607 + const ra = arena.allocator(); 608 + 609 + var transport = zat.HttpTransport.init(app.io, ra); 610 + defer transport.deinit(); 611 + 612 + const identity = pds.resolveIdentity(ra, &transport, handle) catch { 613 + try sendJsonStatus(request, .not_found, "{\"error\":\"failed to resolve handle\"}"); 614 + return; 615 + }; 616 + 617 + if (app.cache.get(app.io, identity.did)) |pack| { 618 + try writeStatusResponse(request, app, pack); 619 + } else { 620 + try sendJsonStatus(request, .not_found, "{\"error\":\"not indexed\"}"); 621 + } 622 + } 623 + 624 + fn writeStatusResponse(request: *http.Server.Request, app: *App, pack: *indexer.IndexedPack) !void { 625 + var arena = std.heap.ArenaAllocator.init(app.allocator); 626 + defer arena.deinit(); 627 + const alloc = arena.allocator(); 628 + 629 + var buf: std.ArrayList(u8) = .empty; 630 + try buf.appendSlice(alloc, "{"); 631 + 632 + try buf.print(alloc, "\"handle\":\"{s}\",", .{pack.handle}); 633 + try buf.print(alloc, "\"did\":\"{s}\",", .{pack.did}); 634 + 635 + const status_str = switch (pack.status) { 636 + .indexing => "indexing", 637 + .ready => "ready", 638 + .@"error" => "error", 639 + }; 640 + try buf.print(alloc, "\"status\":\"{s}\",", .{status_str}); 641 + try buf.print(alloc, "\"records_fetched\":{d},", .{pack.records_fetched}); 642 + try buf.print(alloc, "\"records_embedded\":{d},", .{pack.records_embedded}); 643 + try buf.print(alloc, "\"records_reused\":{d},", .{pack.records_reused}); 644 + try buf.print(alloc, "\"count\":{d},", .{pack.count()}); 645 + try buf.print(alloc, "\"indexed_at_ms\":{d},", .{pack.indexed_at_ms}); 646 + try buf.print(alloc, "\"build_ms\":{d},", .{pack.build_ms}); 647 + try buf.print(alloc, "\"prior_build_ms\":{d},", .{pack.prior_build_ms}); 648 + try buf.print(alloc, "\"prior_count\":{d},", .{pack.prior_count}); 649 + if (pack.persisted_uri) |uri| { 650 + try buf.appendSlice(alloc, "\"persisted\":true,\"persisted_uri\":"); 651 + try writeJsonString(&buf, alloc, uri); 652 + try buf.appendSlice(alloc, ","); 653 + } else { 654 + try buf.appendSlice(alloc, "\"persisted\":false,"); 655 + } 656 + 657 + if (pack.error_msg.len > 0) { 658 + try buf.appendSlice(alloc, "\"error_msg\":"); 659 + try writeJsonString(&buf, alloc, pack.error_msg); 660 + try buf.appendSlice(alloc, ","); 661 + } 662 + 663 + try buf.appendSlice(alloc, "\"collections\":["); 664 + for (pack.collections, 0..) |c, i| { 665 + if (i > 0) try buf.appendSlice(alloc, ","); 666 + try buf.appendSlice(alloc, "{\"nsid\":"); 667 + try writeJsonString(&buf, alloc, c.nsid); 668 + try buf.print(alloc, ",\"count\":{d}}}", .{c.count}); 669 + } 670 + try buf.appendSlice(alloc, "]}"); 671 + 672 + try sendJson(request, buf.items); 673 + } 674 + 675 + // ---------- route: /api/search/:handle?q=... ---------- 676 + 677 + fn handleSearch(request: *http.Server.Request, app: *App, handle: []const u8, query: []const u8) !void { 678 + var arena = std.heap.ArenaAllocator.init(app.allocator); 679 + defer arena.deinit(); 680 + const alloc = arena.allocator(); 681 + 682 + // parse query string: q=...&k=... 683 + var q_str: []const u8 = ""; 684 + var k: usize = 20; 685 + var it = mem.splitScalar(u8, query, '&'); 686 + while (it.next()) |param| { 687 + if (mem.startsWith(u8, param, "q=")) { 688 + q_str = try urlDecode(alloc, param[2..]); 689 + } else if (mem.startsWith(u8, param, "k=")) { 690 + k = std.fmt.parseInt(usize, param[2..], 10) catch 20; 691 + } 692 + } 693 + 694 + if (q_str.len == 0) { 695 + try sendJsonStatus(request, .bad_request, "{\"error\":\"missing q\"}"); 696 + return; 697 + } 698 + 699 + // resolve handle → DID → pack 700 + var transport = zat.HttpTransport.init(app.io, alloc); 701 + defer transport.deinit(); 702 + 703 + const identity = pds.resolveIdentity(alloc, &transport, handle) catch { 704 + try sendJsonStatus(request, .not_found, "{\"error\":\"failed to resolve handle\"}"); 705 + return; 706 + }; 707 + 708 + const pack = app.cache.get(app.io, identity.did) orelse { 709 + try sendJsonStatus(request, .not_found, "{\"error\":\"not indexed\"}"); 710 + return; 711 + }; 712 + 713 + // search works during indexing too — the pack has a `valid` bitmap 714 + // that gates which rows have been embedded so far. we just report the 715 + // current state to the client so the UI can label partial results. 716 + 717 + // embed the query 718 + const query_vec: []f32 = blk: { 719 + app.embedder_mutex.lockUncancelable(app.io); 720 + defer app.embedder_mutex.unlock(app.io); 721 + break :blk app.embedder.embed(alloc, q_str) catch |err| { 722 + std.log.err("query embed failed: {t}", .{err}); 723 + try sendJsonStatus(request, .internal_server_error, "{\"error\":\"embed failed\"}"); 724 + return; 725 + }; 726 + }; 727 + 728 + // top-k 729 + const idxs = try indexer.search(alloc, pack, query_vec, k); 730 + 731 + // render — include current indexing state so the frontend can label 732 + // results as partial while the pack is still being built. 733 + const searchable = pack.validCount(); 734 + const total = pack.entries.len; 735 + const is_indexing = pack.status == .indexing; 736 + 737 + var buf: std.ArrayList(u8) = .empty; 738 + try buf.print(alloc, "{{\"indexing\":{s},\"searchable\":{d},\"total\":{d},\"results\":[", .{ 739 + if (is_indexing) "true" else "false", 740 + searchable, 741 + total, 742 + }); 743 + for (idxs, 0..) |idx, i| { 744 + if (i > 0) try buf.appendSlice(alloc, ","); 745 + const e = pack.entries[idx]; 746 + try buf.appendSlice(alloc, "{\"uri\":"); 747 + try writeJsonString(&buf, alloc, e.uri); 748 + try buf.appendSlice(alloc, ",\"collection\":"); 749 + try writeJsonString(&buf, alloc, e.collection); 750 + try buf.appendSlice(alloc, ",\"title\":"); 751 + try writeJsonString(&buf, alloc, e.title); 752 + try buf.appendSlice(alloc, ",\"body\":"); 753 + try writeJsonString(&buf, alloc, e.body); 754 + try buf.appendSlice(alloc, ",\"date\":"); 755 + try writeJsonString(&buf, alloc, e.date); 756 + try buf.appendSlice(alloc, "}"); 757 + } 758 + try buf.appendSlice(alloc, "]}"); 759 + 760 + try sendJson(request, buf.items); 761 + } 762 + 763 + // ---------- helpers ---------- 764 + 765 + fn urlDecode(alloc: Allocator, s: []const u8) ![]const u8 { 766 + var out: std.ArrayList(u8) = .empty; 767 + var i: usize = 0; 768 + while (i < s.len) { 769 + const ch = s[i]; 770 + if (ch == '+') { 771 + try out.append(alloc, ' '); 772 + i += 1; 773 + } else if (ch == '%' and i + 2 < s.len) { 774 + const hi = std.fmt.charToDigit(s[i + 1], 16) catch { 775 + try out.append(alloc, ch); 776 + i += 1; 777 + continue; 778 + }; 779 + const lo = std.fmt.charToDigit(s[i + 2], 16) catch { 780 + try out.append(alloc, ch); 781 + i += 1; 782 + continue; 783 + }; 784 + try out.append(alloc, @intCast(hi * 16 + lo)); 785 + i += 3; 786 + } else { 787 + try out.append(alloc, ch); 788 + i += 1; 789 + } 790 + } 791 + return out.items; 792 + } 793 + 794 + fn writeJsonString(buf: *std.ArrayList(u8), alloc: Allocator, s: []const u8) !void { 795 + try buf.append(alloc, '"'); 796 + for (s) |ch| { 797 + switch (ch) { 798 + '"' => try buf.appendSlice(alloc, "\\\""), 799 + '\\' => try buf.appendSlice(alloc, "\\\\"), 800 + '\n' => try buf.appendSlice(alloc, "\\n"), 801 + '\r' => try buf.appendSlice(alloc, "\\r"), 802 + '\t' => try buf.appendSlice(alloc, "\\t"), 803 + 0...0x08, 0x0b, 0x0c, 0x0e...0x1f => { 804 + try buf.print(alloc, "\\u{x:0>4}", .{ch}); 805 + }, 806 + else => try buf.append(alloc, ch), 807 + } 808 + } 809 + try buf.append(alloc, '"'); 810 + } 811 + 812 + // ---------- response helpers ---------- 813 + 814 + fn sendJson(request: *http.Server.Request, body: []const u8) !void { 815 + try sendJsonStatus(request, .ok, body); 816 + } 817 + 818 + fn sendJsonStatus(request: *http.Server.Request, status: http.Status, body: []const u8) !void { 819 + try request.respond(body, .{ 820 + .status = status, 821 + .extra_headers = &.{ 822 + .{ .name = "content-type", .value = "application/json" }, 823 + .{ .name = "access-control-allow-origin", .value = "*" }, 824 + }, 825 + }); 826 + } 827 + 828 + fn sendAsset(request: *http.Server.Request, body: []const u8, content_type: []const u8) !void { 829 + try request.respond(body, .{ 830 + .status = .ok, 831 + .extra_headers = &.{ 832 + .{ .name = "content-type", .value = content_type }, 833 + .{ .name = "cache-control", .value = "public, max-age=300" }, 834 + }, 835 + }); 836 + } 837 + 838 + fn sendCorsPreflight(request: *http.Server.Request) !void { 839 + try request.respond("", .{ 840 + .status = .no_content, 841 + .extra_headers = &.{ 842 + .{ .name = "access-control-allow-origin", .value = "*" }, 843 + .{ .name = "access-control-allow-methods", .value = "GET, POST, OPTIONS" }, 844 + .{ .name = "access-control-allow-headers", .value = "content-type" }, 845 + }, 846 + }); 847 + } 848 + 849 + fn sendNotFound(request: *http.Server.Request) !void { 850 + try request.respond("{\"error\":\"not found\"}", .{ 851 + .status = .not_found, 852 + .extra_headers = &.{ 853 + .{ .name = "content-type", .value = "application/json" }, 854 + .{ .name = "access-control-allow-origin", .value = "*" }, 855 + }, 856 + }); 857 + }
+328
backend/src/state.zig
··· 1 + //! in-memory replacement for what was a sqlite-backed db.zig. 2 + //! 3 + //! the user said: "we don't need to mount a volume to the fly machine." so 4 + //! we don't. OAuth sessions live in memory. the fly machine stays pinned 5 + //! (min_machines_running=1, auto_stop=off) so restarts are rare and only 6 + //! happen on redeploy. when that happens, users re-auth. acceptable UX for a 7 + //! research demo, and it keeps the deployment model simple: **the PDS is the 8 + //! only persistent storage this project uses.** period. 9 + //! 10 + //! pack data lives on the user's PDS as `tech.waow.ken.pack` records + 11 + //! blobs. to "is this DID already indexed?" we call listRecords on their PDS 12 + //! directly — no local cache of pack locations. 13 + //! 14 + //! public surface mirrors the old db.zig so oauth.zig and indexer.zig can 15 + //! swap `@import` without other changes. 16 + 17 + const std = @import("std"); 18 + const Io = std.Io; 19 + const Allocator = std.mem.Allocator; 20 + 21 + var gpa: Allocator = undefined; 22 + var io: Io = undefined; 23 + pub var mutex: Io.Mutex = .init; 24 + 25 + /// heap-owned storage for an auth request (transient — seconds between 26 + /// /oauth/login and /oauth/callback). all string fields own their bytes. 27 + const StoredAuthRequest = struct { 28 + state: []u8, 29 + authserver_iss: []u8, 30 + did: []u8, 31 + handle: []u8, 32 + pds_url: []u8, 33 + pkce_verifier: []u8, 34 + scope: []u8, 35 + dpop_authserver_nonce: []u8, 36 + dpop_private_key: []u8, 37 + created_at: i64, 38 + 39 + fn deinit(self: *StoredAuthRequest, a: Allocator) void { 40 + a.free(self.state); 41 + a.free(self.authserver_iss); 42 + a.free(self.did); 43 + a.free(self.handle); 44 + a.free(self.pds_url); 45 + a.free(self.pkce_verifier); 46 + a.free(self.scope); 47 + a.free(self.dpop_authserver_nonce); 48 + a.free(self.dpop_private_key); 49 + } 50 + }; 51 + 52 + const StoredSession = struct { 53 + did: []u8, 54 + handle: []u8, 55 + pds_url: []u8, 56 + authserver_iss: []u8, 57 + access_token: []u8, 58 + refresh_token: []u8, 59 + dpop_authserver_nonce: []u8, 60 + dpop_pds_nonce: []u8, 61 + dpop_private_key: []u8, 62 + created_at: i64, 63 + 64 + fn deinit(self: *StoredSession, a: Allocator) void { 65 + a.free(self.did); 66 + a.free(self.handle); 67 + a.free(self.pds_url); 68 + a.free(self.authserver_iss); 69 + a.free(self.access_token); 70 + a.free(self.refresh_token); 71 + a.free(self.dpop_authserver_nonce); 72 + a.free(self.dpop_pds_nonce); 73 + a.free(self.dpop_private_key); 74 + } 75 + }; 76 + 77 + const StoredExchange = struct { 78 + did: []u8, 79 + created_at: i64, 80 + }; 81 + 82 + var auth_requests: std.StringHashMap(StoredAuthRequest) = undefined; 83 + var sessions: std.StringHashMap(StoredSession) = undefined; 84 + var exchange_tokens: std.StringHashMap(StoredExchange) = undefined; 85 + 86 + fn timestamp() i64 { 87 + return @intCast(@divFloor(Io.Timestamp.now(io, .real).nanoseconds, std.time.ns_per_s)); 88 + } 89 + 90 + pub fn init(app_io: Io, app_allocator: Allocator) void { 91 + io = app_io; 92 + gpa = app_allocator; 93 + auth_requests = std.StringHashMap(StoredAuthRequest).init(gpa); 94 + sessions = std.StringHashMap(StoredSession).init(gpa); 95 + exchange_tokens = std.StringHashMap(StoredExchange).init(gpa); 96 + std.log.info("state: in-memory (oauth sessions reset on restart)", .{}); 97 + } 98 + 99 + pub fn close() void {} 100 + 101 + // --------------------------------------------------------------------------- 102 + // public structs — use arena-allocated strings returned to callers 103 + // --------------------------------------------------------------------------- 104 + 105 + pub const AuthRequest = struct { 106 + state: []const u8, 107 + authserver_iss: []const u8, 108 + did: []const u8, 109 + handle: []const u8, 110 + pds_url: []const u8, 111 + pkce_verifier: []const u8, 112 + scope: []const u8, 113 + dpop_authserver_nonce: []const u8, 114 + dpop_private_key: []const u8, 115 + }; 116 + 117 + pub const Session = struct { 118 + did: []const u8, 119 + handle: []const u8, 120 + pds_url: []const u8, 121 + authserver_iss: []const u8, 122 + access_token: []const u8, 123 + refresh_token: []const u8, 124 + dpop_authserver_nonce: []const u8, 125 + dpop_pds_nonce: []const u8, 126 + dpop_private_key: []const u8, 127 + }; 128 + 129 + // --------------------------------------------------------------------------- 130 + // auth_requests — insert, get (arena-duped copy), delete 131 + // --------------------------------------------------------------------------- 132 + 133 + pub fn insertAuthRequest( 134 + state: []const u8, 135 + authserver_iss: []const u8, 136 + did: []const u8, 137 + handle: []const u8, 138 + pds_url: []const u8, 139 + pkce_verifier: []const u8, 140 + scope: []const u8, 141 + dpop_nonce: []const u8, 142 + dpop_private_key_hex: []const u8, 143 + ) !void { 144 + mutex.lockUncancelable(io); 145 + defer mutex.unlock(io); 146 + 147 + // heap-copy everything 148 + const stored: StoredAuthRequest = .{ 149 + .state = try gpa.dupe(u8, state), 150 + .authserver_iss = try gpa.dupe(u8, authserver_iss), 151 + .did = try gpa.dupe(u8, did), 152 + .handle = try gpa.dupe(u8, handle), 153 + .pds_url = try gpa.dupe(u8, pds_url), 154 + .pkce_verifier = try gpa.dupe(u8, pkce_verifier), 155 + .scope = try gpa.dupe(u8, scope), 156 + .dpop_authserver_nonce = try gpa.dupe(u8, dpop_nonce), 157 + .dpop_private_key = try gpa.dupe(u8, dpop_private_key_hex), 158 + .created_at = timestamp(), 159 + }; 160 + // if a previous entry existed, free it first 161 + if (auth_requests.fetchRemove(stored.state)) |kv| { 162 + var prev = kv.value; 163 + prev.deinit(gpa); 164 + } 165 + try auth_requests.put(stored.state, stored); 166 + } 167 + 168 + /// caller provides an arena; returned struct's strings are dup'd into it. 169 + pub fn getAuthRequest(arena: Allocator, state: []const u8) !?AuthRequest { 170 + mutex.lockUncancelable(io); 171 + defer mutex.unlock(io); 172 + 173 + const stored = auth_requests.getPtr(state) orelse return null; 174 + return AuthRequest{ 175 + .state = try arena.dupe(u8, stored.state), 176 + .authserver_iss = try arena.dupe(u8, stored.authserver_iss), 177 + .did = try arena.dupe(u8, stored.did), 178 + .handle = try arena.dupe(u8, stored.handle), 179 + .pds_url = try arena.dupe(u8, stored.pds_url), 180 + .pkce_verifier = try arena.dupe(u8, stored.pkce_verifier), 181 + .scope = try arena.dupe(u8, stored.scope), 182 + .dpop_authserver_nonce = try arena.dupe(u8, stored.dpop_authserver_nonce), 183 + .dpop_private_key = try arena.dupe(u8, stored.dpop_private_key), 184 + }; 185 + } 186 + 187 + pub fn deleteAuthRequest(state: []const u8) void { 188 + mutex.lockUncancelable(io); 189 + defer mutex.unlock(io); 190 + if (auth_requests.fetchRemove(state)) |kv| { 191 + var stored = kv.value; 192 + stored.deinit(gpa); 193 + } 194 + } 195 + 196 + // --------------------------------------------------------------------------- 197 + // sessions — upsert, get (arena-duped copy), delete, updateNonce, updateTokens 198 + // --------------------------------------------------------------------------- 199 + 200 + pub fn upsertSession( 201 + did: []const u8, 202 + handle: []const u8, 203 + pds_url: []const u8, 204 + authserver_iss: []const u8, 205 + access_token: []const u8, 206 + refresh_token: []const u8, 207 + dpop_authserver_nonce: []const u8, 208 + dpop_pds_nonce: []const u8, 209 + dpop_private_key_hex: []const u8, 210 + ) !void { 211 + mutex.lockUncancelable(io); 212 + defer mutex.unlock(io); 213 + 214 + const new_session: StoredSession = .{ 215 + .did = try gpa.dupe(u8, did), 216 + .handle = try gpa.dupe(u8, handle), 217 + .pds_url = try gpa.dupe(u8, pds_url), 218 + .authserver_iss = try gpa.dupe(u8, authserver_iss), 219 + .access_token = try gpa.dupe(u8, access_token), 220 + .refresh_token = try gpa.dupe(u8, refresh_token), 221 + .dpop_authserver_nonce = try gpa.dupe(u8, dpop_authserver_nonce), 222 + .dpop_pds_nonce = try gpa.dupe(u8, dpop_pds_nonce), 223 + .dpop_private_key = try gpa.dupe(u8, dpop_private_key_hex), 224 + .created_at = timestamp(), 225 + }; 226 + 227 + if (sessions.fetchRemove(new_session.did)) |kv| { 228 + var prev = kv.value; 229 + prev.deinit(gpa); 230 + } 231 + try sessions.put(new_session.did, new_session); 232 + } 233 + 234 + pub fn getSession(arena: Allocator, did: []const u8) !?Session { 235 + mutex.lockUncancelable(io); 236 + defer mutex.unlock(io); 237 + 238 + const stored = sessions.getPtr(did) orelse return null; 239 + return Session{ 240 + .did = try arena.dupe(u8, stored.did), 241 + .handle = try arena.dupe(u8, stored.handle), 242 + .pds_url = try arena.dupe(u8, stored.pds_url), 243 + .authserver_iss = try arena.dupe(u8, stored.authserver_iss), 244 + .access_token = try arena.dupe(u8, stored.access_token), 245 + .refresh_token = try arena.dupe(u8, stored.refresh_token), 246 + .dpop_authserver_nonce = try arena.dupe(u8, stored.dpop_authserver_nonce), 247 + .dpop_pds_nonce = try arena.dupe(u8, stored.dpop_pds_nonce), 248 + .dpop_private_key = try arena.dupe(u8, stored.dpop_private_key), 249 + }; 250 + } 251 + 252 + pub fn deleteSession(did: []const u8) void { 253 + mutex.lockUncancelable(io); 254 + defer mutex.unlock(io); 255 + if (sessions.fetchRemove(did)) |kv| { 256 + var stored = kv.value; 257 + stored.deinit(gpa); 258 + } 259 + } 260 + 261 + pub fn updateSessionNonce(did: []const u8, field: enum { authserver, pds }, nonce: []const u8) void { 262 + mutex.lockUncancelable(io); 263 + defer mutex.unlock(io); 264 + const stored = sessions.getPtr(did) orelse return; 265 + const new_val = gpa.dupe(u8, nonce) catch return; 266 + switch (field) { 267 + .authserver => { 268 + gpa.free(stored.dpop_authserver_nonce); 269 + stored.dpop_authserver_nonce = new_val; 270 + }, 271 + .pds => { 272 + gpa.free(stored.dpop_pds_nonce); 273 + stored.dpop_pds_nonce = new_val; 274 + }, 275 + } 276 + } 277 + 278 + pub fn updateSessionTokens(did: []const u8, access_token: []const u8, refresh_token: []const u8) void { 279 + mutex.lockUncancelable(io); 280 + defer mutex.unlock(io); 281 + const stored = sessions.getPtr(did) orelse return; 282 + const new_at = gpa.dupe(u8, access_token) catch return; 283 + const new_rt = gpa.dupe(u8, refresh_token) catch { 284 + gpa.free(new_at); 285 + return; 286 + }; 287 + gpa.free(stored.access_token); 288 + gpa.free(stored.refresh_token); 289 + stored.access_token = new_at; 290 + stored.refresh_token = new_rt; 291 + } 292 + 293 + // --------------------------------------------------------------------------- 294 + // exchange_tokens — one-time short-lived codes 295 + // --------------------------------------------------------------------------- 296 + 297 + pub fn insertExchangeToken(token: []const u8, did: []const u8) !void { 298 + mutex.lockUncancelable(io); 299 + defer mutex.unlock(io); 300 + const stored: StoredExchange = .{ 301 + .did = try gpa.dupe(u8, did), 302 + .created_at = timestamp(), 303 + }; 304 + const key = try gpa.dupe(u8, token); 305 + if (exchange_tokens.fetchRemove(key)) |kv| { 306 + gpa.free(kv.value.did); 307 + } 308 + try exchange_tokens.put(key, stored); 309 + } 310 + 311 + /// consume (delete) and return the associated DID. caller owns the returned 312 + /// slice — allocated with the provided arena. 313 + pub fn consumeExchangeToken(arena: Allocator, token: []const u8) !?[]const u8 { 314 + mutex.lockUncancelable(io); 315 + defer mutex.unlock(io); 316 + const cutoff = timestamp() - 60; 317 + const kv = exchange_tokens.fetchRemove(token) orelse return null; 318 + if (kv.value.created_at < cutoff) { 319 + gpa.free(kv.key); 320 + gpa.free(kv.value.did); 321 + return null; 322 + } 323 + defer { 324 + gpa.free(kv.key); 325 + gpa.free(kv.value.did); 326 + } 327 + return try arena.dupe(u8, kv.value.did); 328 + }
+114
lexicons/tech/waow/ken/pack.json
··· 1 + { 2 + "lexicon": 1, 3 + "id": "tech.waow.ken.pack", 4 + "description": "A snapshot of semantic embeddings derived from records on this repo. Vectors live as blob attachments (chunked) and positions live as a JSON blob; this record is the manifest. Multiple packs may coexist (different models, different time slices).", 5 + "defs": { 6 + "main": { 7 + "type": "record", 8 + "key": "tid", 9 + "record": { 10 + "type": "object", 11 + "required": ["model", "dim", "encoding", "count", "positions", "vectors", "createdAt"], 12 + "properties": { 13 + "model": { 14 + "type": "string", 15 + "description": "Canonical id of the embedding model. Vectors from different models live in non-comparable spaces; consumers MUST filter by model before doing nearest-neighbor.", 16 + "maxLength": 128 17 + }, 18 + "dim": { 19 + "type": "integer", 20 + "description": "Vector dimensionality.", 21 + "minimum": 1, 22 + "maximum": 8192 23 + }, 24 + "encoding": { 25 + "type": "string", 26 + "description": "Wire encoding of the vectors blob. 'float32' is the canonical full-fidelity form; 'int8' and 'binary' are quantized variants for cold-start size.", 27 + "knownValues": ["float32", "int8", "binary"] 28 + }, 29 + "count": { 30 + "type": "integer", 31 + "description": "Number of vectors in the pack. Must equal the entry count in the positions and vectors blobs.", 32 + "minimum": 0 33 + }, 34 + "positions": { 35 + "type": "blob", 36 + "description": "JSON-encoded array of {uri, cid} per entry, same length and order as the vectors. Stored as application/octet-stream as a workaround for an upstream PDS bug where sync.getBlob mishandles application/json blobs (returns a serialized stream object instead of the file contents).", 37 + "accept": ["application/octet-stream"], 38 + "maxSize": 50000000 39 + }, 40 + "vectors": { 41 + "type": "array", 42 + "description": "Chunked raw-f32 vector blobs. Concatenate in order to reconstruct count*dim floats.", 43 + "items": { 44 + "type": "blob", 45 + "accept": ["application/octet-stream"], 46 + "maxSize": 5000000 47 + } 48 + }, 49 + "metadata": { 50 + "type": "blob", 51 + "description": "Optional msgpack/json blob with per-vector metadata (cluster labels, contentHashes, source collection NSID).", 52 + "accept": ["application/octet-stream", "application/json"], 53 + "maxSize": 50000000 54 + }, 55 + "coverage": { 56 + "type": "ref", 57 + "ref": "#coverage", 58 + "description": "Which subset of the repo this pack covers." 59 + }, 60 + "builder": { 61 + "type": "string", 62 + "description": "Identifier of the tool that produced this pack. Free-form, for debugging.", 63 + "maxLength": 256 64 + }, 65 + "supersedes": { 66 + "type": "string", 67 + "format": "at-uri", 68 + "description": "If this pack replaces a prior one, point at it. Consumers may use this to chain history." 69 + }, 70 + "buildMs": { 71 + "type": "integer", 72 + "description": "Total wall time (ms) it took to build this pack, from kicking off the walk through final embed. Lets consumers calibrate ETA messages for future builds of comparable repos.", 73 + "minimum": 0 74 + }, 75 + "reusedCount": { 76 + "type": "integer", 77 + "description": "Of the `count` entries, how many had their vectors reused from a prior pack (same uri+cid). The remainder were freshly embedded.", 78 + "minimum": 0 79 + }, 80 + "freshCount": { 81 + "type": "integer", 82 + "description": "Of the `count` entries, how many were freshly embedded (not reused from a prior pack).", 83 + "minimum": 0 84 + }, 85 + "createdAt": { 86 + "type": "string", 87 + "format": "datetime" 88 + } 89 + } 90 + } 91 + }, 92 + "coverage": { 93 + "type": "object", 94 + "description": "Description of which records this pack was built from.", 95 + "properties": { 96 + "collections": { 97 + "type": "array", 98 + "description": "NSIDs of collections included in this pack. Empty/absent means 'all collections in the repo at build time'.", 99 + "items": { "type": "string", "format": "nsid" } 100 + }, 101 + "from": { 102 + "type": "string", 103 + "format": "datetime", 104 + "description": "Earliest source createdAt covered." 105 + }, 106 + "to": { 107 + "type": "string", 108 + "format": "datetime", 109 + "description": "Latest source createdAt covered." 110 + } 111 + } 112 + } 113 + } 114 + }