import logging import math import sys from typing import List from datetime import datetime, timezone, timedelta import click from qdrant_client.models import ( DatetimeRange, Direction, FieldCondition, Filter, MatchValue, OrderBy, ) from rich.console import Console from rich.table import Table from rich.panel import Panel from rich import box from itertools import combinations import numpy as np from config import CONFIG from database import QDRANT_SERVICE, Result from embedder import EMBEDDING_SERVICE from retina import RETINA_CLIENT, binary_to_float_vector, hex_to_binary logging.basicConfig( level=logging.INFO, format=logging.BASIC_FORMAT, ) logger = logging.getLogger(__name__) console = Console() def display_results( type: str, query: str, results: List[Result], show_more: bool = False ): """ A lil guy that turns the results into a table for viewing """ if not results: console.print("[yellow]No similar profiles found.[/yellow]") return console.print(Panel(f"[bold blue]Query: {query}[/bold blue]", box=box.ROUNDED)) console.print() table = Table( title=f"Found {len(results)} similar results", box=box.ROUNDED, header_style="bold magenta", expand=True, show_header=True, show_lines=True, ) table.add_column("#", style="dim", width=4) table.add_column("DID", style="cyan", width=35) table.add_column("Similarity", justify="right", style="green", width=10) if show_more: table.add_column("More", style="white", overflow="fold") for idx, result in enumerate(results, 1): similarity_percent = f"{result.score or 0.0 * 100:.4f}%" row: List[str] = [ str(idx), result.did, similarity_percent, ] if show_more and result.payload is not None: more = None if type == "profile": more = result.payload.get("description") elif type == "avatar": cid = result.payload.get("cid") more = f"https://cdn.bsky.app/img/feed_thumbnail/plain/{result.did}/{cid}@jpeg" elif type == "post": more = result.payload.get("text") if more is not None: row.append(more) table.add_row(*row) console.print(table) console.print() @click.group() def main(): pass @main.command() @click.argument("query", required=False) @click.option( "--type", default="profile", show_default=True, ) @click.option( "--did", ) @click.option( "--limit", default=10, show_default=True, ) @click.option( "--threshold", default=0.7, type=float, show_default=True, ) @click.option( "--more", is_flag=True, default=False, show_default=True, ) def search( query: str, type: str, did: str, limit: int, threshold: float, more: bool, ): # TODO: would be nice if these were flags instead if type not in ["profile", "avatar", "post"]: raise Exception("invalid type") QDRANT_SERVICE.initialize() try: if type == "profile": if not query: console.print("[cyan]Looking up profile...[/cyan]") profile = QDRANT_SERVICE.get_profile_by_did(did) if not profile: console.print(f"[red]Profile not found: {did}[/red]") sys.exit(1) description = profile.payload.get("description") query_vector = profile.vector console.print("[green]Found profile[/green]") else: EMBEDDING_SERVICE.initialize() description = query query_vector = EMBEDDING_SERVICE.encode(query) console.print("[cyan]Looking up similar profiles...[/cyan]") results = QDRANT_SERVICE.search_similar( collection_name=CONFIG.qdrant_profile_collection_name, query_vector=query_vector, limit=limit, score_threshold=math.sqrt(threshold), ) display_results(type, description, results, more) elif type == "avatar": if not query: console.print("[cyan]Looking up avatar...[/cyan]") avatar = QDRANT_SERVICE.get_avatar_by_did(did) if not avatar: console.print(f"[red]Avatar not found: {did}[/red]") sys.exit(1) cid = avatar.payload.get("cid") query_vector = avatar.vector else: pts = query.split("/") if len(pts) != 8: console.print("[red]Invalid avatar URL provided[/red]") sys.exit(1) did = pts[6] cid = pts[7].split("@")[0] resp = RETINA_CLIENT.get_image_hash(did, cid) if resp.quality_too_low or resp.hash is None: console.print("[red]Hash quality too low[/red]") sys.exit(1) query_vector = binary_to_float_vector(hex_to_binary(resp.hash)) console.print("[cyan]Looking up similar avatars...[/cyan]") results = QDRANT_SERVICE.search_similar( collection_name=CONFIG.qdrant_avatar_collection_name, query_vector=query_vector, limit=limit, score_threshold=threshold, ) display_results(type, cid, results, more) elif type == "post": if not query: console.print("[red]Must supply input for post search[/red]") sys.exit(1) else: EMBEDDING_SERVICE.initialize() description = query query_vector = EMBEDDING_SERVICE.encode(query) console.print("[cyan]Looking up similar posts...[/cyan]") results = QDRANT_SERVICE.search_similar( collection_name=CONFIG.qdrant_post_collection_name, query_vector=query_vector, limit=limit, score_threshold=threshold, ) display_results(type, description, results, more) except Exception as e: console.print(f"[red]Error: {e}[/red]") logger.error(f"Search error: {e}", exc_info=True) sys.exit(1) @main.command() @click.argument("text", required=True) @click.option("--did") @click.option( "--more", is_flag=True, default=False, ) def did_similar_posts(text: str, did: str, more: bool): QDRANT_SERVICE.initialize() EMBEDDING_SERVICE.initialize() vector = EMBEDDING_SERVICE.encode(text) client = QDRANT_SERVICE.get_client() console.print(f"[cyan]Searching for [bold]{did}[/bold]'s posts...[/cyan]") results = client.query_points( collection_name=CONFIG.qdrant_post_collection_name, query=vector, query_filter=Filter( must=[FieldCondition(key="did", match=MatchValue(value=did))] ), limit=30, score_threshold=0.85, with_payload=True, ).points total_score = 0 for hit in results: total_score += hit.score avg = total_score / len(results) console.print( f"[green]Found [bold]{len(results)}[/bold] similar posts from [bold]{did}[/bold]. Average similarity was [bold]{avg}[/bold].[/green]" ) if more: for hit in results: text = hit.payload.get("text") console.print(text) console.print() @main.command() @click.option("--did") @click.option( "--more", is_flag=True, default=False, ) def did_similar_recent(did: str, more: bool): QDRANT_SERVICE.initialize() client = QDRANT_SERVICE.get_client() day_ago = datetime.now(timezone.utc) - timedelta(days=1) results = client.scroll( collection_name=CONFIG.qdrant_post_collection_name, scroll_filter=Filter( must=[ FieldCondition(key="did", match=MatchValue(value=did)), FieldCondition( key="timestamp", range=DatetimeRange(gte=day_ago), ), ] ), order_by=OrderBy( key="timestamp", direction=Direction.DESC, ), limit=30, with_payload=True, with_vectors=True, )[0] if len(results) < 2: console.print( f"[yellow]Found only {len(results)} post(s). Need at least 2 to compare.[/yellow]" ) return vectors = [point.vector for point in results] similarities = [] for i, j in combinations(range(len(vectors)), 2): dot_product = np.dot(vectors[i], vectors[j]) norm_i = np.linalg.norm(vectors[i]) norm_j = np.linalg.norm(vectors[j]) similarity = dot_product / (norm_i * norm_j) similarities.append(similarity) avg_similarity = np.mean(similarities) min_similarity = np.min(similarities) max_similarity = np.max(similarities) console.print( f"[green]Found [bold]{len(results)}[/bold] posts from [bold]{did}[/bold] in the last 24h.[/green]\n" f"[cyan]Average pairwise similarity: [bold]{avg_similarity:.4f}[/bold][/cyan]\n" f"[cyan]Min: {min_similarity:.4f}, Max: {max_similarity:.4f}[/cyan]" ) if more: console.print("\n[bold]Posts:[/bold]\n") for i, point in enumerate(results): text = point.payload.get("text", "") timestamp = point.payload.get("timestamp", "") console.print(f"[dim]{i + 1}. {timestamp}[/dim]") console.print(text) console.print() if __name__ == "__main__": main()