A social RSS reader built on the AT Protocol. glean.at
glean atproto atmosphere rss feed social app
14
fork

Configure Feed

Select the types of activity you want to include in your feed.

Refactor embedding logic to support instructions and sqlite-vec knn

+109 -120
+7 -17
internal/cluster/article.go
··· 76 76 texts[j] = a.text 77 77 } 78 78 79 - embeddings, err := e.embedder.Embed(ctx, texts) 79 + embeddings, err := e.embedder.Embed(ctx, texts, "Represent this news article for retrieving topically similar articles. Focus on the subjects, themes, and key entities discussed.") 80 80 if err != nil { 81 81 return fmt.Errorf("embed batch %d: %w", i/embedBatchSize, err) 82 82 } ··· 153 153 } 154 154 155 155 dim := e.embedder.Dimension() 156 - sumVec := make([]float32, dim) 157 - count := 0 156 + var blobs [][]byte 158 157 likedSet := make(map[int64]bool) 159 158 for embRows.Next() { 160 159 var id int64 ··· 163 162 embRows.Close() 164 163 return err 165 164 } 166 - v := deserializeFloat32(blob) 167 - if len(v) != dim { 165 + if len(blob) != dim*4 { 168 166 continue 169 167 } 170 - for j := range sumVec { 171 - sumVec[j] += v[j] 172 - } 173 - count++ 168 + blobs = append(blobs, blob) 174 169 likedSet[id] = true 175 170 } 176 171 embRows.Close() 177 172 178 - if count == 0 { 173 + if len(blobs) == 0 { 179 174 return nil 180 175 } 181 176 182 - avgVec := make([]float32, dim) 183 - for j := range avgVec { 184 - avgVec[j] = sumVec[j] / float32(count) 185 - } 186 - 187 - queryBlob, err := vec.SerializeFloat32(avgVec) 177 + queryBlob, err := avgEmbeddings(blobs, dim) 188 178 if err != nil { 189 179 return fmt.Errorf("serialize query vector: %w", err) 190 180 } ··· 308 298 texts[j] = f.text 309 299 } 310 300 311 - embeddings, err := e.embedder.Embed(ctx, texts) 301 + embeddings, err := e.embedder.Embed(ctx, texts, "Represent this RSS feed description for discovering feeds with similar editorial focus and topic coverage.") 312 302 if err != nil { 313 303 return fmt.Errorf("embed feed batch %d: %w", i/embedBatchSize, err) 314 304 }
+43 -18
internal/cluster/embed.go
··· 1 1 package cluster 2 2 3 3 import ( 4 - "bytes" 5 4 "context" 6 - "encoding/binary" 5 + "unsafe" 7 6 7 + vec "github.com/asg017/sqlite-vec-go-bindings/cgo" 8 8 "github.com/openai/openai-go" 9 9 "github.com/openai/openai-go/option" 10 10 ) 11 11 12 - // Embedder generates vector embeddings for text inputs. Implementations must be 13 - // safe for concurrent use. 12 + // Embedder generates vector embeddings for text inputs. 14 13 type Embedder interface { 15 - Embed(ctx context.Context, texts []string) ([][]float32, error) 14 + Embed(ctx context.Context, texts []string, instruction string) ([][]float32, error) 16 15 Dimension() int 17 16 } 18 17 19 - type OpenAIEmbedder struct { 18 + type EmbedderClient struct { 20 19 client openai.Client 21 20 model string 22 21 dimension int 23 22 } 24 23 25 - type OpenAIEmbedderConfig struct { 24 + type EmbedderClientConfig struct { 26 25 BaseURL string 27 26 APIKey string 28 27 Model string 29 28 Dimension int 30 29 } 31 30 32 - func NewOpenAIEmbedder(cfg OpenAIEmbedderConfig) *OpenAIEmbedder { 31 + func NewEmbedderClient(cfg EmbedderClientConfig) *EmbedderClient { 33 32 opts := []option.RequestOption{} 34 33 if cfg.BaseURL != "" { 35 34 opts = append(opts, option.WithBaseURL(cfg.BaseURL)) ··· 37 36 if cfg.APIKey != "" { 38 37 opts = append(opts, option.WithAPIKey(cfg.APIKey)) 39 38 } 40 - return &OpenAIEmbedder{ 39 + return &EmbedderClient{ 41 40 client: openai.NewClient(opts...), 42 41 model: cfg.Model, 43 42 dimension: cfg.Dimension, 44 43 } 45 44 } 46 45 47 - func (e *OpenAIEmbedder) Dimension() int { 46 + func (e *EmbedderClient) Dimension() int { 48 47 return e.dimension 49 48 } 50 49 51 - func (e *OpenAIEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) { 50 + func (e *EmbedderClient) Embed(ctx context.Context, texts []string, instruction string) ([][]float32, error) { 51 + inputs := texts 52 + if instruction != "" { 53 + inputs = make([]string, len(texts)) 54 + for i, t := range texts { 55 + inputs[i] = instruction + "\n" + t 56 + } 57 + } 52 58 resp, err := e.client.Embeddings.New(ctx, openai.EmbeddingNewParams{ 53 59 Model: e.model, 54 60 Input: openai.EmbeddingNewParamsInputUnion{ 55 - OfArrayOfStrings: texts, 61 + OfArrayOfStrings: inputs, 56 62 }, 57 63 }) 58 64 if err != nil { ··· 69 75 return embeddings, nil 70 76 } 71 77 72 - func deserializeFloat32(data []byte) []float32 { 73 - if len(data)%4 != 0 { 78 + func avgEmbeddings(blobs [][]byte, dim int) ([]byte, error) { 79 + sum := make([]float32, dim) 80 + count := 0 81 + for _, blob := range blobs { 82 + v := bytesToFloat32s(blob, dim) 83 + if v == nil { 84 + continue 85 + } 86 + for j := range sum { 87 + sum[j] += v[j] 88 + } 89 + count++ 90 + } 91 + if count == 0 { 92 + return nil, nil 93 + } 94 + for j := range sum { 95 + sum[j] /= float32(count) 96 + } 97 + return vec.SerializeFloat32(sum) 98 + } 99 + 100 + func bytesToFloat32s(data []byte, expectedDim int) []float32 { 101 + if len(data) != expectedDim*4 { 74 102 return nil 75 103 } 76 - result := make([]float32, len(data)/4) 77 - r := bytes.NewReader(data) 78 - _ = binary.Read(r, binary.LittleEndian, &result) 79 - return result 104 + return unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), expectedDim) 80 105 }
+51 -65
internal/cluster/jaccard.go
··· 5 5 "database/sql" 6 6 "fmt" 7 7 "log/slog" 8 - "math" 9 8 "sync" 10 9 ) 11 10 ··· 122 121 if e.embedder == nil { 123 122 return nil 124 123 } 125 - e.logger.Debug("computing embedding similarity") 124 + e.logger.Debug("computing embedding similarity via vec0 KNN") 126 125 127 - stagingRows, err := tx.QueryContext(ctx, `SELECT feed_a, feed_b FROM _feed_sim_staging`) 126 + feedRows, err := tx.QueryContext(ctx, `SELECT feed_url, embedding FROM recs.feed_embeddings`) 128 127 if err != nil { 129 128 return err 130 129 } 131 130 132 - type pair struct{ a, b string } 133 - var pairs []pair 134 - feedSet := make(map[string]bool) 135 - for stagingRows.Next() { 136 - var p pair 137 - if err := stagingRows.Scan(&p.a, &p.b); err != nil { 138 - stagingRows.Close() 131 + type feedEmb struct { 132 + url string 133 + vec []byte 134 + } 135 + var feeds []feedEmb 136 + for feedRows.Next() { 137 + var f feedEmb 138 + if err := feedRows.Scan(&f.url, &f.vec); err != nil { 139 + feedRows.Close() 139 140 return err 140 141 } 141 - pairs = append(pairs, p) 142 - feedSet[p.a] = true 143 - feedSet[p.b] = true 142 + feeds = append(feeds, f) 144 143 } 145 - stagingRows.Close() 144 + feedRows.Close() 146 145 147 - if len(feedSet) == 0 { 146 + if len(feeds) == 0 { 148 147 return nil 149 148 } 150 149 151 - ph := make([]string, 0, len(feedSet)) 152 - args := make([]any, 0, len(feedSet)) 153 - for url := range feedSet { 154 - ph = append(ph, "?") 155 - args = append(args, url) 150 + const knnLimit = 50 151 + stmt, err := tx.PrepareContext(ctx, ` 152 + SELECT feed_url, distance 153 + FROM recs.feed_embeddings 154 + WHERE embedding MATCH ? AND k = ? 155 + ORDER BY distance 156 + `) 157 + if err != nil { 158 + return err 156 159 } 157 - embRows, err := tx.QueryContext(ctx, 158 - fmt.Sprintf("SELECT feed_url, embedding FROM recs.feed_embeddings WHERE feed_url IN (%s)", joinPh(ph)), 159 - args..., 160 + defer stmt.Close() 161 + 162 + updateStmt, err := tx.PrepareContext(ctx, 163 + `UPDATE _feed_sim_staging SET jaccard = jaccard + ? WHERE feed_a = ? AND feed_b = ?`, 160 164 ) 161 165 if err != nil { 162 166 return err 163 167 } 168 + defer updateStmt.Close() 164 169 165 - embeddings := make(map[string][]float32) 166 - for embRows.Next() { 167 - var url string 168 - var blob []byte 169 - if err := embRows.Scan(&url, &blob); err != nil { 170 - embRows.Close() 170 + for _, f := range feeds { 171 + knnRows, err := stmt.QueryContext(ctx, f.vec, knnLimit) 172 + if err != nil { 171 173 return err 172 174 } 173 - v := deserializeFloat32(blob) 174 - if len(v) > 0 { 175 - embeddings[url] = v 175 + for knnRows.Next() { 176 + var neighborURL string 177 + var dist float64 178 + if err := knnRows.Scan(&neighborURL, &dist); err != nil { 179 + knnRows.Close() 180 + return err 181 + } 182 + if neighborURL == f.url || dist <= 0 { 183 + continue 184 + } 185 + boost := 1.0 / (1.0 + dist) * e.config.DescriptionWeight 186 + a, b := f.url, neighborURL 187 + if a > b { 188 + a, b = b, a 189 + } 190 + if _, err := updateStmt.ExecContext(ctx, boost, a, b); err != nil { 191 + knnRows.Close() 192 + return err 193 + } 176 194 } 177 - } 178 - embRows.Close() 179 - 180 - for _, p := range pairs { 181 - vecA, okA := embeddings[p.a] 182 - vecB, okB := embeddings[p.b] 183 - if !okA || !okB { 184 - continue 185 - } 186 - sim := cosineSimilarity(vecA, vecB) 187 - if sim <= 0 { 188 - continue 189 - } 190 - boost := sim * e.config.DescriptionWeight 191 - if _, err := tx.ExecContext(ctx, 192 - `UPDATE _feed_sim_staging SET jaccard = jaccard + ? WHERE feed_a = ? AND feed_b = ?`, 193 - boost, p.a, p.b, 194 - ); err != nil { 195 - return err 196 - } 195 + knnRows.Close() 197 196 } 198 197 199 198 return nil 200 - } 201 - 202 - func cosineSimilarity(a, b []float32) float64 { 203 - var dot, normA, normB float64 204 - for i := range a { 205 - dot += float64(a[i]) * float64(b[i]) 206 - normA += float64(a[i]) * float64(a[i]) 207 - normB += float64(b[i]) * float64(b[i]) 208 - } 209 - if normA == 0 || normB == 0 { 210 - return 0 211 - } 212 - return dot / (math.Sqrt(normA) * math.Sqrt(normB)) 213 199 } 214 200 215 201 // ComputeUserSimilarity recomputes the user_similarity table: subscription
+1 -1
internal/cluster/jaccard_test.go
··· 584 584 return &MockEmbedder{dimension: dimension} 585 585 } 586 586 587 - func (m *MockEmbedder) Embed(_ context.Context, texts []string) ([][]float32, error) { 587 + func (m *MockEmbedder) Embed(_ context.Context, texts []string, _ string) ([][]float32, error) { 588 588 result := make([][]float32, len(texts)) 589 589 for i, text := range texts { 590 590 vec := make([]float32, m.dimension)
+6 -18
internal/cluster/scoring.go
··· 4 4 "context" 5 5 "database/sql" 6 6 "fmt" 7 - 8 - vec "github.com/asg017/sqlite-vec-go-bindings/cgo" 9 7 ) 10 8 11 9 type FeedRecommendation struct { ··· 292 290 } 293 291 defer conn.Close() 294 292 293 + dim := e.embedder.Dimension() 295 294 subRows, err := conn.QueryContext(ctx, ` 296 295 SELECT fe.feed_url, fe.embedding FROM articles.subscriptions s 297 296 JOIN recs.feed_embeddings fe ON fe.feed_url = s.feed_url ··· 301 300 return nil, err 302 301 } 303 302 304 - dim := e.embedder.Dimension() 305 - sumVec := make([]float32, dim) 306 - subCount := 0 307 303 var subFeedURLs []string 304 + var blobs [][]byte 308 305 for subRows.Next() { 309 306 var url string 310 307 var blob []byte ··· 312 309 subRows.Close() 313 310 return nil, err 314 311 } 315 - v := deserializeFloat32(blob) 316 - if len(v) != dim { 312 + if len(blob) != dim*4 { 317 313 continue 318 314 } 319 - for j := range sumVec { 320 - sumVec[j] += v[j] 321 - } 322 - subCount++ 323 315 subFeedURLs = append(subFeedURLs, url) 316 + blobs = append(blobs, blob) 324 317 } 325 318 subRows.Close() 326 319 327 - if subCount == 0 { 320 + if len(blobs) == 0 { 328 321 return nil, nil 329 322 } 330 323 331 - avgVec := make([]float32, dim) 332 - for j := range avgVec { 333 - avgVec[j] = sumVec[j] / float32(subCount) 334 - } 335 - 336 324 subSet := make(map[string]bool, len(subFeedURLs)) 337 325 for _, u := range subFeedURLs { 338 326 subSet[u] = true 339 327 } 340 328 341 - queryBlob, err := vec.SerializeFloat32(avgVec) 329 + queryBlob, err := avgEmbeddings(blobs, dim) 342 330 if err != nil { 343 331 return nil, fmt.Errorf("serialize query vector: %w", err) 344 332 }
+1 -1
main.go
··· 58 58 59 59 var embedder cluster.Embedder 60 60 if embedURL := envOr("GLEAN_EMBED_BASE_URL", ""); embedURL != "" { 61 - embedder = cluster.NewOpenAIEmbedder(cluster.OpenAIEmbedderConfig{ 61 + embedder = cluster.NewEmbedderClient(cluster.EmbedderClientConfig{ 62 62 BaseURL: embedURL, 63 63 APIKey: envOr("GLEAN_EMBED_API_KEY", ""), 64 64 Model: envOr("GLEAN_EMBED_MODEL", "text-embedding-3-small"),