···11+MIT License
22+33+Copyright (c) 2026 Josh Ghiloni
44+55+Permission is hereby granted, free of charge, to any person obtaining a copy
66+of this software and associated documentation files (the "Software"), to deal
77+in the Software without restriction, including without limitation the rights
88+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
99+copies of the Software, and to permit persons to whom the Software is
1010+furnished to do so, subject to the following conditions:
1111+1212+The above copyright notice and this permission notice shall be included in all
1313+copies or substantial portions of the Software.
1414+1515+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1616+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1717+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1818+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121+SOFTWARE.
+331
client.go
···11+package jsonrpc
22+33+import (
44+ "context"
55+ "crypto/x509"
66+ "encoding/pem"
77+ "errors"
88+ "fmt"
99+ "log/slog"
1010+ "net/http"
1111+ "strconv"
1212+ "sync"
1313+ "sync/atomic"
1414+ "time"
1515+1616+ "github.com/coder/websocket"
1717+ "github.com/coder/websocket/wsjson"
1818+ "github.com/goccy/go-json"
1919+)
2020+2121+// WebsocketClient allows users to connect to a websocket that serves a JSON-RPC
2222+// 2.0 API and send method calls, receive responses, and subscribe to notifications
2323+// It is safe to use a single [WebsocketClient] across multiple goroutines.
2424+//
2525+// Because of the non-trivial overhead associated with connecting to websockets
2626+// and typically the overahead of authenticating a session, it's recommended to
2727+// keep a session alive for as long as possible, though the details of that will
2828+// vary from provider to provider
2929+type WebsocketClient struct {
3030+ conn *websocket.Conn
3131+ subscribers map[string][]NotificationSubscription
3232+ subscriberMapMu sync.RWMutex
3333+ activeRequests map[ID]chan Response
3434+ reqMapMu sync.Mutex
3535+ lastID atomic.Pointer[ID]
3636+ nextID func(ID) ID
3737+ idGenMu sync.Mutex
3838+ logger *slog.Logger
3939+ active atomic.Bool
4040+}
4141+4242+// ClientOptions are optional parameters to send to [NewClient] that have reasonable
4343+// defaults if omitted
4444+type ClientOptions struct {
4545+ // SkipTLSValidation tells clients connecting to wss:// URLs to not verify
4646+ // TLS certs, allowing it to connect to servers without trusted certificates.
4747+ // If possible, opt for CACert instead
4848+ SkipTLSValidation bool
4949+5050+ // CACert adds provided Certificate to the underlying HTTP Client's
5151+ // trusted certificate pool. It must be a []byte, string, or [*crypto/x509.Certificate].
5252+ // If it is a []byte or string, it must be a PEM-encoded x509 string.
5353+ CACert any
5454+5555+ // Logger adds a specific [*log/slog.Logger] to the client. If not provided,
5656+ // [log/slog.Default] will be used
5757+ Logger *slog.Logger
5858+5959+ // HTTPClient allows a user to use a specific underlying [*net/http.Client]
6060+ // to be used when making the websocket connection. If this is set, the values
6161+ // of SkipTLSValidation and CACert are ignored. If it is not set, [net/http.DefaultClient]
6262+ // is used
6363+ HTTPClient *http.Client
6464+6565+ // IDGenerator allows the user to specify a func to take the last ID known
6666+ // to the collection and generate the next. Because JSON-RPC correlates
6767+ // requests and responses with IDs, it is imperative these are unique within
6868+ // a connection, or unexpected behavior may occur. If this is omitted, the
6969+ // function will respond with the current timestamp with nanosecond precision,
7070+ // as a base-32 integer (as opposed to being base-32 encoded)
7171+ IDGenerator func(lastID ID) ID
7272+7373+ // DialOptions contain further options to send to the underlying websocket
7474+ // connection. Note that these will be passed as-is, with the exception being
7575+ // HTTPClient, which will be overridden with the value of HTTPClient generated
7676+ // in NewClient
7777+ DialOptions *websocket.DialOptions
7878+}
7979+8080+func defaultIDGen(_ ID) ID {
8181+ return ID(strconv.FormatInt(time.Now().UTC().UnixNano(), 32))
8282+}
8383+8484+// NewClient attempts to connect to the given URL (which must have a ws:// or
8585+// wss:// scheme) with the given options. If successful, it returns a client
8686+// with an open connection but is not yet listening. For that, you must call
8787+// [*WebsocketClient.Start]
8888+func NewClient(ctx context.Context, serverURL string, options *ClientOptions) (*WebsocketClient, error) {
8989+ var (
9090+ err error
9191+ resp *http.Response
9292+ )
9393+9494+ nextID := defaultIDGen
9595+ if options.IDGenerator != nil {
9696+ nextID = options.IDGenerator
9797+ }
9898+ c := &WebsocketClient{
9999+ subscribers: make(map[string][]NotificationSubscription),
100100+ activeRequests: make(map[ID]chan Response),
101101+ nextID: nextID,
102102+ logger: options.Logger,
103103+ }
104104+105105+ if c.logger == nil {
106106+ c.logger = slog.Default()
107107+ }
108108+109109+ httpClient, err := getHTTPClient(options)
110110+ if err != nil {
111111+ return nil, err
112112+ }
113113+114114+ if options.DialOptions == nil {
115115+ options.DialOptions = new(websocket.DialOptions)
116116+ }
117117+118118+ options.DialOptions.HTTPClient = httpClient
119119+120120+ c.conn, resp, err = websocket.Dial(ctx, serverURL, options.DialOptions)
121121+ if resp != nil && resp.Body != nil {
122122+ resp.Body.Close()
123123+ }
124124+ return c, err
125125+}
126126+127127+// Start generates the first ID for the server, and starts listening for messages
128128+func (w *WebsocketClient) Start(ctx context.Context) {
129129+ // set the first ID
130130+ w.getNextID()
131131+ go w.listen(ctx)
132132+}
133133+134134+// Close implements [io.Closer]. It closes any active response channels and deletes
135135+// all notification subscriptions, as well as closing the websocket connection
136136+// without waiting for a response
137137+func (w *WebsocketClient) Close() error {
138138+ safeCloseChan := func(c chan Response) {
139139+ defer func() {
140140+ // closing a closed channel panics, but we just want to ignore that
141141+ recover()
142142+ }()
143143+ close(c)
144144+ }
145145+146146+ w.reqMapMu.Lock()
147147+ for _, r := range w.activeRequests {
148148+ safeCloseChan(r)
149149+ }
150150+ w.reqMapMu.Unlock()
151151+152152+ w.subscriberMapMu.Lock()
153153+ for m := range w.subscribers {
154154+ delete(w.subscribers, m)
155155+ }
156156+157157+ return w.conn.CloseNow()
158158+}
159159+160160+func (w *WebsocketClient) Call(ctx context.Context, method string, params ...any) (response chan Response, err error) {
161161+ select {
162162+ case <-ctx.Done():
163163+ return nil, ctx.Err()
164164+ default:
165165+ }
166166+167167+ request := Request{
168168+ Notification: Notification{
169169+ Method: method,
170170+ },
171171+ ID: w.getNextID(),
172172+ }
173173+174174+ request.Params, err = json.Marshal(params)
175175+ if err != nil {
176176+ return
177177+ }
178178+179179+ response = make(chan Response, 1)
180180+ w.reqMapMu.Lock()
181181+ w.activeRequests[request.ID] = response
182182+ w.reqMapMu.Unlock()
183183+184184+ return response, wsjson.Write(ctx, w.conn, request)
185185+}
186186+187187+func (w *WebsocketClient) CallSynchronous(ctx context.Context, method string, params ...any) (any, error) {
188188+ select {
189189+ case <-ctx.Done():
190190+ return nil, ctx.Err()
191191+ default:
192192+ }
193193+194194+ response, err := w.Call(ctx, method, params...)
195195+ if err != nil {
196196+ return nil, err
197197+ }
198198+199199+ select {
200200+ case <-ctx.Done():
201201+ return nil, ctx.Err()
202202+ case r := <-response:
203203+ var a any
204204+ err = r.Result(&a)
205205+ return a, errors.Join(r.Error(), err)
206206+ }
207207+}
208208+209209+func (w *WebsocketClient) listen(ctx context.Context) {
210210+ for {
211211+ select {
212212+ case <-ctx.Done():
213213+ return
214214+ default:
215215+ }
216216+217217+ w.active.Store(true)
218218+219219+ var e envelope
220220+ err := wsjson.Read(ctx, w.conn, &e)
221221+ if err != nil {
222222+ w.logger.Error("error reading message from websocket", "error", err)
223223+ continue
224224+ }
225225+226226+ // first try to unmarshal it as a response
227227+ var response apiResponse
228228+ if err = e.unwrap(&response); err == nil {
229229+ w.reqMapMu.Lock()
230230+ r, ok := w.activeRequests[response.ID()]
231231+ if !ok {
232232+ w.logger.Warn("message received for unknown request ID", "id", response.ID())
233233+ w.reqMapMu.Unlock()
234234+ continue
235235+ }
236236+ delete(w.activeRequests, response.ID())
237237+ w.reqMapMu.Unlock()
238238+239239+ r <- response
240240+ close(r)
241241+ continue
242242+ }
243243+244244+ var notif Notification
245245+ if err = e.unwrap(¬if); err == nil {
246246+ w.subscriberMapMu.RLock()
247247+ listeners := w.subscribers[notif.Method]
248248+ w.subscriberMapMu.RUnlock()
249249+250250+ for _, listener := range listeners {
251251+ listener(notif)
252252+ }
253253+ }
254254+ }
255255+}
256256+257257+func (w *WebsocketClient) IsActive() bool {
258258+ return w.active.Load()
259259+}
260260+261261+func (w *WebsocketClient) SubscribeToNotification(name string, listener NotificationSubscription) {
262262+ w.subscriberMapMu.Lock()
263263+ defer w.subscriberMapMu.Unlock()
264264+265265+ w.subscribers[name] = append(w.subscribers[name], listener)
266266+}
267267+268268+func (w *WebsocketClient) getNextID() ID {
269269+ w.idGenMu.Lock()
270270+ next := new(w.nextID(""))
271271+ w.lastID.Store(next)
272272+ w.idGenMu.Unlock()
273273+274274+ return *next
275275+}
276276+277277+func getHTTPClient(options *ClientOptions) (*http.Client, error) {
278278+ // if the client is explicitly set, we're done. ignore skiptlsvalidation
279279+ if options.HTTPClient != nil {
280280+ return options.HTTPClient, nil
281281+ }
282282+283283+ client := http.DefaultClient
284284+ tr := client.Transport.(*http.Transport)
285285+ var err error
286286+ switch {
287287+ case options.CACert != nil:
288288+ var cert *x509.Certificate
289289+ switch c := options.CACert.(type) {
290290+ case string:
291291+ cert, err = getCertFromBytes([]byte(c))
292292+ if err != nil {
293293+ return nil, err
294294+ }
295295+ case []byte:
296296+ cert, err = getCertFromBytes(c)
297297+ if err != nil {
298298+ return nil, err
299299+ }
300300+ case *x509.Certificate:
301301+ cert = c
302302+ default:
303303+ return nil, fmt.Errorf("expected CACert to be []byte, string, or *x509.Certificate, got %T", options.CACert)
304304+ }
305305+306306+ certPool := tr.TLSClientConfig.RootCAs
307307+ if certPool == nil {
308308+ certPool, err = x509.SystemCertPool()
309309+ if err != nil {
310310+ return nil, err
311311+ }
312312+ }
313313+ certPool.AddCert(cert)
314314+315315+ // TODO check if this is required with debugging
316316+ tr.TLSClientConfig.RootCAs = certPool
317317+ case options.SkipTLSValidation:
318318+ tr.TLSClientConfig.InsecureSkipVerify = true
319319+ }
320320+321321+ return client, nil
322322+}
323323+324324+func getCertFromBytes(b []byte) (*x509.Certificate, error) {
325325+ block, _ := pem.Decode(b)
326326+ if block != nil {
327327+ return x509.ParseCertificate(block.Bytes)
328328+ }
329329+330330+ return nil, errors.New("bytes contained no PEM data")
331331+}