atproto relay implementation in zig zlay.waow.tech
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: thread pool for frame processing

separates lightweight reader threads from heavy frame processing workers.
reader threads keep cursor tracking, rate limiting, and frame type filtering;
heavy work (CBOR decode, validation, DB persist, broadcast) offloaded to a
configurable pool of N workers (default 16, env FRAME_WORKERS).

per-host ordering preserved via key-partitioned queues (host_id % N).
broadcast_order mutex unchanged, just fewer contenders.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

zzstoatzz f0c7bafa 162ced85

+603 -4
+2
build.zig
··· 74 74 "src/slurper.zig", 75 75 "src/collection_index.zig", 76 76 "src/backfill.zig", 77 + "src/thread_pool.zig", 78 + "src/frame_worker.zig", 77 79 }; 78 80 inline for (test_files) |file| { 79 81 const test_mod = b.createModule(.{
+219
src/frame_worker.zig
··· 1 + //! frame processing worker — heavy frame handling offloaded from reader threads 2 + //! 3 + //! the reader thread (subscriber) does lightweight header decode, cursor tracking, 4 + //! and rate limiting, then submits raw frame bytes here for heavy processing: 5 + //! CBOR decode, DID resolution, signature validation, DB persist, broadcast. 6 + //! 7 + //! double decode (reader + worker both parse CBOR header) is intentional — 8 + //! CBOR decode is ~1-2μs, far cheaper than serializing decoded values across threads. 9 + 10 + const std = @import("std"); 11 + const zat = @import("zat"); 12 + const broadcaster = @import("broadcaster.zig"); 13 + const validator_mod = @import("validator.zig"); 14 + const event_log_mod = @import("event_log.zig"); 15 + const collection_index_mod = @import("collection_index.zig"); 16 + const thread_pool = @import("thread_pool.zig"); 17 + 18 + const Allocator = std.mem.Allocator; 19 + const log = std.log.scoped(.relay); 20 + 21 + pub const FrameWork = struct { 22 + data: []u8, // raw frame bytes (heap-duped by reader, freed by worker) 23 + host_id: u64, 24 + hostname: []const u8, // borrowed from subscriber (stable lifetime) 25 + allocator: Allocator, 26 + // shared references (all thread-safe, all outlive the work item) 27 + bc: *broadcaster.Broadcaster, 28 + validator: *validator_mod.Validator, 29 + persist: ?*event_log_mod.DiskPersist, 30 + collection_index: ?*collection_index_mod.CollectionIndex, 31 + }; 32 + 33 + pub fn processFrame(work: *FrameWork) void { 34 + defer work.allocator.free(work.data); 35 + 36 + var arena = std.heap.ArenaAllocator.init(work.allocator); 37 + defer arena.deinit(); 38 + const alloc = arena.allocator(); 39 + 40 + const data = work.data; 41 + 42 + // re-decode header (cheap — ~1-2μs) 43 + const header_result = zat.cbor.decode(alloc, data) catch |err| { 44 + log.debug("worker: frame header decode failed: {s} (len={d})", .{ @errorName(err), data.len }); 45 + _ = work.bc.stats.decode_errors.fetchAdd(1, .monotonic); 46 + return; 47 + }; 48 + const header = header_result.value; 49 + const payload_data = data[header_result.consumed..]; 50 + 51 + const op = header.getInt("op") orelse return; 52 + if (op != 1) return; // only process message frames (error frames handled by reader) 53 + 54 + const frame_type = header.getString("t") orelse return; 55 + const payload = zat.cbor.decodeAll(alloc, payload_data) catch |err| { 56 + log.debug("worker: frame payload decode failed: {s} (type={s})", .{ @errorName(err), frame_type }); 57 + _ = work.bc.stats.decode_errors.fetchAdd(1, .monotonic); 58 + return; 59 + }; 60 + 61 + // route by frame type — unknown types already filtered by reader 62 + const is_commit = std.mem.eql(u8, frame_type, "#commit"); 63 + const is_sync = std.mem.eql(u8, frame_type, "#sync"); 64 + const is_account = std.mem.eql(u8, frame_type, "#account"); 65 + const is_identity = std.mem.eql(u8, frame_type, "#identity"); 66 + 67 + if (!is_commit and !is_sync and !is_account and !is_identity) return; 68 + 69 + // extract DID: "repo" for commits, "did" for identity/account 70 + const did: ?[]const u8 = if (is_commit) 71 + payload.getString("repo") 72 + else 73 + payload.getString("did"); 74 + 75 + // on #identity event, evict cached signing key so next commit re-resolves 76 + if (is_identity) { 77 + if (did) |d| work.validator.evictKey(d); 78 + } 79 + 80 + // resolve DID → numeric UID for event header (host-aware) 81 + const uid: u64 = if (work.persist) |dp| blk: { 82 + if (did) |d| { 83 + const result = dp.uidForDidFromHost(d, work.host_id) catch break :blk @as(u64, 0); 84 + if (result.host_changed or result.is_new) { 85 + work.validator.queueMigrationCheck(d, work.host_id); 86 + } 87 + break :blk result.uid; 88 + } else break :blk @as(u64, 0); 89 + } else 0; 90 + 91 + // process #account events: update upstream status 92 + if (is_account) { 93 + if (work.persist) |dp| { 94 + if (uid > 0) { 95 + const is_active = payload.getBool("active") orelse false; 96 + const status_str = payload.getString("status"); 97 + const new_status: []const u8 = if (is_active) 98 + "active" 99 + else 100 + (status_str orelse "inactive"); 101 + dp.updateAccountUpstreamStatus(uid, new_status) catch |err| { 102 + log.debug("upstream status update failed: {s}", .{@errorName(err)}); 103 + }; 104 + 105 + // on account tombstone/deletion, remove all collection index entries 106 + if (std.mem.eql(u8, new_status, "deleted") or std.mem.eql(u8, new_status, "takendown")) { 107 + if (work.collection_index) |ci| { 108 + if (did) |d| { 109 + ci.removeAll(d) catch |err| { 110 + log.debug("collection removeAll failed: {s}", .{@errorName(err)}); 111 + }; 112 + } 113 + } 114 + } 115 + } 116 + } 117 + } 118 + 119 + // for commits and syncs: check account is active, validate, extract state 120 + var commit_data_cid: ?[]const u8 = null; 121 + var commit_rev: ?[]const u8 = null; 122 + if (is_commit or is_sync) { 123 + // drop for inactive accounts 124 + if (work.persist) |dp| { 125 + if (uid > 0) { 126 + const active = dp.isAccountActive(uid) catch true; 127 + if (!active) { 128 + _ = work.bc.stats.skipped.fetchAdd(1, .monotonic); 129 + return; 130 + } 131 + } 132 + } 133 + 134 + // stale rev check 135 + if (is_commit and uid > 0) { 136 + if (payload.getString("rev")) |incoming_rev| { 137 + if (work.persist) |dp| { 138 + if (dp.getAccountState(uid, alloc) catch null) |prev| { 139 + if (std.mem.order(u8, incoming_rev, prev.rev) != .gt) { 140 + log.debug("host {s}: dropping stale commit uid={d} rev={s} <= {s}", .{ 141 + work.hostname, uid, incoming_rev, prev.rev, 142 + }); 143 + _ = work.bc.stats.skipped.fetchAdd(1, .monotonic); 144 + return; 145 + } 146 + } 147 + } 148 + } 149 + } 150 + 151 + if (is_commit) { 152 + const result = work.validator.validateCommit(payload); 153 + if (!result.valid) return; 154 + commit_data_cid = result.data_cid; 155 + commit_rev = result.commit_rev; 156 + 157 + // track collections from commit ops (phase 1 live indexing) 158 + if (work.collection_index) |ci| { 159 + if (did) |d| { 160 + if (payload.get("ops")) |ops| { 161 + ci.trackCommitOps(d, ops); 162 + } 163 + } 164 + } 165 + } else { 166 + // #sync: signature verification only, no ops/MST 167 + const result = work.validator.validateSync(payload); 168 + if (!result.valid) return; 169 + commit_data_cid = result.data_cid; 170 + commit_rev = result.commit_rev; 171 + } 172 + } 173 + 174 + // determine event kind for persistence 175 + const kind: event_log_mod.EvtKind = if (is_commit) 176 + .commit 177 + else if (is_sync) 178 + .sync 179 + else if (is_account) 180 + .account 181 + else 182 + .identity; 183 + 184 + // persist and broadcast under ordering lock 185 + if (work.persist) |dp| { 186 + const relay_seq = blk: { 187 + work.bc.broadcast_order.lock(); 188 + defer work.bc.broadcast_order.unlock(); 189 + 190 + const seq = dp.persist(kind, uid, data) catch |err| { 191 + log.warn("persist failed: {s}", .{@errorName(err)}); 192 + return; 193 + }; 194 + work.bc.stats.relay_seq.store(seq, .release); 195 + const broadcast_data = broadcaster.resequenceFrame(alloc, data, seq) orelse data; 196 + work.bc.broadcast(seq, broadcast_data); 197 + break :blk seq; 198 + }; 199 + _ = relay_seq; 200 + 201 + // update per-DID state outside the ordering lock (Postgres round-trip) 202 + if ((is_commit or is_sync) and uid > 0) { 203 + if (commit_rev) |rev| { 204 + const cid_str: []const u8 = if (commit_data_cid) |cid_raw| 205 + zat.multibase.encode(alloc, .base32lower, cid_raw) catch "" 206 + else 207 + ""; 208 + dp.updateAccountState(uid, rev, cid_str) catch |err| { 209 + log.debug("account state update failed: {s}", .{@errorName(err)}); 210 + }; 211 + } 212 + } 213 + } else { 214 + const upstream_seq = payload.getUint("seq") orelse 0; 215 + work.bc.broadcast(upstream_seq, data); 216 + } 217 + } 218 + 219 + pub const FramePool = thread_pool.ThreadPool(FrameWork, processFrame);
+4
src/main.zig
··· 131 131 const upstream = std.posix.getenv("RELAY_UPSTREAM") orelse "bsky.network"; 132 132 const data_dir = std.posix.getenv("RELAY_DATA_DIR") orelse "data/events"; 133 133 const retention_hours = parseEnvInt(u64, "RELAY_RETENTION_HOURS", 72); 134 + const frame_workers = parseEnvInt(u16, "FRAME_WORKERS", 16); 135 + const frame_queue_capacity = parseEnvInt(u16, "FRAME_QUEUE_CAPACITY", 4096); 134 136 135 137 // install signal handlers (including SIGPIPE ignore) 136 138 installSignalHandlers(); ··· 184 186 .{ 185 187 .seed_host = upstream, 186 188 .max_message_size = 5 * 1024 * 1024, 189 + .frame_workers = frame_workers, 190 + .frame_queue_capacity = frame_queue_capacity, 187 191 }, 188 192 ); 189 193 defer slurper.deinit();
+23 -1
src/slurper.zig
··· 17 17 const event_log_mod = @import("event_log.zig"); 18 18 const subscriber_mod = @import("subscriber.zig"); 19 19 const collection_index_mod = @import("collection_index.zig"); 20 + const frame_worker_mod = @import("frame_worker.zig"); 20 21 21 22 const Allocator = std.mem.Allocator; 22 23 const log = std.log.scoped(.relay); ··· 24 25 pub const Options = struct { 25 26 seed_host: []const u8 = "bsky.network", 26 27 max_message_size: usize = 5 * 1024 * 1024, 28 + frame_workers: u16 = 16, 29 + frame_queue_capacity: u16 = 4096, 27 30 }; 28 31 29 32 // --- host validation --- ··· 204 207 shutdown: *std.atomic.Value(bool), 205 208 options: Options, 206 209 210 + // frame processing pool — offloads heavy work from reader threads 211 + frame_pool: ?frame_worker_mod.FramePool = null, 212 + 207 213 // shared TLS CA bundle — loaded once, used by all subscriber connections 208 214 ca_bundle: ?std.crypto.Certificate.Bundle = null, 209 215 ··· 247 253 self.ca_bundle = bundle; 248 254 log.info("loaded shared CA bundle", .{}); 249 255 256 + // create frame processing pool — worker threads handle heavy decode/validate/persist 257 + self.frame_pool = try frame_worker_mod.FramePool.init(self.allocator, .{ 258 + .num_workers = self.options.frame_workers, 259 + .queue_capacity = self.options.frame_queue_capacity, 260 + .stack_size = @import("main.zig").default_stack_size, 261 + }); 262 + log.info("frame pool started: {d} workers, queue capacity {d}", .{ self.options.frame_workers, self.options.frame_queue_capacity }); 263 + 250 264 // spawn worker startup in background so HTTP server + probes come up immediately. 251 265 // pullHosts + listActiveHosts + spawnWorker all happen in the background thread. 252 266 self.startup_thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, spawnWorkers, .{self}); ··· 426 440 }, 427 441 ); 428 442 sub.collection_index = self.collection_index; 443 + if (self.frame_pool) |*fp| sub.pool = fp; 429 444 430 445 const thread = try std.Thread.spawn(.{ .stack_size = @import("main.zig").default_stack_size }, runWorker, .{ self, host_id, sub }); 431 446 ··· 539 554 } 540 555 } 541 556 542 - // join all worker threads (shutdown flag is already set by main) 557 + // join all reader threads FIRST (they stop submitting to pool) 543 558 for (threads_to_join.items) |t| t.join(); 559 + 560 + // then drain + join pool workers (processes remaining queued frames) 561 + if (self.frame_pool) |*fp| { 562 + fp.shutdown(); 563 + fp.deinit(); 564 + self.frame_pool = null; 565 + } 544 566 545 567 // clean up workers map 546 568 self.workers.deinit(self.allocator);
+44 -3
src/subscriber.zig
··· 13 13 const validator_mod = @import("validator.zig"); 14 14 const event_log_mod = @import("event_log.zig"); 15 15 const collection_index_mod = @import("collection_index.zig"); 16 + const frame_worker_mod = @import("frame_worker.zig"); 16 17 17 18 const Allocator = std.mem.Allocator; 18 19 const log = std.log.scoped(.relay); ··· 105 106 validator: *validator_mod.Validator, 106 107 persist: ?*event_log_mod.DiskPersist, 107 108 collection_index: ?*collection_index_mod.CollectionIndex = null, 109 + pool: ?*frame_worker_mod.FramePool = null, 108 110 shutdown: *std.atomic.Value(bool), 109 111 last_upstream_seq: ?u64 = null, 110 112 last_cursor_flush: i64 = 0, ··· 254 256 pub fn serverMessage(self: *FrameHandler, data: []const u8) !void { 255 257 const sub = self.subscriber; 256 258 257 - // decode frame using SDK CBOR codec: [header map] [payload map] 259 + // lightweight header decode for cursor tracking + routing 258 260 var arena = std.heap.ArenaAllocator.init(sub.allocator); 259 261 defer arena.deinit(); 260 262 const alloc = arena.allocator(); ··· 323 325 return; 324 326 } 325 327 326 - // route by frame type — unknown types are ignored per spec (forward-compat) 328 + // filter unknown frame types before submitting to pool (forward-compat) 327 329 const is_commit = std.mem.eql(u8, frame_type, "#commit"); 328 330 const is_sync = std.mem.eql(u8, frame_type, "#sync"); 329 331 const is_account = std.mem.eql(u8, frame_type, "#account"); ··· 334 336 return; 335 337 } 336 338 339 + // submit to frame pool for heavy processing (CBOR re-decode, validation, persist, broadcast) 340 + if (sub.pool) |pool| { 341 + const duped = sub.allocator.dupe(u8, data) catch return; 342 + if (!pool.submit(sub.options.host_id, .{ 343 + .data = duped, 344 + .host_id = sub.options.host_id, 345 + .hostname = sub.options.hostname, 346 + .allocator = sub.allocator, 347 + .bc = sub.bc, 348 + .validator = sub.validator, 349 + .persist = sub.persist, 350 + .collection_index = sub.collection_index, 351 + })) { 352 + // backpressure: queue full, drop frame 353 + sub.allocator.free(duped); 354 + log.debug("host {s}: frame pool queue full, dropping frame", .{sub.options.hostname}); 355 + } 356 + return; 357 + } 358 + 359 + // fallback: no pool, process inline (original path for tests / standalone use) 360 + self.processInline(sub, alloc, data, payload, upstream_seq, frame_type, is_commit, is_sync, is_account, is_identity); 361 + } 362 + 363 + /// inline processing path — used when no frame pool is configured (tests, standalone). 364 + /// this is the original serverMessage heavy processing logic. 365 + fn processInline( 366 + _: *FrameHandler, 367 + sub: *Subscriber, 368 + alloc: Allocator, 369 + data: []const u8, 370 + payload: zat.cbor.Value, 371 + upstream_seq: ?u64, 372 + _: []const u8, 373 + is_commit: bool, 374 + is_sync: bool, 375 + is_account: bool, 376 + is_identity: bool, 377 + ) void { 337 378 // extract DID: "repo" for commits, "did" for identity/account 338 379 const did: ?[]const u8 = if (is_commit) 339 380 payload.getString("repo") ··· 341 382 payload.getString("did"); 342 383 343 384 // on #identity event, evict cached signing key so next commit re-resolves 344 - if (std.mem.eql(u8, frame_type, "#identity")) { 385 + if (is_identity) { 345 386 if (did) |d| sub.validator.evictKey(d); 346 387 } 347 388
+311
src/thread_pool.zig
··· 1 + //! generic thread pool with key-partitioned queues 2 + //! 3 + //! each worker has its own bounded ring buffer + mutex + condvar. 4 + //! `submit(key, item)` routes to `workers[key % N]` for per-key ordering. 5 + //! items stored by value in pre-allocated ring buffer (zero alloc per submit). 6 + 7 + const std = @import("std"); 8 + const Allocator = std.mem.Allocator; 9 + 10 + pub fn ThreadPool(comptime T: type, comptime processFn: fn (*T) void) type { 11 + return struct { 12 + const Self = @This(); 13 + 14 + pub const Config = struct { 15 + num_workers: u16 = 8, 16 + queue_capacity: u16 = 4096, 17 + stack_size: usize = 4 * 1024 * 1024, 18 + }; 19 + 20 + const Worker = struct { 21 + // ring buffer stored as a slice of T 22 + queue: []T, 23 + capacity: u16, 24 + head: u16 = 0, // next slot to read 25 + tail: u16 = 0, // next slot to write 26 + count: u16 = 0, 27 + mutex: std.Thread.Mutex = .{}, 28 + cond: std.Thread.Condition = .{}, 29 + alive: bool = true, 30 + thread: ?std.Thread = null, 31 + }; 32 + 33 + workers: []Worker, 34 + allocator: Allocator, 35 + 36 + pub fn init(allocator: Allocator, config: Config) !Self { 37 + const workers = try allocator.alloc(Worker, config.num_workers); 38 + for (workers) |*w| { 39 + w.* = .{ 40 + .queue = try allocator.alloc(T, config.queue_capacity), 41 + .capacity = config.queue_capacity, 42 + }; 43 + } 44 + 45 + const self = Self{ 46 + .workers = workers, 47 + .allocator = allocator, 48 + }; 49 + 50 + // spawn worker threads 51 + for (self.workers) |*w| { 52 + w.thread = try std.Thread.spawn( 53 + .{ .stack_size = config.stack_size }, 54 + workerLoop, 55 + .{w}, 56 + ); 57 + } 58 + 59 + return self; 60 + } 61 + 62 + /// submit an item for processing, routed by key. 63 + /// returns false if the target worker's queue is full (backpressure). 64 + pub fn submit(self: *Self, key: u64, item: T) bool { 65 + const idx = key % self.workers.len; 66 + const w = &self.workers[idx]; 67 + 68 + w.mutex.lock(); 69 + defer w.mutex.unlock(); 70 + 71 + if (w.count == w.capacity) return false; 72 + 73 + w.queue[w.tail] = item; 74 + w.tail = @intCast((@as(u32, w.tail) + 1) % @as(u32, w.capacity)); 75 + w.count += 1; 76 + w.cond.signal(); 77 + return true; 78 + } 79 + 80 + /// drain remaining items and join all worker threads. 81 + pub fn shutdown(self: *Self) void { 82 + // signal all workers to stop 83 + for (self.workers) |*w| { 84 + w.mutex.lock(); 85 + w.alive = false; 86 + w.cond.signal(); 87 + w.mutex.unlock(); 88 + } 89 + // join all threads 90 + for (self.workers) |*w| { 91 + if (w.thread) |t| { 92 + t.join(); 93 + w.thread = null; 94 + } 95 + } 96 + } 97 + 98 + /// free queue storage. 99 + pub fn deinit(self: *Self) void { 100 + for (self.workers) |*w| { 101 + self.allocator.free(w.queue); 102 + } 103 + self.allocator.free(self.workers); 104 + } 105 + 106 + /// total pending items across all workers (diagnostic). 107 + pub fn pendingCount(self: *Self) usize { 108 + var total: usize = 0; 109 + for (self.workers) |*w| { 110 + w.mutex.lock(); 111 + defer w.mutex.unlock(); 112 + total += w.count; 113 + } 114 + return total; 115 + } 116 + 117 + fn workerLoop(w: *Worker) void { 118 + while (true) { 119 + var item: T = undefined; 120 + 121 + { 122 + w.mutex.lock(); 123 + defer w.mutex.unlock(); 124 + 125 + while (w.count == 0 and w.alive) { 126 + w.cond.wait(&w.mutex); 127 + } 128 + 129 + if (w.count == 0 and !w.alive) return; 130 + 131 + item = w.queue[w.head]; 132 + w.head = @intCast((@as(u32, w.head) + 1) % @as(u32, w.capacity)); 133 + w.count -= 1; 134 + } 135 + 136 + processFn(&item); 137 + } 138 + } 139 + }; 140 + } 141 + 142 + // --- tests --- 143 + 144 + const testing = std.testing; 145 + 146 + test "basic submit and process" { 147 + const Item = struct { 148 + value: u32, 149 + processed: *std.atomic.Value(u32), 150 + }; 151 + 152 + const S = struct { 153 + fn process(item: *Item) void { 154 + _ = item.processed.fetchAdd(item.value, .monotonic); 155 + } 156 + }; 157 + 158 + var counter: std.atomic.Value(u32) = .{ .raw = 0 }; 159 + var pool = try ThreadPool(Item, S.process).init(testing.allocator, .{ 160 + .num_workers = 2, 161 + .queue_capacity = 64, 162 + .stack_size = 1 * 1024 * 1024, 163 + }); 164 + 165 + // submit items 166 + for (0..10) |i| { 167 + const ok = pool.submit(i, .{ 168 + .value = @intCast(i + 1), 169 + .processed = &counter, 170 + }); 171 + try testing.expect(ok); 172 + } 173 + 174 + pool.shutdown(); 175 + defer pool.deinit(); 176 + 177 + // sum of 1..10 = 55 178 + try testing.expectEqual(@as(u32, 55), counter.load(.acquire)); 179 + } 180 + 181 + test "per-key ordering preserved" { 182 + // items with the same key should be processed in FIFO order 183 + const Item = struct { 184 + seq: u32, 185 + results: *std.ArrayListUnmanaged(u32), 186 + mutex: *std.Thread.Mutex, 187 + allocator: Allocator, 188 + }; 189 + 190 + const S = struct { 191 + fn process(item: *Item) void { 192 + item.mutex.lock(); 193 + defer item.mutex.unlock(); 194 + item.results.append(item.allocator, item.seq) catch {}; 195 + } 196 + }; 197 + 198 + var results: std.ArrayListUnmanaged(u32) = .{}; 199 + defer results.deinit(testing.allocator); 200 + var mutex: std.Thread.Mutex = .{}; 201 + 202 + var pool = try ThreadPool(Item, S.process).init(testing.allocator, .{ 203 + .num_workers = 4, 204 + .queue_capacity = 64, 205 + .stack_size = 1 * 1024 * 1024, 206 + }); 207 + 208 + // submit 20 items all with key=42 (same worker) 209 + for (0..20) |i| { 210 + const ok = pool.submit(42, .{ 211 + .seq = @intCast(i), 212 + .results = &results, 213 + .mutex = &mutex, 214 + .allocator = testing.allocator, 215 + }); 216 + try testing.expect(ok); 217 + } 218 + 219 + pool.shutdown(); 220 + defer pool.deinit(); 221 + 222 + try testing.expectEqual(@as(usize, 20), results.items.len); 223 + for (results.items, 0..) |val, i| { 224 + try testing.expectEqual(@as(u32, @intCast(i)), val); 225 + } 226 + } 227 + 228 + test "submit returns false when queue is full" { 229 + const Item = struct { x: u32 }; 230 + const S = struct { 231 + fn process(_: *Item) void { 232 + // block so queue stays full 233 + std.posix.nanosleep(0, 50 * std.time.ns_per_ms); 234 + } 235 + }; 236 + 237 + var pool = try ThreadPool(Item, S.process).init(testing.allocator, .{ 238 + .num_workers = 1, 239 + .queue_capacity = 4, 240 + .stack_size = 1 * 1024 * 1024, 241 + }); 242 + 243 + // fill the queue (worker may drain some, so submit more than capacity) 244 + var submitted: u32 = 0; 245 + var rejected: u32 = 0; 246 + for (0..100) |i| { 247 + if (pool.submit(0, .{ .x = @intCast(i) })) { 248 + submitted += 1; 249 + } else { 250 + rejected += 1; 251 + } 252 + } 253 + 254 + pool.shutdown(); 255 + defer pool.deinit(); 256 + 257 + // should have rejected at least some 258 + try testing.expect(rejected > 0); 259 + try testing.expect(submitted > 0); 260 + } 261 + 262 + test "pendingCount reflects queued items" { 263 + const Item = struct { x: u32 }; 264 + const S = struct { 265 + fn process(item: *Item) void { 266 + _ = item; 267 + // slow worker so items accumulate 268 + std.posix.nanosleep(0, 10 * std.time.ns_per_ms); 269 + } 270 + }; 271 + 272 + var pool = try ThreadPool(Item, S.process).init(testing.allocator, .{ 273 + .num_workers = 1, 274 + .queue_capacity = 64, 275 + .stack_size = 1 * 1024 * 1024, 276 + }); 277 + 278 + // initially empty 279 + try testing.expectEqual(@as(usize, 0), pool.pendingCount()); 280 + 281 + pool.shutdown(); 282 + defer pool.deinit(); 283 + } 284 + 285 + test "shutdown drains remaining items" { 286 + const Item = struct { 287 + counter: *std.atomic.Value(u32), 288 + }; 289 + const S = struct { 290 + fn process(item: *Item) void { 291 + _ = item.counter.fetchAdd(1, .monotonic); 292 + } 293 + }; 294 + 295 + var counter: std.atomic.Value(u32) = .{ .raw = 0 }; 296 + var pool = try ThreadPool(Item, S.process).init(testing.allocator, .{ 297 + .num_workers = 2, 298 + .queue_capacity = 64, 299 + .stack_size = 1 * 1024 * 1024, 300 + }); 301 + 302 + for (0..30) |i| { 303 + _ = pool.submit(i, .{ .counter = &counter }); 304 + } 305 + 306 + pool.shutdown(); 307 + defer pool.deinit(); 308 + 309 + // all 30 should have been processed (shutdown drains) 310 + try testing.expectEqual(@as(u32, 30), counter.load(.acquire)); 311 + }