Running EmbeddingGemma ONNX with Hugot
0
hugot_embeddinggemma_onnx.go edited
89 lines 2.5 kB view raw
1// Experiment to load the ONNX port of EmbeddingGemma with Hugot. 2// See: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX 3// which this script ports from the Python/Transformers.js equivalent. 4 5package main 6 7import ( 8 "context" 9 "fmt" 10 "log" 11 12 "github.com/knights-analytics/hugot" 13 "github.com/knights-analytics/hugot/options" 14 "github.com/knights-analytics/hugot/pipelines" 15) 16 17func check(err error) { 18 if err != nil { 19 log.Fatal(err) 20 } 21} 22 23func main() { 24 ctx := context.Background() 25 26 // Running on MacOS with vanilla 'brew install onnx' 27 session, err := hugot.NewORTSession(ctx, options.WithOnnxLibraryPath("/opt/homebrew/lib/")) 28 check(err) 29 30 defer func(session *hugot.Session) { 31 err := session.Destroy() 32 check(err) 33 }(session) 34 35 dl_opts := hugot.NewDownloadOptions() 36 // model graph 37 dl_opts.OnnxFilePath = "onnx/model_quantized.onnx" 38 // model weights (also required) 39 dl_opts.ExternalDataPath = "onnx/model_quantized.onnx_data" 40 dl_opts.Verbose = true 41 42 modelPath, err := hugot.DownloadModel(ctx, 43 "onnx-community/embeddinggemma-300m-ONNX", 44 "./models/", 45 dl_opts) 46 check(err) 47 48 config := hugot.FeatureExtractionConfig{ 49 ModelPath: modelPath, 50 OnnxFilename: "onnx/model_quantized.onnx", 51 Name: "embeddinggemma", 52 Options: []hugot.FeatureExtractionOption{ 53 pipelines.WithOutputName("sentence_embedding"), 54 pipelines.WithNormalization(), 55 }, 56 } 57 pipe, err := hugot.NewPipeline(session, config) 58 check(err) 59 60 inputs := []string{ 61 "task: search result | query: Which planet is known as the Red Planet?", 62 "title: none | text: Venus is often called Earth's twin because of its similar size and proximity.", 63 "title: none | text: Mars, known for its reddish appearance, is often referred to as the Red Planet.", 64 "title: none | text: Jupiter, the largest planet in our solar system, has a prominent red spot.", 65 "title: none | text: Saturn, famous for its rings, is sometimes mistaken for the Red Planet.", 66 } 67 68 out, err := pipe.RunPipeline(ctx, inputs) 69 // out is *pipelines.FeatureExtractionOutput 70 // out.Embeddings is [][]float32 71 72 fmt.Println(len(out.Embeddings)) // 5, len(inputs) 73 fmt.Println(len(out.Embeddings[0])) // 768, EmbeddingGemma's full dim 74 75 query_emb := out.Embeddings[0] 76 doc_embs := out.Embeddings[1:] 77 for i, doc_emb := range doc_embs { 78 var sim float32 79 for j := range query_emb { 80 sim += query_emb[j] * doc_emb[j] 81 } 82 fmt.Printf("doc %d sim: %.4f\n", i, sim) 83 } 84} 85 86// doc 0 sim: 0.3050 87// doc 1 sim: 0.6361 88// doc 2 sim: 0.4965 89// doc 3 sim: 0.4938