atproto relay implementation in zig
zlay.waow.tech
1//! admin endpoint handlers for relay management.
2//!
3//! all handlers require Bearer token auth against RELAY_ADMIN_PASSWORD.
4//! includes host blocking/unblocking, account bans, and backfill control.
5//!
6//! DB-accessing handlers use DbRequest + DbRequestQueue to route queries
7//! through pool_io workers.
8
9const std = @import("std");
10const Io = std.Io;
11const h = @import("http.zig");
12const router = @import("router.zig");
13const websocket = @import("websocket");
14const event_log_mod = @import("../event_log.zig");
15const backfill_mod = @import("../backfill.zig");
16const cleaner_mod = @import("../cleaner.zig");
17const resync_mod = @import("../resync.zig");
18
19const log = std.log.scoped(.relay);
20
21const HttpContext = router.HttpContext;
22const DbRequest = event_log_mod.DbRequest;
23const DiskPersist = event_log_mod.DiskPersist;
24
25/// check admin auth via headers, send error response if not authorized. returns true if authorized.
26pub fn checkAdmin(conn: *h.Conn, headers: ?*const websocket.Handshake.KeyValue) bool {
27 const admin_pw = getenv("RELAY_ADMIN_PASSWORD") orelse {
28 h.respondJson(conn, .forbidden, "{\"error\":\"admin endpoint not configured\"}");
29 return false;
30 };
31
32 const kv = headers orelse {
33 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}");
34 return false;
35 };
36
37 // handshake parser lowercases all header names
38 const auth_value = kv.get("authorization") orelse {
39 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}");
40 return false;
41 };
42
43 const bearer_prefix = "Bearer ";
44 if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) {
45 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid authorization scheme\"}");
46 return false;
47 }
48 const token = auth_value[bearer_prefix.len..];
49 if (!std.mem.eql(u8, token, admin_pw)) {
50 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid token\"}");
51 return false;
52 }
53 return true;
54}
55
56pub fn handleBan(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
57 if (!checkAdmin(conn, headers)) return;
58
59 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
60 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}");
61 return;
62 };
63 defer parsed.deinit();
64 const did = parsed.value.did;
65
66 // resolve DID → UID via DbRequestQueue
67 const UidReq = struct {
68 base: DbRequest = .{ .callback = &execute },
69 did_buf: [256]u8 = undefined,
70 did_len: usize = 0,
71 uid: u64 = 0,
72
73 fn execute(b: *DbRequest, dp: *DiskPersist) void {
74 const self: *@This() = @fieldParentPtr("base", b);
75 const d = self.did_buf[0..self.did_len];
76 // check database
77 if (dp.db.rowUnsafe("SELECT uid FROM account WHERE did = $1", .{d}) catch null) |row| {
78 var r = row;
79 defer r.deinit() catch {};
80 self.uid = @intCast(r.get(i64, 0));
81 return;
82 }
83 // create new account row
84 _ = dp.db.exec("INSERT INTO account (did) VALUES ($1) ON CONFLICT (did) DO NOTHING", .{d}) catch {
85 b.err = error.DatabaseError;
86 return;
87 };
88 var row = dp.db.rowUnsafe("SELECT uid FROM account WHERE did = $1", .{d}) catch {
89 b.err = error.DatabaseError;
90 return;
91 } orelse {
92 b.err = error.AccountCreationFailed;
93 return;
94 };
95 defer row.deinit() catch {};
96 self.uid = @intCast(row.get(i64, 0));
97 }
98 };
99 var uid_req: UidReq = .{};
100 const copy_len = @min(did.len, uid_req.did_buf.len);
101 @memcpy(uid_req.did_buf[0..copy_len], did[0..copy_len]);
102 uid_req.did_len = copy_len;
103 ctx.db_queue.push(&uid_req.base);
104 uid_req.base.wait(ctx.io, ctx.shutdown);
105
106 if (uid_req.base.err != null) {
107 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to resolve DID\"}");
108 return;
109 }
110
111 // remove from collection index so banned accounts don't appear in listReposByCollection
112 ctx.collection_index.removeAll(did) catch |err| {
113 log.debug("collection removeAll after ban failed: {s}", .{@errorName(err)});
114 };
115
116 // build CBOR #account frame and route takedown + persist + broadcast
117 // through host_ops queue (pool_io thread) — fire and forget.
118 const host_ops_mod = @import("../host_ops.zig");
119 var td: host_ops_mod.HostOp.Payload.Takedown = .{ .uid = uid_req.uid };
120
121 if (buildAccountFrame(ctx.persist.allocator, did)) |frame_bytes| {
122 defer ctx.persist.allocator.free(frame_bytes);
123 if (frame_bytes.len <= td.frame_buf.len) {
124 @memcpy(td.frame_buf[0..frame_bytes.len], frame_bytes);
125 td.frame_len = @intCast(frame_bytes.len);
126 }
127 }
128
129 ctx.host_ops.push(.{
130 .host_id = 0, // not host-specific
131 .kind = .takedown_user,
132 .payload = .{ .takedown = td },
133 });
134
135 log.info("admin: banned {s} (uid={d}), takedown enqueued", .{ did, uid_req.uid });
136 h.respondJson(conn, .ok, "{\"success\":true}");
137}
138
139pub fn handleAdminListHosts(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
140 if (!checkAdmin(conn, headers)) return;
141
142 // list all hosts via DbRequestQueue
143 const ListAllHostsReq = struct {
144 base: DbRequest = .{ .callback = &execute },
145 alloc: std.mem.Allocator,
146 result: ?[]event_log_mod.DiskPersist.Host = null,
147
148 fn execute(b: *DbRequest, dp: *DiskPersist) void {
149 const self: *@This() = @fieldParentPtr("base", b);
150 self.result = dp.listAllHosts(self.alloc) catch |e| {
151 b.err = e;
152 return;
153 };
154 }
155 };
156 var list_req: ListAllHostsReq = .{ .alloc = ctx.persist.allocator };
157 ctx.db_queue.push(&list_req.base);
158 list_req.base.wait(ctx.io, ctx.shutdown);
159
160 if (list_req.base.err != null or list_req.result == null) {
161 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
162 return;
163 }
164
165 const hosts = list_req.result.?;
166 defer {
167 for (hosts) |host| {
168 ctx.persist.allocator.free(host.hostname);
169 ctx.persist.allocator.free(host.status);
170 }
171 ctx.persist.allocator.free(hosts);
172 }
173
174 var aw: Io.Writer.Allocating = .init(ctx.persist.allocator);
175 defer aw.deinit();
176 const w = &aw.writer;
177
178 w.writeAll("{\"hosts\":[") catch return;
179
180 for (hosts, 0..) |host, i| {
181 if (i > 0) w.writeByte(',') catch return;
182 if (host.account_limit) |limit| {
183 w.print("{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":{d}}}", .{
184 host.id,
185 host.hostname,
186 host.status,
187 host.last_seq,
188 host.failed_attempts,
189 limit,
190 }) catch return;
191 } else {
192 w.print("{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":null}}", .{
193 host.id,
194 host.hostname,
195 host.status,
196 host.last_seq,
197 host.failed_attempts,
198 }) catch return;
199 }
200 }
201
202 w.print("],\"active_workers\":{d}}}", .{ctx.slurper.workerCount()}) catch return;
203 h.respondJson(conn, .ok, aw.written());
204}
205
206pub fn handleAdminBlockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
207 if (!checkAdmin(conn, headers)) return;
208
209 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
210 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}");
211 return;
212 };
213 defer parsed.deinit();
214
215 const BlockHostReq = struct {
216 base: DbRequest = .{ .callback = &execute },
217 hostname_buf: [256]u8 = undefined,
218 hostname_len: usize = 0,
219 host_id: u64 = 0,
220
221 fn execute(b: *DbRequest, dp: *DiskPersist) void {
222 const self: *@This() = @fieldParentPtr("base", b);
223 const hn = self.hostname_buf[0..self.hostname_len];
224 const info = dp.getOrCreateHost(hn) catch |e| {
225 b.err = e;
226 return;
227 };
228 self.host_id = info.id;
229 dp.updateHostStatus(info.id, "blocked") catch |e| {
230 b.err = e;
231 return;
232 };
233 }
234 };
235 var req: BlockHostReq = .{};
236 const copy_len = @min(parsed.value.hostname.len, req.hostname_buf.len);
237 @memcpy(req.hostname_buf[0..copy_len], parsed.value.hostname[0..copy_len]);
238 req.hostname_len = copy_len;
239 ctx.db_queue.push(&req.base);
240 req.base.wait(ctx.io, ctx.shutdown);
241
242 if (req.base.err != null) {
243 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"operation failed\"}");
244 return;
245 }
246
247 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, req.host_id });
248 h.respondJson(conn, .ok, "{\"success\":true}");
249}
250
251pub fn handleAdminUnblockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
252 if (!checkAdmin(conn, headers)) return;
253
254 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
255 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}");
256 return;
257 };
258 defer parsed.deinit();
259
260 const UnblockHostReq = struct {
261 base: DbRequest = .{ .callback = &execute },
262 hostname_buf: [256]u8 = undefined,
263 hostname_len: usize = 0,
264 host_id: u64 = 0,
265
266 fn execute(b: *DbRequest, dp: *DiskPersist) void {
267 const self: *@This() = @fieldParentPtr("base", b);
268 const hn = self.hostname_buf[0..self.hostname_len];
269 const info = dp.getOrCreateHost(hn) catch |e| {
270 b.err = e;
271 return;
272 };
273 self.host_id = info.id;
274 dp.updateHostStatus(info.id, "active") catch |e| {
275 b.err = e;
276 return;
277 };
278 dp.resetHostFailures(info.id) catch {};
279 }
280 };
281 var req: UnblockHostReq = .{};
282 const copy_len = @min(parsed.value.hostname.len, req.hostname_buf.len);
283 @memcpy(req.hostname_buf[0..copy_len], parsed.value.hostname[0..copy_len]);
284 req.hostname_len = copy_len;
285 ctx.db_queue.push(&req.base);
286 req.base.wait(ctx.io, ctx.shutdown);
287
288 if (req.base.err != null) {
289 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"operation failed\"}");
290 return;
291 }
292
293 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, req.host_id });
294 h.respondJson(conn, .ok, "{\"success\":true}");
295}
296
297/// set or clear the account_limit override for a host.
298pub fn handleAdminChangeLimits(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
299 if (!checkAdmin(conn, headers)) return;
300
301 const parsed = std.json.parseFromSlice(
302 struct { host: []const u8, account_limit: ?u64 },
303 ctx.persist.allocator,
304 body,
305 .{ .ignore_unknown_fields = true },
306 ) catch {
307 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"host\\\":\\\"...\\\",\\\"account_limit\\\":...}\"}");
308 return;
309 };
310 defer parsed.deinit();
311
312 const ChangeLimitsReq = struct {
313 base: DbRequest = .{ .callback = &execute },
314 hostname_buf: [256]u8 = undefined,
315 hostname_len: usize = 0,
316 new_limit: ?u64,
317 host_id: ?u64 = null,
318 effective: u64 = 0,
319
320 fn execute(b: *DbRequest, dp: *DiskPersist) void {
321 const self: *@This() = @fieldParentPtr("base", b);
322 const hn = self.hostname_buf[0..self.hostname_len];
323 self.host_id = dp.getHostIdForHostname(hn) catch |e| {
324 b.err = e;
325 return;
326 };
327 const hid = self.host_id orelse return;
328 dp.setHostAccountLimit(hid, self.new_limit) catch |e| {
329 b.err = e;
330 return;
331 };
332 self.effective = if (self.new_limit) |l| l else dp.getHostAccountCount(hid);
333 }
334 };
335 var req: ChangeLimitsReq = .{ .new_limit = parsed.value.account_limit };
336 const copy_len = @min(parsed.value.host.len, req.hostname_buf.len);
337 @memcpy(req.hostname_buf[0..copy_len], parsed.value.host[0..copy_len]);
338 req.hostname_len = copy_len;
339 ctx.db_queue.push(&req.base);
340 req.base.wait(ctx.io, ctx.shutdown);
341
342 if (req.base.err != null) {
343 h.respondJson(conn, .internal_server_error, "{\"error\":\"database error\"}");
344 return;
345 }
346 const host_id = req.host_id orelse {
347 h.respondJson(conn, .not_found, "{\"error\":\"host not found\"}");
348 return;
349 };
350
351 // update running subscriber's rate limits immediately
352 ctx.slurper.updateHostLimits(host_id, req.effective);
353
354 if (parsed.value.account_limit) |limit| {
355 log.info("admin: set account_limit for {s} (id={d}): {d}", .{ parsed.value.host, host_id, limit });
356 } else {
357 log.info("admin: cleared account_limit for {s} (id={d}), reverted to COUNT(*)", .{ parsed.value.host, host_id });
358 }
359 h.respondJson(conn, .ok, "{\"success\":true}");
360}
361
362pub fn handleAdminBackfillTrigger(conn: *h.Conn, query: []const u8, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void {
363 if (!checkAdmin(conn, headers)) return;
364
365 const source = h.queryParam(query, "source") orelse "bsky.network";
366
367 backfiller.start(source) catch |err| {
368 switch (err) {
369 error.AlreadyRunning => {
370 h.respondJson(conn, .conflict, "{\"error\":\"backfill already in progress\"}");
371 },
372 else => {
373 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start backfill\"}");
374 },
375 }
376 return;
377 };
378
379 var buf: [256]u8 = undefined;
380 const resp_body = std.fmt.bufPrint(&buf, "{{\"status\":\"started\",\"source\":\"{s}\"}}", .{source}) catch {
381 h.respondJson(conn, .ok, "{\"status\":\"started\"}");
382 return;
383 };
384 h.respondJson(conn, .ok, resp_body);
385}
386
387pub fn handleAdminBackfillStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void {
388 if (!checkAdmin(conn, headers)) return;
389
390 const body = backfiller.getStatus(backfiller.allocator) catch {
391 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to query backfill status\"}");
392 return;
393 };
394 defer backfiller.allocator.free(body);
395
396 h.respondJson(conn, .ok, body);
397}
398
399pub fn handleCleanupTrigger(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, cleaner: *cleaner_mod.Cleaner) void {
400 if (!checkAdmin(conn, headers)) return;
401
402 cleaner.start() catch |err| {
403 switch (err) {
404 error.AlreadyRunning => {
405 h.respondJson(conn, .conflict, "{\"error\":\"cleanup already in progress\"}");
406 },
407 else => {
408 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start cleanup\"}");
409 },
410 }
411 return;
412 };
413
414 h.respondJson(conn, .ok, "{\"status\":\"started\"}");
415}
416
417pub fn handleCleanupStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, cleaner: *cleaner_mod.Cleaner) void {
418 if (!checkAdmin(conn, headers)) return;
419
420 const status = cleaner.getStatus();
421 var buf: [256]u8 = undefined;
422 const body = std.fmt.bufPrint(&buf, "{{\"running\":{},\"scanned\":{d},\"removed\":{d}}}", .{
423 status.running,
424 status.scanned,
425 status.removed,
426 }) catch {
427 h.respondJson(conn, .internal_server_error, "{\"error\":\"format failed\"}");
428 return;
429 };
430 h.respondJson(conn, .ok, body);
431}
432
433pub fn handleResyncStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, resyncer: *resync_mod.Resyncer) void {
434 if (!checkAdmin(conn, headers)) return;
435
436 var buf: [256]u8 = undefined;
437 const body = std.fmt.bufPrint(&buf, "{{\"processed\":{d},\"failed\":{d},\"dropped\":{d},\"queue_depth\":{d}}}", .{
438 resyncer.processed.load(.monotonic),
439 resyncer.failed.load(.monotonic),
440 resyncer.dropped.load(.monotonic),
441 resyncer.queueDepth(),
442 }) catch {
443 h.respondJson(conn, .internal_server_error, "{\"error\":\"format failed\"}");
444 return;
445 };
446 h.respondJson(conn, .ok, body);
447}
448
449pub fn handleResyncTrigger(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, resyncer: *resync_mod.Resyncer) void {
450 if (!checkAdmin(conn, headers)) return;
451
452 const parsed = std.json.parseFromSlice(
453 struct { did: []const u8, hostname: []const u8 },
454 std.heap.c_allocator,
455 body,
456 .{ .ignore_unknown_fields = true },
457 ) catch {
458 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\",\\\"hostname\\\":\\\"...\\\"}\"}");
459 return;
460 };
461 defer parsed.deinit();
462
463 resyncer.enqueue(parsed.value.did, parsed.value.hostname);
464 h.respondJson(conn, .ok, "{\"status\":\"enqueued\"}");
465}
466
467// --- protocol helpers (used only by handleBan) ---
468
469/// build a CBOR #account frame for a takedown event.
470fn buildAccountFrame(allocator: std.mem.Allocator, did: []const u8) ?[]const u8 {
471 const zat = @import("zat");
472 const cbor = zat.cbor;
473
474 const header: cbor.Value = .{ .map = &.{
475 .{ .key = "op", .value = .{ .unsigned = 1 } },
476 .{ .key = "t", .value = .{ .text = "#account" } },
477 } };
478
479 var time_buf: [24]u8 = undefined;
480 const time_str = formatTimestamp(&time_buf);
481
482 const payload: cbor.Value = .{ .map = &.{
483 .{ .key = "seq", .value = .{ .unsigned = 0 } },
484 .{ .key = "did", .value = .{ .text = did } },
485 .{ .key = "time", .value = .{ .text = time_str } },
486 .{ .key = "active", .value = .{ .boolean = false } },
487 .{ .key = "status", .value = .{ .text = "takendown" } },
488 } };
489
490 const header_bytes = cbor.encodeAlloc(allocator, header) catch return null;
491 const payload_bytes = cbor.encodeAlloc(allocator, payload) catch {
492 allocator.free(header_bytes);
493 return null;
494 };
495
496 var frame = allocator.alloc(u8, header_bytes.len + payload_bytes.len) catch {
497 allocator.free(header_bytes);
498 allocator.free(payload_bytes);
499 return null;
500 };
501 @memcpy(frame[0..header_bytes.len], header_bytes);
502 @memcpy(frame[header_bytes.len..], payload_bytes);
503
504 allocator.free(header_bytes);
505 allocator.free(payload_bytes);
506
507 return frame;
508}
509
510fn formatTimestamp(buf: *[24]u8) []const u8 {
511 var tp: std.c.timespec = undefined;
512 _ = std.c.clock_gettime(.REALTIME, &tp);
513 const ts: u64 = @intCast(tp.sec);
514 const es = std.time.epoch.EpochSeconds{ .secs = ts };
515 const day = es.getEpochDay();
516 const yd = day.calculateYearDay();
517 const md = yd.calculateMonthDay();
518 const ds = es.getDaySeconds();
519
520 return std.fmt.bufPrint(buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{
521 yd.year,
522 @as(u32, @intFromEnum(md.month)) + 1,
523 @as(u32, md.day_index) + 1,
524 ds.getHoursIntoDay(),
525 ds.getMinutesIntoHour(),
526 ds.getSecondsIntoMinute(),
527 }) catch "1970-01-01T00:00:00Z";
528}
529
530fn getenv(key: [*:0]const u8) ?[]const u8 {
531 const ptr = std.c.getenv(key) orelse return null;
532 return std.mem.sliceTo(ptr, 0);
533}