// Experiment to load the ONNX port of EmbeddingGemma with Hugot. // See: https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX // which this script ports from the Python/Transformers.js equivalent. package main import ( "context" "fmt" "log" "github.com/knights-analytics/hugot" "github.com/knights-analytics/hugot/options" "github.com/knights-analytics/hugot/pipelines" ) func check(err error) { if err != nil { log.Fatal(err) } } func main() { ctx := context.Background() // Running on MacOS with vanilla 'brew install onnx' session, err := hugot.NewORTSession(ctx, options.WithOnnxLibraryPath("/opt/homebrew/lib/")) check(err) defer func(session *hugot.Session) { err := session.Destroy() check(err) }(session) dl_opts := hugot.NewDownloadOptions() // model graph dl_opts.OnnxFilePath = "onnx/model_quantized.onnx" // model weights (also required) dl_opts.ExternalDataPath = "onnx/model_quantized.onnx_data" dl_opts.Verbose = true modelPath, err := hugot.DownloadModel(ctx, "onnx-community/embeddinggemma-300m-ONNX", "./models/", dl_opts) check(err) config := hugot.FeatureExtractionConfig{ ModelPath: modelPath, OnnxFilename: "onnx/model_quantized.onnx", Name: "embeddinggemma", Options: []hugot.FeatureExtractionOption{ pipelines.WithOutputName("sentence_embedding"), pipelines.WithNormalization(), }, } pipe, err := hugot.NewPipeline(session, config) check(err) inputs := []string{ "task: search result | query: Which planet is known as the Red Planet?", "title: none | text: Venus is often called Earth's twin because of its similar size and proximity.", "title: none | text: Mars, known for its reddish appearance, is often referred to as the Red Planet.", "title: none | text: Jupiter, the largest planet in our solar system, has a prominent red spot.", "title: none | text: Saturn, famous for its rings, is sometimes mistaken for the Red Planet.", } out, err := pipe.RunPipeline(ctx, inputs) // out is *pipelines.FeatureExtractionOutput // out.Embeddings is [][]float32 fmt.Println(len(out.Embeddings)) // 5, len(inputs) fmt.Println(len(out.Embeddings[0])) // 768, EmbeddingGemma's full dim query_emb := out.Embeddings[0] doc_embs := out.Embeddings[1:] for i, doc_emb := range doc_embs { var sim float32 for j := range query_emb { sim += query_emb[j] * doc_emb[j] } fmt.Printf("doc %d sim: %.4f\n", i, sim) } } // doc 0 sim: 0.3050 // doc 1 sim: 0.6361 // doc 2 sim: 0.4965 // doc 3 sim: 0.4938