A lil service that creates embeddings of posts, profiles, and avatars to store them in Qdrant
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

add more search tools

Hailey 9bed6c62 f876a822

+126 -5
+2
.gitignore
··· 8 8 9 9 # Virtual environments 10 10 .venv 11 + 12 + .env
+124 -5
search.py
··· 2 2 import math 3 3 import sys 4 4 from typing import List 5 + from datetime import datetime, timezone, timedelta 5 6 import click 7 + from qdrant_client.models import ( 8 + DatetimeRange, 9 + FieldCondition, 10 + Filter, 11 + MatchValue, 12 + ) 6 13 from rich.console import Console 7 14 from rich.table import Table 8 15 from rich.panel import Panel 9 16 from rich import box 17 + from itertools import combinations 18 + import numpy as np 10 19 11 20 from config import CONFIG 12 21 from database import QDRANT_SERVICE, Result ··· 109 118 show_default=True, 110 119 ) 111 120 @click.option( 112 - "--show-more", 121 + "--more", 113 122 is_flag=True, 114 123 default=False, 115 124 show_default=True, ··· 120 129 did: str, 121 130 limit: int, 122 131 threshold: float, 123 - show_more: bool, 132 + more: bool, 124 133 ): 125 134 # TODO: would be nice if these were flags instead 126 135 if type not in ["profile", "avatar", "post"]: ··· 157 166 score_threshold=math.sqrt(threshold), 158 167 ) 159 168 160 - display_results(type, description, results, show_more) 169 + display_results(type, description, results, more) 161 170 elif type == "avatar": 162 171 if not query: 163 172 console.print("[cyan]Looking up avatar...[/cyan]") ··· 196 205 score_threshold=threshold, 197 206 ) 198 207 199 - display_results(type, cid, results, show_more) 208 + display_results(type, cid, results, more) 200 209 elif type == "post": 201 210 if not query: 202 211 console.print("[red]Must supply input for post search[/red]") ··· 216 225 score_threshold=threshold, 217 226 ) 218 227 219 - display_results(type, description, results, show_more) 228 + display_results(type, description, results, more) 220 229 221 230 except Exception as e: 222 231 console.print(f"[red]Error: {e}[/red]") 223 232 logger.error(f"Search error: {e}", exc_info=True) 224 233 sys.exit(1) 234 + 235 + 236 + @main.command() 237 + @click.argument("text", required=True) 238 + @click.option("--did") 239 + @click.option( 240 + "--more", 241 + is_flag=True, 242 + default=False, 243 + ) 244 + def did_similar_posts(text: str, did: str, more: bool): 245 + QDRANT_SERVICE.initialize() 246 + EMBEDDING_SERVICE.initialize() 247 + 248 + vector = EMBEDDING_SERVICE.encode(text) 249 + 250 + client = QDRANT_SERVICE.get_client() 251 + 252 + console.print(f"[cyan]Searching for [bold]{did}[/bold]'s posts...[/cyan]") 253 + 254 + results = client.query_points( 255 + collection_name=CONFIG.qdrant_post_collection_name, 256 + query=vector, 257 + query_filter=Filter( 258 + must=[FieldCondition(key="did", match=MatchValue(value=did))] 259 + ), 260 + limit=30, 261 + score_threshold=0.85, 262 + with_payload=True, 263 + ).points 264 + 265 + total_score = 0 266 + for hit in results: 267 + total_score += hit.score 268 + avg = total_score / len(results) 269 + 270 + console.print( 271 + f"[green]Found [bold]{len(results)}[/bold] similar posts from [bold]{did}[/bold]. Average similarity was [bold]{avg}[/bold].[/green]" 272 + ) 273 + 274 + if more: 275 + for hit in results: 276 + text = hit.payload.get("text") 277 + console.print(text) 278 + console.print() 279 + 280 + 281 + @main.command() 282 + @click.option("--did") 283 + @click.option( 284 + "--more", 285 + is_flag=True, 286 + default=False, 287 + ) 288 + def did_similar_recent(did: str, more: bool): 289 + QDRANT_SERVICE.initialize() 290 + 291 + client = QDRANT_SERVICE.get_client() 292 + 293 + day_ago = datetime.now(timezone.utc) - timedelta(days=1) 294 + 295 + results = client.scroll( 296 + collection_name=CONFIG.qdrant_post_collection_name, 297 + scroll_filter=Filter( 298 + must=[ 299 + FieldCondition(key="did", match=MatchValue(value=did)), 300 + FieldCondition( 301 + key="timestamp", 302 + range=DatetimeRange(gte=day_ago), 303 + ), 304 + ] 305 + ), 306 + with_payload=True, 307 + with_vectors=True, 308 + )[0] 309 + 310 + if len(results) < 2: 311 + console.print( 312 + f"[yellow]Found only {len(results)} post(s). Need at least 2 to compare.[/yellow]" 313 + ) 314 + return 315 + 316 + vectors = [point.vector for point in results] 317 + 318 + similarities = [] 319 + for i, j in combinations(range(len(vectors)), 2): 320 + dot_product = np.dot(vectors[i], vectors[j]) 321 + norm_i = np.linalg.norm(vectors[i]) 322 + norm_j = np.linalg.norm(vectors[j]) 323 + similarity = dot_product / (norm_i * norm_j) 324 + similarities.append(similarity) 325 + 326 + avg_similarity = np.mean(similarities) 327 + min_similarity = np.min(similarities) 328 + max_similarity = np.max(similarities) 329 + 330 + console.print( 331 + f"[green]Found [bold]{len(results)}[/bold] posts from [bold]{did}[/bold] in the last 24h.[/green]\n" 332 + f"[cyan]Average pairwise similarity: [bold]{avg_similarity:.4f}[/bold][/cyan]\n" 333 + f"[cyan]Min: {min_similarity:.4f}, Max: {max_similarity:.4f}[/cyan]" 334 + ) 335 + 336 + if more: 337 + console.print("\n[bold]Posts:[/bold]\n") 338 + for i, point in enumerate(results): 339 + text = point.payload.get("text", "") 340 + timestamp = point.payload.get("timestamp", "") 341 + console.print(f"[dim]{i + 1}. {timestamp}[/dim]") 342 + console.print(text) 343 + console.print() 225 344 226 345 227 346 if __name__ == "__main__":