atproto relay implementation in zig zlay.waow.tech
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 533 lines 20 kB view raw
1//! admin endpoint handlers for relay management. 2//! 3//! all handlers require Bearer token auth against RELAY_ADMIN_PASSWORD. 4//! includes host blocking/unblocking, account bans, and backfill control. 5//! 6//! DB-accessing handlers use DbRequest + DbRequestQueue to route queries 7//! through pool_io workers. 8 9const std = @import("std"); 10const Io = std.Io; 11const h = @import("http.zig"); 12const router = @import("router.zig"); 13const websocket = @import("websocket"); 14const event_log_mod = @import("../event_log.zig"); 15const backfill_mod = @import("../backfill.zig"); 16const cleaner_mod = @import("../cleaner.zig"); 17const resync_mod = @import("../resync.zig"); 18 19const log = std.log.scoped(.relay); 20 21const HttpContext = router.HttpContext; 22const DbRequest = event_log_mod.DbRequest; 23const DiskPersist = event_log_mod.DiskPersist; 24 25/// check admin auth via headers, send error response if not authorized. returns true if authorized. 26pub fn checkAdmin(conn: *h.Conn, headers: ?*const websocket.Handshake.KeyValue) bool { 27 const admin_pw = getenv("RELAY_ADMIN_PASSWORD") orelse { 28 h.respondJson(conn, .forbidden, "{\"error\":\"admin endpoint not configured\"}"); 29 return false; 30 }; 31 32 const kv = headers orelse { 33 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}"); 34 return false; 35 }; 36 37 // handshake parser lowercases all header names 38 const auth_value = kv.get("authorization") orelse { 39 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}"); 40 return false; 41 }; 42 43 const bearer_prefix = "Bearer "; 44 if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) { 45 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid authorization scheme\"}"); 46 return false; 47 } 48 const token = auth_value[bearer_prefix.len..]; 49 if (!std.mem.eql(u8, token, admin_pw)) { 50 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid token\"}"); 51 return false; 52 } 53 return true; 54} 55 56pub fn handleBan(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 57 if (!checkAdmin(conn, headers)) return; 58 59 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 60 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}"); 61 return; 62 }; 63 defer parsed.deinit(); 64 const did = parsed.value.did; 65 66 // resolve DID → UID via DbRequestQueue 67 const UidReq = struct { 68 base: DbRequest = .{ .callback = &execute }, 69 did_buf: [256]u8 = undefined, 70 did_len: usize = 0, 71 uid: u64 = 0, 72 73 fn execute(b: *DbRequest, dp: *DiskPersist) void { 74 const self: *@This() = @fieldParentPtr("base", b); 75 const d = self.did_buf[0..self.did_len]; 76 // check database 77 if (dp.db.rowUnsafe("SELECT uid FROM account WHERE did = $1", .{d}) catch null) |row| { 78 var r = row; 79 defer r.deinit() catch {}; 80 self.uid = @intCast(r.get(i64, 0)); 81 return; 82 } 83 // create new account row 84 _ = dp.db.exec("INSERT INTO account (did) VALUES ($1) ON CONFLICT (did) DO NOTHING", .{d}) catch { 85 b.err = error.DatabaseError; 86 return; 87 }; 88 var row = dp.db.rowUnsafe("SELECT uid FROM account WHERE did = $1", .{d}) catch { 89 b.err = error.DatabaseError; 90 return; 91 } orelse { 92 b.err = error.AccountCreationFailed; 93 return; 94 }; 95 defer row.deinit() catch {}; 96 self.uid = @intCast(row.get(i64, 0)); 97 } 98 }; 99 var uid_req: UidReq = .{}; 100 const copy_len = @min(did.len, uid_req.did_buf.len); 101 @memcpy(uid_req.did_buf[0..copy_len], did[0..copy_len]); 102 uid_req.did_len = copy_len; 103 ctx.db_queue.push(&uid_req.base); 104 uid_req.base.wait(ctx.io, ctx.shutdown); 105 106 if (uid_req.base.err != null) { 107 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to resolve DID\"}"); 108 return; 109 } 110 111 // remove from collection index so banned accounts don't appear in listReposByCollection 112 ctx.collection_index.removeAll(did) catch |err| { 113 log.debug("collection removeAll after ban failed: {s}", .{@errorName(err)}); 114 }; 115 116 // build CBOR #account frame and route takedown + persist + broadcast 117 // through host_ops queue (pool_io thread) — fire and forget. 118 const host_ops_mod = @import("../host_ops.zig"); 119 var td: host_ops_mod.HostOp.Payload.Takedown = .{ .uid = uid_req.uid }; 120 121 if (buildAccountFrame(ctx.persist.allocator, did)) |frame_bytes| { 122 defer ctx.persist.allocator.free(frame_bytes); 123 if (frame_bytes.len <= td.frame_buf.len) { 124 @memcpy(td.frame_buf[0..frame_bytes.len], frame_bytes); 125 td.frame_len = @intCast(frame_bytes.len); 126 } 127 } 128 129 ctx.host_ops.push(.{ 130 .host_id = 0, // not host-specific 131 .kind = .takedown_user, 132 .payload = .{ .takedown = td }, 133 }); 134 135 log.info("admin: banned {s} (uid={d}), takedown enqueued", .{ did, uid_req.uid }); 136 h.respondJson(conn, .ok, "{\"success\":true}"); 137} 138 139pub fn handleAdminListHosts(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 140 if (!checkAdmin(conn, headers)) return; 141 142 // list all hosts via DbRequestQueue 143 const ListAllHostsReq = struct { 144 base: DbRequest = .{ .callback = &execute }, 145 alloc: std.mem.Allocator, 146 result: ?[]event_log_mod.DiskPersist.Host = null, 147 148 fn execute(b: *DbRequest, dp: *DiskPersist) void { 149 const self: *@This() = @fieldParentPtr("base", b); 150 self.result = dp.listAllHosts(self.alloc) catch |e| { 151 b.err = e; 152 return; 153 }; 154 } 155 }; 156 var list_req: ListAllHostsReq = .{ .alloc = ctx.persist.allocator }; 157 ctx.db_queue.push(&list_req.base); 158 list_req.base.wait(ctx.io, ctx.shutdown); 159 160 if (list_req.base.err != null or list_req.result == null) { 161 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 162 return; 163 } 164 165 const hosts = list_req.result.?; 166 defer { 167 for (hosts) |host| { 168 ctx.persist.allocator.free(host.hostname); 169 ctx.persist.allocator.free(host.status); 170 } 171 ctx.persist.allocator.free(hosts); 172 } 173 174 var aw: Io.Writer.Allocating = .init(ctx.persist.allocator); 175 defer aw.deinit(); 176 const w = &aw.writer; 177 178 w.writeAll("{\"hosts\":[") catch return; 179 180 for (hosts, 0..) |host, i| { 181 if (i > 0) w.writeByte(',') catch return; 182 if (host.account_limit) |limit| { 183 w.print("{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":{d}}}", .{ 184 host.id, 185 host.hostname, 186 host.status, 187 host.last_seq, 188 host.failed_attempts, 189 limit, 190 }) catch return; 191 } else { 192 w.print("{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":null}}", .{ 193 host.id, 194 host.hostname, 195 host.status, 196 host.last_seq, 197 host.failed_attempts, 198 }) catch return; 199 } 200 } 201 202 w.print("],\"active_workers\":{d}}}", .{ctx.slurper.workerCount()}) catch return; 203 h.respondJson(conn, .ok, aw.written()); 204} 205 206pub fn handleAdminBlockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 207 if (!checkAdmin(conn, headers)) return; 208 209 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 210 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 211 return; 212 }; 213 defer parsed.deinit(); 214 215 const BlockHostReq = struct { 216 base: DbRequest = .{ .callback = &execute }, 217 hostname_buf: [256]u8 = undefined, 218 hostname_len: usize = 0, 219 host_id: u64 = 0, 220 221 fn execute(b: *DbRequest, dp: *DiskPersist) void { 222 const self: *@This() = @fieldParentPtr("base", b); 223 const hn = self.hostname_buf[0..self.hostname_len]; 224 const info = dp.getOrCreateHost(hn) catch |e| { 225 b.err = e; 226 return; 227 }; 228 self.host_id = info.id; 229 dp.updateHostStatus(info.id, "blocked") catch |e| { 230 b.err = e; 231 return; 232 }; 233 } 234 }; 235 var req: BlockHostReq = .{}; 236 const copy_len = @min(parsed.value.hostname.len, req.hostname_buf.len); 237 @memcpy(req.hostname_buf[0..copy_len], parsed.value.hostname[0..copy_len]); 238 req.hostname_len = copy_len; 239 ctx.db_queue.push(&req.base); 240 req.base.wait(ctx.io, ctx.shutdown); 241 242 if (req.base.err != null) { 243 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"operation failed\"}"); 244 return; 245 } 246 247 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, req.host_id }); 248 h.respondJson(conn, .ok, "{\"success\":true}"); 249} 250 251pub fn handleAdminUnblockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 252 if (!checkAdmin(conn, headers)) return; 253 254 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 255 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 256 return; 257 }; 258 defer parsed.deinit(); 259 260 const UnblockHostReq = struct { 261 base: DbRequest = .{ .callback = &execute }, 262 hostname_buf: [256]u8 = undefined, 263 hostname_len: usize = 0, 264 host_id: u64 = 0, 265 266 fn execute(b: *DbRequest, dp: *DiskPersist) void { 267 const self: *@This() = @fieldParentPtr("base", b); 268 const hn = self.hostname_buf[0..self.hostname_len]; 269 const info = dp.getOrCreateHost(hn) catch |e| { 270 b.err = e; 271 return; 272 }; 273 self.host_id = info.id; 274 dp.updateHostStatus(info.id, "active") catch |e| { 275 b.err = e; 276 return; 277 }; 278 dp.resetHostFailures(info.id) catch {}; 279 } 280 }; 281 var req: UnblockHostReq = .{}; 282 const copy_len = @min(parsed.value.hostname.len, req.hostname_buf.len); 283 @memcpy(req.hostname_buf[0..copy_len], parsed.value.hostname[0..copy_len]); 284 req.hostname_len = copy_len; 285 ctx.db_queue.push(&req.base); 286 req.base.wait(ctx.io, ctx.shutdown); 287 288 if (req.base.err != null) { 289 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"operation failed\"}"); 290 return; 291 } 292 293 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, req.host_id }); 294 h.respondJson(conn, .ok, "{\"success\":true}"); 295} 296 297/// set or clear the account_limit override for a host. 298pub fn handleAdminChangeLimits(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 299 if (!checkAdmin(conn, headers)) return; 300 301 const parsed = std.json.parseFromSlice( 302 struct { host: []const u8, account_limit: ?u64 }, 303 ctx.persist.allocator, 304 body, 305 .{ .ignore_unknown_fields = true }, 306 ) catch { 307 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"host\\\":\\\"...\\\",\\\"account_limit\\\":...}\"}"); 308 return; 309 }; 310 defer parsed.deinit(); 311 312 const ChangeLimitsReq = struct { 313 base: DbRequest = .{ .callback = &execute }, 314 hostname_buf: [256]u8 = undefined, 315 hostname_len: usize = 0, 316 new_limit: ?u64, 317 host_id: ?u64 = null, 318 effective: u64 = 0, 319 320 fn execute(b: *DbRequest, dp: *DiskPersist) void { 321 const self: *@This() = @fieldParentPtr("base", b); 322 const hn = self.hostname_buf[0..self.hostname_len]; 323 self.host_id = dp.getHostIdForHostname(hn) catch |e| { 324 b.err = e; 325 return; 326 }; 327 const hid = self.host_id orelse return; 328 dp.setHostAccountLimit(hid, self.new_limit) catch |e| { 329 b.err = e; 330 return; 331 }; 332 self.effective = if (self.new_limit) |l| l else dp.getHostAccountCount(hid); 333 } 334 }; 335 var req: ChangeLimitsReq = .{ .new_limit = parsed.value.account_limit }; 336 const copy_len = @min(parsed.value.host.len, req.hostname_buf.len); 337 @memcpy(req.hostname_buf[0..copy_len], parsed.value.host[0..copy_len]); 338 req.hostname_len = copy_len; 339 ctx.db_queue.push(&req.base); 340 req.base.wait(ctx.io, ctx.shutdown); 341 342 if (req.base.err != null) { 343 h.respondJson(conn, .internal_server_error, "{\"error\":\"database error\"}"); 344 return; 345 } 346 const host_id = req.host_id orelse { 347 h.respondJson(conn, .not_found, "{\"error\":\"host not found\"}"); 348 return; 349 }; 350 351 // update running subscriber's rate limits immediately 352 ctx.slurper.updateHostLimits(host_id, req.effective); 353 354 if (parsed.value.account_limit) |limit| { 355 log.info("admin: set account_limit for {s} (id={d}): {d}", .{ parsed.value.host, host_id, limit }); 356 } else { 357 log.info("admin: cleared account_limit for {s} (id={d}), reverted to COUNT(*)", .{ parsed.value.host, host_id }); 358 } 359 h.respondJson(conn, .ok, "{\"success\":true}"); 360} 361 362pub fn handleAdminBackfillTrigger(conn: *h.Conn, query: []const u8, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void { 363 if (!checkAdmin(conn, headers)) return; 364 365 const source = h.queryParam(query, "source") orelse "bsky.network"; 366 367 backfiller.start(source) catch |err| { 368 switch (err) { 369 error.AlreadyRunning => { 370 h.respondJson(conn, .conflict, "{\"error\":\"backfill already in progress\"}"); 371 }, 372 else => { 373 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start backfill\"}"); 374 }, 375 } 376 return; 377 }; 378 379 var buf: [256]u8 = undefined; 380 const resp_body = std.fmt.bufPrint(&buf, "{{\"status\":\"started\",\"source\":\"{s}\"}}", .{source}) catch { 381 h.respondJson(conn, .ok, "{\"status\":\"started\"}"); 382 return; 383 }; 384 h.respondJson(conn, .ok, resp_body); 385} 386 387pub fn handleAdminBackfillStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void { 388 if (!checkAdmin(conn, headers)) return; 389 390 const body = backfiller.getStatus(backfiller.allocator) catch { 391 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to query backfill status\"}"); 392 return; 393 }; 394 defer backfiller.allocator.free(body); 395 396 h.respondJson(conn, .ok, body); 397} 398 399pub fn handleCleanupTrigger(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, cleaner: *cleaner_mod.Cleaner) void { 400 if (!checkAdmin(conn, headers)) return; 401 402 cleaner.start() catch |err| { 403 switch (err) { 404 error.AlreadyRunning => { 405 h.respondJson(conn, .conflict, "{\"error\":\"cleanup already in progress\"}"); 406 }, 407 else => { 408 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start cleanup\"}"); 409 }, 410 } 411 return; 412 }; 413 414 h.respondJson(conn, .ok, "{\"status\":\"started\"}"); 415} 416 417pub fn handleCleanupStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, cleaner: *cleaner_mod.Cleaner) void { 418 if (!checkAdmin(conn, headers)) return; 419 420 const status = cleaner.getStatus(); 421 var buf: [256]u8 = undefined; 422 const body = std.fmt.bufPrint(&buf, "{{\"running\":{},\"scanned\":{d},\"removed\":{d}}}", .{ 423 status.running, 424 status.scanned, 425 status.removed, 426 }) catch { 427 h.respondJson(conn, .internal_server_error, "{\"error\":\"format failed\"}"); 428 return; 429 }; 430 h.respondJson(conn, .ok, body); 431} 432 433pub fn handleResyncStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, resyncer: *resync_mod.Resyncer) void { 434 if (!checkAdmin(conn, headers)) return; 435 436 var buf: [256]u8 = undefined; 437 const body = std.fmt.bufPrint(&buf, "{{\"processed\":{d},\"failed\":{d},\"dropped\":{d},\"queue_depth\":{d}}}", .{ 438 resyncer.processed.load(.monotonic), 439 resyncer.failed.load(.monotonic), 440 resyncer.dropped.load(.monotonic), 441 resyncer.queueDepth(), 442 }) catch { 443 h.respondJson(conn, .internal_server_error, "{\"error\":\"format failed\"}"); 444 return; 445 }; 446 h.respondJson(conn, .ok, body); 447} 448 449pub fn handleResyncTrigger(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, resyncer: *resync_mod.Resyncer) void { 450 if (!checkAdmin(conn, headers)) return; 451 452 const parsed = std.json.parseFromSlice( 453 struct { did: []const u8, hostname: []const u8 }, 454 std.heap.c_allocator, 455 body, 456 .{ .ignore_unknown_fields = true }, 457 ) catch { 458 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\",\\\"hostname\\\":\\\"...\\\"}\"}"); 459 return; 460 }; 461 defer parsed.deinit(); 462 463 resyncer.enqueue(parsed.value.did, parsed.value.hostname); 464 h.respondJson(conn, .ok, "{\"status\":\"enqueued\"}"); 465} 466 467// --- protocol helpers (used only by handleBan) --- 468 469/// build a CBOR #account frame for a takedown event. 470fn buildAccountFrame(allocator: std.mem.Allocator, did: []const u8) ?[]const u8 { 471 const zat = @import("zat"); 472 const cbor = zat.cbor; 473 474 const header: cbor.Value = .{ .map = &.{ 475 .{ .key = "op", .value = .{ .unsigned = 1 } }, 476 .{ .key = "t", .value = .{ .text = "#account" } }, 477 } }; 478 479 var time_buf: [24]u8 = undefined; 480 const time_str = formatTimestamp(&time_buf); 481 482 const payload: cbor.Value = .{ .map = &.{ 483 .{ .key = "seq", .value = .{ .unsigned = 0 } }, 484 .{ .key = "did", .value = .{ .text = did } }, 485 .{ .key = "time", .value = .{ .text = time_str } }, 486 .{ .key = "active", .value = .{ .boolean = false } }, 487 .{ .key = "status", .value = .{ .text = "takendown" } }, 488 } }; 489 490 const header_bytes = cbor.encodeAlloc(allocator, header) catch return null; 491 const payload_bytes = cbor.encodeAlloc(allocator, payload) catch { 492 allocator.free(header_bytes); 493 return null; 494 }; 495 496 var frame = allocator.alloc(u8, header_bytes.len + payload_bytes.len) catch { 497 allocator.free(header_bytes); 498 allocator.free(payload_bytes); 499 return null; 500 }; 501 @memcpy(frame[0..header_bytes.len], header_bytes); 502 @memcpy(frame[header_bytes.len..], payload_bytes); 503 504 allocator.free(header_bytes); 505 allocator.free(payload_bytes); 506 507 return frame; 508} 509 510fn formatTimestamp(buf: *[24]u8) []const u8 { 511 var tp: std.c.timespec = undefined; 512 _ = std.c.clock_gettime(.REALTIME, &tp); 513 const ts: u64 = @intCast(tp.sec); 514 const es = std.time.epoch.EpochSeconds{ .secs = ts }; 515 const day = es.getEpochDay(); 516 const yd = day.calculateYearDay(); 517 const md = yd.calculateMonthDay(); 518 const ds = es.getDaySeconds(); 519 520 return std.fmt.bufPrint(buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{ 521 yd.year, 522 @as(u32, @intFromEnum(md.month)) + 1, 523 @as(u32, md.day_index) + 1, 524 ds.getHoursIntoDay(), 525 ds.getMinutesIntoHour(), 526 ds.getSecondsIntoMinute(), 527 }) catch "1970-01-01T00:00:00Z"; 528} 529 530fn getenv(key: [*:0]const u8) ?[]const u8 { 531 const ptr = std.c.getenv(key) orelse return null; 532 return std.mem.sliceTo(ptr, 0); 533}