GET /xrpc/app.bsky.actor.searchActorsTypeahead typeahead.waow.tech
16
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 306 lines 12 kB view raw
1//! Search handler: 3-query ranking strategy 2//! Replicates src/handlers/search.ts logic using local SQLite FTS5 3 4const std = @import("std"); 5const json = std.json; 6const mem = std.mem; 7const Allocator = mem.Allocator; 8const LocalDb = @import("db/LocalDb.zig"); 9const Col = LocalDb.Col; 10 11const log = std.log.scoped(.search); 12 13const MAX_SEEN = 201; // 1 exact + 100 prefix + 100 FTS 14const MAX_DID_LEN = 64; 15 16/// Perform the 3-query search and write JSON response into the writer. 17/// Ranking: exact handle → handle prefix (LIKE) → FTS5 prefix 18/// Matches the CF Worker's search.ts ranking. 19/// 20/// Results are written inline while SQLite statements are alive to avoid 21/// use-after-free on row text pointers (which are only valid until the 22/// next step() or deinit() on the statement). 23pub fn search(local: *LocalDb, raw_query: []const u8, limit: usize, writer: anytype) !void { 24 const term = sanitize(raw_query); 25 if (term.len == 0) { 26 try writer.writeAll("{\"actors\":[]}"); 27 return; 28 } 29 30 // DID dedup: copy DIDs into stack storage so they outlive each query 31 var seen_storage: [MAX_SEEN * MAX_DID_LEN]u8 = undefined; 32 var seen_fba = std.heap.FixedBufferAllocator.init(&seen_storage); 33 var seen_dids: [MAX_SEEN][]const u8 = undefined; 34 var seen_count: usize = 0; 35 36 var jw: json.Stringify = .{ .writer = writer }; 37 try jw.beginObject(); 38 try jw.objectField("actors"); 39 try jw.beginArray(); 40 41 var emitted: usize = 0; 42 43 // 1. exact handle match 44 { 45 var rows = local.query( 46 "SELECT " ++ LocalDb.actor_cols ++ " FROM actors WHERE handle = ? COLLATE NOCASE AND hidden = 0 LIMIT 1", 47 .{term}, 48 ) catch |err| blk: { 49 log.err("exact query failed: {}", .{err}); 50 break :blk null; 51 }; 52 if (rows) |*r| { 53 defer r.deinit(); 54 if (r.next()) |row| { 55 if (emitted < limit) { 56 try writeActorFromRow(&jw, row); 57 if (trackDid(&seen_fba, &seen_dids, &seen_count, row.text(Col.did))) {} 58 emitted += 1; 59 } 60 } 61 } 62 } 63 64 // build queries 65 var fts_query_buf: [256]u8 = undefined; 66 var like_query_buf: [256]u8 = undefined; 67 68 const fts_query = std.fmt.bufPrint(&fts_query_buf, "\"{s}\"*", .{term}) catch { 69 try jw.endArray(); 70 try jw.endObject(); 71 return; 72 }; 73 const like_query = std.fmt.bufPrint(&like_query_buf, "{s}%", .{term}) catch { 74 try jw.endArray(); 75 try jw.endObject(); 76 return; 77 }; 78 79 // 2. two-phase handle prefix match 80 // Phase 1: index-friendly LIKE with no ORDER BY (enables index early termination) 81 // Phase 2: sort candidates by handle length in Zig, point-lookup for full rows 82 // This restores the Worker's length-based ranking without the pathological full-scan sort. 83 { 84 const max_prefix = 100; 85 const prefix_fetch = @min(limit * 5, max_prefix); 86 87 var prefix_did_storage: [max_prefix * MAX_DID_LEN]u8 = undefined; 88 var prefix_fba = std.heap.FixedBufferAllocator.init(&prefix_did_storage); 89 var prefix_dids: [max_prefix][]const u8 = undefined; 90 var prefix_hlens: [max_prefix]usize = undefined; 91 var prefix_count: usize = 0; 92 93 // phase 1: fast index scan (no ORDER BY — uses handle COLLATE NOCASE index) 94 { 95 var rows = local.query( 96 \\SELECT did, handle FROM actors 97 \\WHERE handle LIKE ? COLLATE NOCASE AND hidden = 0 LIMIT ? 98 , .{ like_query, prefix_fetch }) catch |err| blk: { 99 log.err("prefix query failed: {}", .{err}); 100 break :blk null; 101 }; 102 if (rows) |*r| { 103 defer r.deinit(); 104 while (r.next()) |row| { 105 const did = row.text(0); 106 const handle = row.text(1); 107 // skip exact match (already in tier 1) 108 if (handle.len == term.len and asciiEqlIgnoreCase(handle, term)) continue; 109 if (isDuplicate(seen_dids[0..seen_count], did)) continue; 110 const copy = prefix_fba.allocator().dupe(u8, did) catch break; 111 prefix_dids[prefix_count] = copy; 112 prefix_hlens[prefix_count] = handle.len; 113 prefix_count += 1; 114 } 115 } 116 } 117 118 // sort by handle length (insertion sort — at most 100 elements, <1μs) 119 { 120 var i: usize = 1; 121 while (i < prefix_count) : (i += 1) { 122 const key_did = prefix_dids[i]; 123 const key_hlen = prefix_hlens[i]; 124 var j: usize = i; 125 while (j > 0 and prefix_hlens[j - 1] > key_hlen) { 126 prefix_dids[j] = prefix_dids[j - 1]; 127 prefix_hlens[j] = prefix_hlens[j - 1]; 128 j -= 1; 129 } 130 prefix_dids[j] = key_did; 131 prefix_hlens[j] = key_hlen; 132 } 133 } 134 135 // phase 2: point lookups in handle-length order 136 for (prefix_dids[0..prefix_count]) |did| { 137 if (emitted >= limit) break; 138 var rows = local.query( 139 "SELECT " ++ LocalDb.actor_cols ++ " FROM actors WHERE did = ?", 140 .{did}, 141 ) catch continue; 142 defer rows.deinit(); 143 if (rows.next()) |row| { 144 try writeActorFromRow(&jw, row); 145 if (trackDid(&seen_fba, &seen_dids, &seen_count, row.text(Col.did))) {} 146 emitted += 1; 147 } 148 } 149 } 150 151 // 3. two-phase FTS5 prefix search 152 // Phase 1: pure FTS5 ranked query (no JOIN — enables rank optimization + early termination) 153 // Phase 2: point lookups on actors table by primary key 154 // Skip if tiers 1+2 already filled the limit, or term too short (single-char prefix is pathological) 155 if (emitted < limit and term.len >= 2) { 156 const overfetch = limit * 5; 157 const max_candidates = 500; 158 const fetch_count = @min(overfetch, max_candidates); 159 160 // collect candidate DIDs into stack buffer (row text pointers die on step/deinit) 161 var candidate_storage: [max_candidates * MAX_DID_LEN]u8 = undefined; 162 var candidate_fba = std.heap.FixedBufferAllocator.init(&candidate_storage); 163 var candidate_dids: [max_candidates][]const u8 = undefined; 164 var candidate_count: usize = 0; 165 166 // phase 1: FTS5-only ranked query 167 { 168 var rows = local.query( 169 \\SELECT did FROM actors_fts WHERE actors_fts MATCH ? ORDER BY rank LIMIT ? 170 , .{ fts_query, fetch_count }) catch |err| blk: { 171 log.err("FTS query failed for '{s}': {}", .{ fts_query, err }); 172 break :blk null; 173 }; 174 if (rows) |*r| { 175 defer r.deinit(); 176 while (r.next()) |row| { 177 const did = row.text(0); 178 if (isDuplicate(seen_dids[0..seen_count], did)) continue; 179 const copy = candidate_fba.allocator().dupe(u8, did) catch break; 180 candidate_dids[candidate_count] = copy; 181 candidate_count += 1; 182 } 183 } 184 } 185 186 // phase 2: point lookups by primary key 187 for (candidate_dids[0..candidate_count]) |did| { 188 if (emitted >= limit) break; 189 var rows = local.query( 190 "SELECT " ++ LocalDb.actor_cols ++ " FROM actors WHERE did = ? AND handle != '' AND hidden = 0", 191 .{did}, 192 ) catch continue; 193 defer rows.deinit(); 194 if (rows.next()) |row| { 195 if (!isDuplicate(seen_dids[0..seen_count], row.text(Col.did))) { 196 try writeActorFromRow(&jw, row); 197 if (trackDid(&seen_fba, &seen_dids, &seen_count, row.text(Col.did))) {} 198 emitted += 1; 199 } 200 } 201 } 202 } 203 204 try jw.endArray(); 205 try jw.endObject(); 206} 207 208/// Write a single actor object from a live SQLite row. 209/// Must be called while the row's statement is still alive. 210/// Row uses LocalDb.actor_cols column order. 211fn writeActorFromRow(jw: anytype, row: LocalDb.Row) !void { 212 const did = row.text(Col.did); 213 const handle = row.text(Col.handle); 214 const display_name = row.text(Col.display_name); 215 const avatar_url = row.text(Col.avatar_url); 216 const labels = row.text(Col.labels); 217 const created_at = row.text(Col.created_at); 218 const associated = row.text(Col.associated); 219 const pds = row.text(Col.pds); 220 221 try jw.beginObject(); 222 223 try jw.objectField("did"); 224 try jw.write(did); 225 try jw.objectField("handle"); 226 try jw.write(handle); 227 228 if (display_name.len > 0) { 229 try jw.objectField("displayName"); 230 try jw.write(display_name); 231 } 232 if (avatar_url.len > 0) { 233 try jw.objectField("avatar"); 234 if (mem.startsWith(u8, avatar_url, "https://")) { 235 try jw.write(avatar_url); 236 } else if (pds.len > 0 and !mem.endsWith(u8, pds, ".bsky.network")) { 237 // non-bsky PDS: construct blob URL from PDS + DID + CID 238 try jw.print("\"{s}/xrpc/com.atproto.sync.getBlob?did={s}&cid={s}\"", .{ pds, did, avatar_url }); 239 } else { 240 try jw.print("\"https://cdn.bsky.app/img/avatar/plain/{s}/{s}\"", .{ did, avatar_url }); 241 } 242 } 243 if (!mem.eql(u8, associated, "{}") and associated.len > 2) { 244 try jw.objectField("associated"); 245 try jw.print("{s}", .{associated}); 246 } 247 try jw.objectField("labels"); 248 try jw.print("{s}", .{labels}); 249 250 if (created_at.len > 0) { 251 try jw.objectField("createdAt"); 252 try jw.write(created_at); 253 } 254 255 try jw.endObject(); 256} 257 258/// Copy a DID into the fixed buffer for cross-query dedup. 259fn trackDid(fba: *std.heap.FixedBufferAllocator, seen: *[MAX_SEEN][]const u8, count: *usize, did: []const u8) bool { 260 if (count.* >= MAX_SEEN) return false; 261 const copy = fba.allocator().dupe(u8, did) catch return false; 262 seen[count.*] = copy; 263 count.* += 1; 264 return true; 265} 266 267fn asciiEqlIgnoreCase(a: []const u8, b: []const u8) bool { 268 if (a.len != b.len) return false; 269 for (a, b) |ac, bc| { 270 if (toLower(ac) != toLower(bc)) return false; 271 } 272 return true; 273} 274 275fn toLower(c: u8) u8 { 276 return if (c >= 'A' and c <= 'Z') c + 32 else c; 277} 278 279fn isDuplicate(seen: []const []const u8, did: []const u8) bool { 280 for (seen) |s| { 281 if (mem.eql(u8, s, did)) return true; 282 } 283 return false; 284} 285 286/// Keep only unicode letters, digits, whitespace, '.', '-' 287/// Matches the worker's sanitize() in src/utils.ts 288/// Uses a static buffer — safe for single query at a time per thread. 289var sanitize_buf: [512]u8 = undefined; 290 291fn sanitize(input: []const u8) []const u8 { 292 var out: usize = 0; 293 for (input) |c| { 294 if (out >= sanitize_buf.len) break; 295 if ((c >= 'a' and c <= 'z') or (c >= 'A' and c <= 'Z') or (c >= '0' and c <= '9') or c == '.' or c == '-' or c == ' ' or c == '\t') { 296 sanitize_buf[out] = c; 297 out += 1; 298 } else if (c >= 0x80) { 299 // pass through UTF-8 bytes (unicode letters/digits) 300 sanitize_buf[out] = c; 301 out += 1; 302 } 303 // else: strip it (FTS5 special chars, punctuation, etc.) 304 } 305 return mem.trim(u8, sanitize_buf[0..out], " \t\n\r"); 306}