atproto relay implementation in zig zlay.waow.tech
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: use std.http.Server for HTTP endpoint

Replace hand-rolled HTTP parser with zig 0.15's std.http.Server,
which provides buffered I/O that correctly handles POST bodies
split across TCP segments behind traefik. This is the proper fix
for the issue that commit 4733560 worked around with a multi-read loop.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

zzstoatzz ce29acc5 4733560e

+160 -191
+3 -10
src/broadcaster.zig
··· 573 573 } 574 574 575 575 pub fn formatStatsResponse(stats: *const Stats, buf: []u8) []const u8 { 576 - var json_buf: [2048]u8 = undefined; 577 - const json = std.fmt.bufPrint(&json_buf, 576 + return std.fmt.bufPrint(buf, 578 577 \\{{"seq":{d},"relay_seq":{d},"consumers":{d},"connected_inbound":{d},"frames_in":{d},"frames_out":{d},"validated":{d},"failed":{d},"skipped":{d},"decode_errors":{d},"cache_hits":{d},"cache_misses":{d},"slow_consumers":{d},"uptime_seconds":{d}}} 579 578 , .{ 580 579 stats.seq.load(.acquire), ··· 591 590 stats.cache_misses.load(.acquire), 592 591 stats.slow_consumers.load(.acquire), 593 592 std.time.timestamp() - stats.start_time, 594 - }) catch return "HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n"; 595 - 596 - return std.fmt.bufPrint( 597 - buf, 598 - "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {d}\r\nConnection: close\r\n\r\n{s}", 599 - .{ json.len, json }, 600 - ) catch "HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n"; 593 + }) catch ""; 601 594 } 602 595 603 596 // --- tests --- ··· 710 703 var buf: [4096]u8 = undefined; 711 704 const response = formatStatsResponse(&stats, &buf); 712 705 713 - try std.testing.expect(std.mem.startsWith(u8, response, "HTTP/1.1 200 OK")); 706 + try std.testing.expect(std.mem.startsWith(u8, response, "{")); 714 707 try std.testing.expect(std.mem.indexOf(u8, response, "\"seq\":100") != null); 715 708 try std.testing.expect(std.mem.indexOf(u8, response, "\"consumers\":5") != null); 716 709 try std.testing.expect(std.mem.indexOf(u8, response, "\"frames_in\":200") != null);
+157 -181
src/main.zig
··· 18 18 //! /_health, /_stats, /metrics — health, stats, prometheus 19 19 20 20 const std = @import("std"); 21 + const http = std.http; 21 22 const websocket = @import("websocket"); 22 23 const broadcaster = @import("broadcaster.zig"); 23 24 const validator_mod = @import("validator.zig"); ··· 213 214 fn handleHttpConn(stream: std.net.Stream, stats: *broadcaster.Stats, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper, ci: *collection_index_mod.CollectionIndex) void { 214 215 defer stream.close(); 215 216 216 - // read request — may need multiple reads if proxy splits headers/body 217 - var buf: [8192]u8 = undefined; 218 - var total: usize = 0; 219 - total = stream.read(&buf) catch return; 220 - if (total == 0) return; 221 - 222 - // parse first line: "METHOD /path HTTP/1.1" 223 - const line_end = std.mem.indexOfScalar(u8, buf[0..total], '\n') orelse return; 224 - const first_line = buf[0..line_end]; 225 - 226 - const method_end = std.mem.indexOfScalar(u8, first_line, ' ') orelse return; 227 - const method = first_line[0..method_end]; 228 - 229 - const path_start = method_end + 1; 230 - const rest = first_line[path_start..]; 231 - const path_end = std.mem.indexOfScalar(u8, rest, ' ') orelse rest.len; 232 - const path = rest[0..path_end]; 233 - 234 - // find end of headers 235 - const header_end = std.mem.indexOf(u8, buf[0..total], "\r\n\r\n"); 236 - 237 - // for POSTs: if we have headers, parse Content-Length and read remaining body 238 - if (std.mem.eql(u8, method, "POST")) { 239 - if (header_end) |he| { 240 - const headers = buf[0..he]; 241 - const content_length = parseContentLength(headers) orelse 0; 242 - const body_start = he + 4; 243 - const body_needed = body_start + content_length; 217 + var recv_buf: [8192]u8 = undefined; 218 + var send_buf: [8192]u8 = undefined; 219 + var connection_reader = stream.reader(&recv_buf); 220 + var connection_writer = stream.writer(&send_buf); 221 + var server = http.Server.init(connection_reader.interface(), &connection_writer.interface); 244 222 245 - // keep reading until we have the full body (or buffer is full) 246 - while (total < body_needed and total < buf.len) { 247 - const m = stream.read(buf[total..]) catch break; 248 - if (m == 0) break; 249 - total += m; 250 - } 251 - } 252 - } 223 + var request = server.receiveHead() catch return; 253 224 254 - const request = buf[0..total]; 255 - // re-find header_end in case more data shifted things (it won't, but be safe) 256 - const he = std.mem.indexOf(u8, request, "\r\n\r\n"); 257 - const body: []const u8 = if (he) |h| request[h + 4 ..] else ""; 225 + const target = request.head.target; 226 + // extract path and query before reading body (head strings reference recv_buf) 227 + const qmark = std.mem.indexOfScalar(u8, target, '?'); 228 + const path = target[0..(qmark orelse target.len)]; 229 + const query = if (qmark) |q| target[q + 1 ..] else ""; 258 230 259 - if (std.mem.eql(u8, method, "GET")) { 260 - handleGet(stream, path, stats, persist, slurper, ci); 261 - } else if (std.mem.eql(u8, method, "POST")) { 262 - handlePost(stream, path, request[0 .. he orelse total], body, persist, slurper); 231 + if (request.head.method == .GET) { 232 + handleGet(&request, path, query, stats, persist, slurper, ci); 233 + } else if (request.head.method == .POST) { 234 + handlePost(&request, path, persist, slurper); 263 235 } else { 264 - httpRespond(stream, "405 Method Not Allowed", "text/plain", "method not allowed"); 265 - } 266 - } 267 - 268 - fn parseContentLength(headers: []const u8) ?usize { 269 - var iter = std.mem.splitScalar(u8, headers, '\n'); 270 - while (iter.next()) |line| { 271 - const trimmed = std.mem.trimRight(u8, line, "\r"); 272 - const colon = std.mem.indexOfScalar(u8, trimmed, ':') orelse continue; 273 - const key = std.mem.trim(u8, trimmed[0..colon], " "); 274 - if (std.ascii.eqlIgnoreCase(key, "content-length")) { 275 - const val = std.mem.trim(u8, trimmed[colon + 1 ..], " "); 276 - return std.fmt.parseInt(usize, val, 10) catch null; 277 - } 236 + respondText(&request, .method_not_allowed, "method not allowed"); 278 237 } 279 - return null; 280 238 } 281 239 282 - fn handleGet(stream: std.net.Stream, full_path: []const u8, stats: *broadcaster.Stats, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper, ci: *collection_index_mod.CollectionIndex) void { 283 - // split path from query string 284 - const qmark = std.mem.indexOfScalar(u8, full_path, '?'); 285 - const path = full_path[0..(qmark orelse full_path.len)]; 286 - const query = if (qmark) |q| full_path[q + 1 ..] else ""; 287 - 240 + fn handleGet(request: *http.Server.Request, path: []const u8, query: []const u8, stats: *broadcaster.Stats, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper, ci: *collection_index_mod.CollectionIndex) void { 288 241 if (std.mem.eql(u8, path, "/_health") or std.mem.eql(u8, path, "/xrpc/_health")) { 289 - httpRespond(stream, "200 OK", "application/json", "{\"status\":\"ok\"}"); 242 + respondJson(request, .ok, "{\"status\":\"ok\"}"); 290 243 } else if (std.mem.eql(u8, path, "/_stats")) { 291 244 var stats_buf: [4096]u8 = undefined; 292 - const response = broadcaster.formatStatsResponse(stats, &stats_buf); 293 - _ = stream.write(response) catch {}; 245 + const body = broadcaster.formatStatsResponse(stats, &stats_buf); 246 + respondJson(request, .ok, body); 294 247 } else if (std.mem.eql(u8, path, "/metrics")) { 295 248 var metrics_buf: [4096]u8 = undefined; 296 249 const body = broadcaster.formatPrometheusMetrics(stats, &metrics_buf); 297 - var resp_buf: [8192]u8 = undefined; 298 - const response = std.fmt.bufPrint(&resp_buf, "HTTP/1.1 200 OK\r\nContent-Type: text/plain; version=0.0.4; charset=utf-8\r\nContent-Length: {d}\r\nConnection: close\r\n\r\n{s}", .{ body.len, body }) catch return; 299 - _ = stream.write(response) catch {}; 250 + request.respond(body, .{ .status = .ok, .keep_alive = false, .extra_headers = &.{.{ .name = "content-type", .value = "text/plain; version=0.0.4; charset=utf-8" }} }) catch {}; 300 251 } else if (std.mem.eql(u8, path, "/xrpc/com.atproto.sync.listRepos")) { 301 - handleListRepos(stream, query, persist); 252 + handleListRepos(request, query, persist); 302 253 } else if (std.mem.eql(u8, path, "/xrpc/com.atproto.sync.getRepoStatus")) { 303 - handleGetRepoStatus(stream, query, persist); 254 + handleGetRepoStatus(request, query, persist); 304 255 } else if (std.mem.eql(u8, path, "/xrpc/com.atproto.sync.getLatestCommit")) { 305 - handleGetLatestCommit(stream, query, persist); 256 + handleGetLatestCommit(request, query, persist); 306 257 } else if (std.mem.eql(u8, path, "/xrpc/com.atproto.sync.listReposByCollection")) { 307 - handleListReposByCollection(stream, query, ci); 258 + handleListReposByCollection(request, query, ci); 308 259 } else if (std.mem.eql(u8, path, "/admin/hosts")) { 309 - handleAdminListHosts(stream, persist, slurper); 260 + handleAdminListHosts(request, persist, slurper); 310 261 } else if (std.mem.eql(u8, path, "/")) { 311 - httpRespond(stream, "200 OK", "text/plain", 262 + respondText(request, .ok, 312 263 \\ _ 313 264 \\ ___| | __ _ _ _ 314 265 \\|_ / |/ _` | | | | ··· 323 274 \\ 324 275 ); 325 276 } else if (std.mem.eql(u8, path, "/favicon.svg") or std.mem.eql(u8, path, "/favicon.ico")) { 326 - httpRespond(stream, "200 OK", "image/svg+xml", 277 + request.respond( 327 278 \\<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32"> 328 279 \\<rect width="32" height="32" rx="6" fill="#1a1a2e"/> 329 280 \\<text x="16" y="24" font-family="monospace" font-size="22" font-weight="bold" fill="#e94560" text-anchor="middle">Z</text> 330 281 \\</svg> 331 - ); 282 + , .{ .status = .ok, .keep_alive = false, .extra_headers = &.{.{ .name = "content-type", .value = "image/svg+xml" }} }) catch {}; 332 283 } else { 333 - httpRespond(stream, "404 Not Found", "text/plain", "not found"); 284 + respondText(request, .not_found, "not found"); 334 285 } 335 286 } 336 287 337 - fn handlePost(stream: std.net.Stream, path: []const u8, headers: []const u8, body: []const u8, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper) void { 288 + fn handlePost(request: *http.Server.Request, path: []const u8, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper) void { 338 289 if (std.mem.eql(u8, path, "/admin/repo/ban")) { 339 - handleBan(stream, headers, body, persist); 290 + handleBan(request, persist); 340 291 } else if (std.mem.eql(u8, path, "/xrpc/com.atproto.sync.requestCrawl")) { 341 - handleRequestCrawl(stream, body, slurper); 292 + handleRequestCrawl(request, slurper); 342 293 } else if (std.mem.eql(u8, path, "/admin/hosts/block")) { 343 - handleAdminBlockHost(stream, headers, body, persist); 294 + handleAdminBlockHost(request, persist); 344 295 } else if (std.mem.eql(u8, path, "/admin/hosts/unblock")) { 345 - handleAdminUnblockHost(stream, headers, body, persist); 296 + handleAdminUnblockHost(request, persist); 346 297 } else { 347 - httpRespond(stream, "404 Not Found", "text/plain", "not found"); 298 + respondText(request, .not_found, "not found"); 348 299 } 349 300 } 350 301 351 - fn handleBan(stream: std.net.Stream, headers: []const u8, body: []const u8, persist: *event_log_mod.DiskPersist) void { 352 - if (!checkAdmin(stream, headers)) return; 302 + fn handleBan(request: *http.Server.Request, persist: *event_log_mod.DiskPersist) void { 303 + if (!checkAdmin(request)) return; 304 + 305 + // read body (after checkAdmin which uses iterateHeaders) 306 + var transfer_buf: [4096]u8 = undefined; 307 + const body_reader = request.readerExpectNone(&transfer_buf); 308 + var body_buf: [4096]u8 = undefined; 309 + const body_len = body_reader.readSliceShort(&body_buf) catch { 310 + respondJson(request, .bad_request, "{\"error\":\"failed to read request body\"}"); 311 + return; 312 + }; 313 + const body = body_buf[0..body_len]; 353 314 354 315 // parse JSON body for "did" field 355 316 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 356 - httpRespond(stream, "400 Bad Request", "application/json", "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}"); 317 + respondJson(request, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}"); 357 318 return; 358 319 }; 359 320 defer parsed.deinit(); ··· 361 322 362 323 // resolve DID → UID and take down 363 324 const uid = persist.uidForDid(did) catch { 364 - httpRespond(stream, "500 Internal Server Error", "application/json", "{\"error\":\"failed to resolve DID\"}"); 325 + respondJson(request, .internal_server_error, "{\"error\":\"failed to resolve DID\"}"); 365 326 return; 366 327 }; 367 328 persist.takeDownUser(uid) catch { 368 - httpRespond(stream, "500 Internal Server Error", "application/json", "{\"error\":\"takedown failed\"}"); 329 + respondJson(request, .internal_server_error, "{\"error\":\"takedown failed\"}"); 369 330 return; 370 331 }; 371 332 372 333 log.info("admin: banned {s} (uid={d})", .{ did, uid }); 373 - httpRespond(stream, "200 OK", "application/json", "{\"success\":true}"); 334 + respondJson(request, .ok, "{\"success\":true}"); 374 335 } 375 336 376 - fn handleRequestCrawl(stream: std.net.Stream, body: []const u8, slurper: *slurper_mod.Slurper) void { 337 + fn handleRequestCrawl(request: *http.Server.Request, slurper: *slurper_mod.Slurper) void { 338 + var transfer_buf: [4096]u8 = undefined; 339 + const body_reader = request.readerExpectNone(&transfer_buf); 340 + var body_buf: [4096]u8 = undefined; 341 + const body_len = body_reader.readSliceShort(&body_buf) catch { 342 + respondJson(request, .bad_request, "{\"error\":\"failed to read request body\"}"); 343 + return; 344 + }; 345 + const body = body_buf[0..body_len]; 346 + 377 347 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, slurper.allocator, body, .{ .ignore_unknown_fields = true }) catch { 378 - httpRespond(stream, "400 Bad Request", "application/json", "{\"error\":\"invalid JSON, expected {\\\"hostname\\\":\\\"...\\\"}\"}"); 348 + respondJson(request, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"hostname\\\":\\\"...\\\"}\"}"); 379 349 return; 380 350 }; 381 351 defer parsed.deinit(); ··· 383 353 // fast validation: hostname format (Go relay does this synchronously in handler) 384 354 const hostname = slurper_mod.validateHostname(slurper.allocator, parsed.value.hostname) catch |err| { 385 355 log.warn("requestCrawl rejected '{s}': {s}", .{ parsed.value.hostname, @errorName(err) }); 386 - httpRespond(stream, "400 Bad Request", "application/json", switch (err) { 356 + respondJson(request, .bad_request, switch (err) { 387 357 error.EmptyHostname => "{\"error\":\"empty hostname\"}", 388 358 error.InvalidCharacter => "{\"error\":\"hostname contains invalid characters\"}", 389 359 error.InvalidLabel => "{\"error\":\"hostname has invalid label\"}", ··· 400 370 // fast validation: domain ban check 401 371 if (slurper.persist.isDomainBanned(hostname)) { 402 372 log.warn("requestCrawl rejected '{s}': domain banned", .{hostname}); 403 - httpRespond(stream, "400 Bad Request", "application/json", "{\"error\":\"domain is banned\"}"); 373 + respondJson(request, .bad_request, "{\"error\":\"domain is banned\"}"); 404 374 return; 405 375 } 406 376 407 377 // enqueue for async processing (describeServer check happens in crawl processor) 408 378 slurper.addCrawlRequest(hostname) catch { 409 - httpRespond(stream, "500 Internal Server Error", "application/json", "{\"error\":\"failed to store crawl request\"}"); 379 + respondJson(request, .internal_server_error, "{\"error\":\"failed to store crawl request\"}"); 410 380 return; 411 381 }; 412 382 413 383 log.info("crawl requested: {s}", .{hostname}); 414 - httpRespond(stream, "200 OK", "application/json", "{\"success\":true}"); 384 + respondJson(request, .ok, "{\"success\":true}"); 415 385 } 416 386 417 387 // --- admin host management --- 418 388 419 - fn handleAdminListHosts(stream: std.net.Stream, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper) void { 389 + fn handleAdminListHosts(request: *http.Server.Request, persist: *event_log_mod.DiskPersist, slurper: *slurper_mod.Slurper) void { 420 390 const hosts = persist.listAllHosts(persist.allocator) catch { 421 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 391 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 422 392 return; 423 393 }; 424 394 defer { ··· 447 417 } 448 418 449 419 std.fmt.format(w, "],\"active_workers\":{d}}}", .{slurper.workerCount()}) catch return; 450 - httpRespondJson(stream, "200 OK", fbs.getWritten()); 420 + respondJson(request, .ok, fbs.getWritten()); 451 421 } 452 422 453 - fn handleAdminBlockHost(stream: std.net.Stream, headers: []const u8, body: []const u8, persist: *event_log_mod.DiskPersist) void { 454 - if (!checkAdmin(stream, headers)) return; 423 + fn handleAdminBlockHost(request: *http.Server.Request, persist: *event_log_mod.DiskPersist) void { 424 + if (!checkAdmin(request)) return; 425 + 426 + var transfer_buf: [4096]u8 = undefined; 427 + const body_reader = request.readerExpectNone(&transfer_buf); 428 + var body_buf: [4096]u8 = undefined; 429 + const body_len = body_reader.readSliceShort(&body_buf) catch { 430 + respondJson(request, .bad_request, "{\"error\":\"failed to read request body\"}"); 431 + return; 432 + }; 433 + const body = body_buf[0..body_len]; 455 434 456 435 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 457 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 436 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 458 437 return; 459 438 }; 460 439 defer parsed.deinit(); 461 440 462 441 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch { 463 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 442 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 464 443 return; 465 444 }; 466 445 467 446 persist.updateHostStatus(host_info.id, "blocked") catch { 468 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 447 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 469 448 return; 470 449 }; 471 450 472 451 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id }); 473 - httpRespondJson(stream, "200 OK", "{\"success\":true}"); 452 + respondJson(request, .ok, "{\"success\":true}"); 474 453 } 475 454 476 - fn handleAdminUnblockHost(stream: std.net.Stream, headers: []const u8, body: []const u8, persist: *event_log_mod.DiskPersist) void { 477 - if (!checkAdmin(stream, headers)) return; 455 + fn handleAdminUnblockHost(request: *http.Server.Request, persist: *event_log_mod.DiskPersist) void { 456 + if (!checkAdmin(request)) return; 457 + 458 + var transfer_buf: [4096]u8 = undefined; 459 + const body_reader = request.readerExpectNone(&transfer_buf); 460 + var body_buf: [4096]u8 = undefined; 461 + const body_len = body_reader.readSliceShort(&body_buf) catch { 462 + respondJson(request, .bad_request, "{\"error\":\"failed to read request body\"}"); 463 + return; 464 + }; 465 + const body = body_buf[0..body_len]; 478 466 479 467 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 480 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 468 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 481 469 return; 482 470 }; 483 471 defer parsed.deinit(); 484 472 485 473 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch { 486 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 474 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 487 475 return; 488 476 }; 489 477 490 478 persist.updateHostStatus(host_info.id, "active") catch { 491 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 479 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 492 480 return; 493 481 }; 494 482 persist.resetHostFailures(host_info.id) catch {}; 495 483 496 484 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id }); 497 - httpRespondJson(stream, "200 OK", "{\"success\":true}"); 485 + respondJson(request, .ok, "{\"success\":true}"); 498 486 } 499 487 500 488 /// check admin auth, send error response if not authorized. returns true if authorized. 501 - fn checkAdmin(stream: std.net.Stream, headers: []const u8) bool { 489 + fn checkAdmin(request: *http.Server.Request) bool { 502 490 const admin_pw = std.posix.getenv("RELAY_ADMIN_PASSWORD") orelse { 503 - httpRespond(stream, "403 Forbidden", "application/json", "{\"error\":\"admin endpoint not configured\"}"); 491 + respondJson(request, .forbidden, "{\"error\":\"admin endpoint not configured\"}"); 504 492 return false; 505 493 }; 506 494 507 - const auth_value = findHeader(headers, "authorization") orelse { 508 - httpRespond(stream, "401 Unauthorized", "application/json", "{\"error\":\"missing authorization header\"}"); 509 - return false; 510 - }; 511 - const bearer_prefix = "Bearer "; 512 - if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) { 513 - httpRespond(stream, "401 Unauthorized", "application/json", "{\"error\":\"invalid authorization scheme\"}"); 514 - return false; 515 - } 516 - const token = auth_value[bearer_prefix.len..]; 517 - if (!std.mem.eql(u8, token, admin_pw)) { 518 - httpRespond(stream, "401 Unauthorized", "application/json", "{\"error\":\"invalid token\"}"); 519 - return false; 495 + var iter = request.iterateHeaders(); 496 + while (iter.next()) |header| { 497 + if (std.ascii.eqlIgnoreCase(header.name, "authorization")) { 498 + const bearer_prefix = "Bearer "; 499 + if (!std.mem.startsWith(u8, header.value, bearer_prefix)) { 500 + respondJson(request, .unauthorized, "{\"error\":\"invalid authorization scheme\"}"); 501 + return false; 502 + } 503 + const token = header.value[bearer_prefix.len..]; 504 + if (!std.mem.eql(u8, token, admin_pw)) { 505 + respondJson(request, .unauthorized, "{\"error\":\"invalid token\"}"); 506 + return false; 507 + } 508 + return true; 509 + } 520 510 } 521 - return true; 511 + 512 + respondJson(request, .unauthorized, "{\"error\":\"missing authorization header\"}"); 513 + return false; 522 514 } 523 515 524 516 // --- XRPC endpoint handlers --- 525 517 526 - fn handleListRepos(stream: std.net.Stream, query: []const u8, persist: *event_log_mod.DiskPersist) void { 518 + fn handleListRepos(request: *http.Server.Request, query: []const u8, persist: *event_log_mod.DiskPersist) void { 527 519 const cursor_str = queryParam(query, "cursor") orelse "0"; 528 520 const limit_str = queryParam(query, "limit") orelse "500"; 529 521 530 522 const cursor_val = std.fmt.parseInt(i64, cursor_str, 10) catch { 531 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid cursor\"}"); 523 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid cursor\"}"); 532 524 return; 533 525 }; 534 526 if (cursor_val < 0) { 535 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"cursor must be >= 0\"}"); 527 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"cursor must be >= 0\"}"); 536 528 return; 537 529 } 538 530 539 531 const limit = std.fmt.parseInt(i64, limit_str, 10) catch { 540 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 532 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 541 533 return; 542 534 }; 543 535 if (limit < 1 or limit > 1000) { 544 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}"); 536 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}"); 545 537 return; 546 538 } 547 539 ··· 552 544 \\FROM account a LEFT JOIN account_repo r ON a.uid = r.uid 553 545 \\WHERE a.uid > $1 ORDER BY a.uid ASC LIMIT $2 554 546 , .{ cursor_val, limit }) catch { 555 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 547 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 556 548 return; 557 549 }; 558 550 defer result.deinit(); ··· 621 613 622 614 w.writeByte('}') catch return; 623 615 624 - const resp_body = fbs.getWritten(); 625 - httpRespondJson(stream, "200 OK", resp_body); 616 + respondJson(request, .ok, fbs.getWritten()); 626 617 } 627 618 628 - fn handleGetRepoStatus(stream: std.net.Stream, query: []const u8, persist: *event_log_mod.DiskPersist) void { 619 + fn handleGetRepoStatus(request: *http.Server.Request, query: []const u8, persist: *event_log_mod.DiskPersist) void { 629 620 var did_buf: [256]u8 = undefined; 630 621 const did = queryParamDecoded(query, "did", &did_buf) orelse { 631 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}"); 622 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}"); 632 623 return; 633 624 }; 634 625 635 626 // basic DID syntax check 636 627 if (!std.mem.startsWith(u8, did, "did:")) { 637 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}"); 628 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}"); 638 629 return; 639 630 } 640 631 ··· 643 634 "SELECT a.uid, a.status, a.upstream_status, COALESCE(r.rev, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 644 635 .{did}, 645 636 ) catch { 646 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 637 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 647 638 return; 648 639 }) orelse { 649 - httpRespondJson(stream, "404 Not Found", "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}"); 640 + respondJson(request, .not_found, "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}"); 650 641 return; 651 642 }; 652 643 defer row.deinit() catch {}; ··· 683 674 } 684 675 685 676 w.writeByte('}') catch return; 686 - httpRespondJson(stream, "200 OK", fbs.getWritten()); 677 + respondJson(request, .ok, fbs.getWritten()); 687 678 } 688 679 689 - fn handleGetLatestCommit(stream: std.net.Stream, query: []const u8, persist: *event_log_mod.DiskPersist) void { 680 + fn handleGetLatestCommit(request: *http.Server.Request, query: []const u8, persist: *event_log_mod.DiskPersist) void { 690 681 var did_buf: [256]u8 = undefined; 691 682 const did = queryParamDecoded(query, "did", &did_buf) orelse { 692 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}"); 683 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}"); 693 684 return; 694 685 }; 695 686 696 687 if (!std.mem.startsWith(u8, did, "did:")) { 697 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}"); 688 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}"); 698 689 return; 699 690 } 700 691 ··· 703 694 "SELECT a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 704 695 .{did}, 705 696 ) catch { 706 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 697 + respondJson(request, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 707 698 return; 708 699 }) orelse { 709 - httpRespondJson(stream, "404 Not Found", "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}"); 700 + respondJson(request, .not_found, "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}"); 710 701 return; 711 702 }; 712 703 defer row.deinit() catch {}; ··· 721 712 722 713 // check account status (match Go relay behavior) 723 714 if (std.mem.eql(u8, status, "takendown") or std.mem.eql(u8, status, "suspended")) { 724 - httpRespondJson(stream, "403 Forbidden", "{\"error\":\"RepoTakendown\",\"message\":\"account has been taken down\"}"); 715 + respondJson(request, .forbidden, "{\"error\":\"RepoTakendown\",\"message\":\"account has been taken down\"}"); 725 716 return; 726 717 } else if (std.mem.eql(u8, status, "deactivated")) { 727 - httpRespondJson(stream, "403 Forbidden", "{\"error\":\"RepoDeactivated\",\"message\":\"account is deactivated\"}"); 718 + respondJson(request, .forbidden, "{\"error\":\"RepoDeactivated\",\"message\":\"account is deactivated\"}"); 728 719 return; 729 720 } else if (std.mem.eql(u8, status, "deleted")) { 730 - httpRespondJson(stream, "403 Forbidden", "{\"error\":\"RepoDeleted\",\"message\":\"account is deleted\"}"); 721 + respondJson(request, .forbidden, "{\"error\":\"RepoDeleted\",\"message\":\"account is deleted\"}"); 731 722 return; 732 723 } else if (!std.mem.eql(u8, status, "active")) { 733 - httpRespondJson(stream, "403 Forbidden", "{\"error\":\"RepoInactive\",\"message\":\"account is not active\"}"); 724 + respondJson(request, .forbidden, "{\"error\":\"RepoInactive\",\"message\":\"account is not active\"}"); 734 725 return; 735 726 } 736 727 737 728 if (rev.len == 0 or cid.len == 0) { 738 - httpRespondJson(stream, "404 Not Found", "{\"error\":\"RepoNotSynchronized\",\"message\":\"relay has no repo data for this account\"}"); 729 + respondJson(request, .not_found, "{\"error\":\"RepoNotSynchronized\",\"message\":\"relay has no repo data for this account\"}"); 739 730 return; 740 731 } 741 732 ··· 749 740 w.writeAll(rev) catch return; 750 741 w.writeAll("\"}") catch return; 751 742 752 - httpRespondJson(stream, "200 OK", fbs.getWritten()); 743 + respondJson(request, .ok, fbs.getWritten()); 753 744 } 754 745 755 - fn handleListReposByCollection(stream: std.net.Stream, query: []const u8, ci: *collection_index_mod.CollectionIndex) void { 746 + fn handleListReposByCollection(request: *http.Server.Request, query: []const u8, ci: *collection_index_mod.CollectionIndex) void { 756 747 const collection = queryParam(query, "collection") orelse { 757 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"collection parameter required\"}"); 748 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"collection parameter required\"}"); 758 749 return; 759 750 }; 760 751 761 752 if (collection.len == 0 or !std.mem.containsAtLeast(u8, collection, 1, ".")) { 762 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid collection NSID\"}"); 753 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid collection NSID\"}"); 763 754 return; 764 755 } 765 756 766 757 const limit_str = queryParam(query, "limit") orelse "500"; 767 758 const limit = std.fmt.parseInt(usize, limit_str, 10) catch { 768 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 759 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 769 760 return; 770 761 }; 771 762 if (limit < 1 or limit > 1000) { 772 - httpRespondJson(stream, "400 Bad Request", "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}"); 763 + respondJson(request, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}"); 773 764 return; 774 765 } 775 766 ··· 779 770 // scan collection index 780 771 var did_buf: [65536]u8 = undefined; 781 772 const result = ci.listReposByCollection(collection, limit, cursor_did, &did_buf) catch { 782 - httpRespondJson(stream, "500 Internal Server Error", "{\"error\":\"InternalError\",\"message\":\"index scan failed\"}"); 773 + respondJson(request, .internal_server_error, "{\"error\":\"InternalError\",\"message\":\"index scan failed\"}"); 783 774 return; 784 775 }; 785 776 ··· 806 797 } 807 798 808 799 w.writeByte('}') catch return; 809 - httpRespondJson(stream, "200 OK", fbs.getWritten()); 800 + respondJson(request, .ok, fbs.getWritten()); 810 801 } 811 802 812 803 // --- query string helpers --- ··· 873 864 }; 874 865 } 875 866 876 - fn httpRespondJson(stream: std.net.Stream, status: []const u8, body: []const u8) void { 877 - httpRespond(stream, status, "application/json", body); 878 - } 867 + // --- response helpers --- 879 868 880 - fn findHeader(headers: []const u8, name: []const u8) ?[]const u8 { 881 - var iter = std.mem.splitScalar(u8, headers, '\n'); 882 - while (iter.next()) |line| { 883 - const trimmed = std.mem.trimRight(u8, line, "\r"); 884 - const colon = std.mem.indexOfScalar(u8, trimmed, ':') orelse continue; 885 - const key = std.mem.trim(u8, trimmed[0..colon], " "); 886 - if (std.ascii.eqlIgnoreCase(key, name)) { 887 - return std.mem.trim(u8, trimmed[colon + 1 ..], " "); 888 - } 889 - } 890 - return null; 869 + fn respondJson(request: *http.Server.Request, status: http.Status, body: []const u8) void { 870 + request.respond(body, .{ .status = status, .keep_alive = false, .extra_headers = &.{.{ .name = "content-type", .value = "application/json" }} }) catch {}; 891 871 } 892 872 893 - fn httpRespond(stream: std.net.Stream, status: []const u8, content_type: []const u8, body: []const u8) void { 894 - // write headers first, then body separately (body can be much larger than header buffer) 895 - var hdr_buf: [512]u8 = undefined; 896 - const hdr = std.fmt.bufPrint(&hdr_buf, "HTTP/1.1 {s}\r\nContent-Type: {s}\r\nContent-Length: {d}\r\nServer: zlay (atproto-relay)\r\nConnection: close\r\n\r\n", .{ status, content_type, body.len }) catch return; 897 - _ = stream.write(hdr) catch return; 898 - _ = stream.write(body) catch {}; 873 + fn respondText(request: *http.Server.Request, status: http.Status, body: []const u8) void { 874 + request.respond(body, .{ .status = status, .keep_alive = false, .extra_headers = &.{.{ .name = "content-type", .value = "text/plain" }} }) catch {}; 899 875 } 900 876 901 877 fn parseEnvInt(comptime T: type, key: []const u8, default: T) T {