this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

lspaux/protocol: initial implementation for protocol

This provides the types and their binary encoding for communication
between the LSP and upstream servers.

Signed-off-by: Matthew Sackman <matthew@cue.works>
Change-Id: I204ce1e4dc199fb0ce1a6372646f3736043c8b2a
Reviewed-on: https://cue.gerrithub.io/c/cue-lang/cue/+/1232171
Reviewed-by: Marcel van Lohuizen <mpvl@gmail.com>
TryBot-Result: CUEcueckoo <cueckoo@cuelang.org>

+645
+5
unstable/README.md
··· 1 + # Unstable 2 + 3 + Every package within this directory should be considered unstable and 4 + experimental. They are not subject to any backwards compatibility 5 + guarantees. Use at your own risk.
+304
unstable/lspaux/protocol/protocol.go
··· 1 + // Copyright 2026 CUE Authors 2 + // 3 + // Licensed under the Apache License, Version 2.0 (the "License"); 4 + // you may not use this file except in compliance with the License. 5 + // You may obtain a copy of the License at 6 + // 7 + // http://www.apache.org/licenses/LICENSE-2.0 8 + // 9 + // Unless required by applicable law or agreed to in writing, software 10 + // distributed under the License is distributed on an "AS IS" BASIS, 11 + // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 + // See the License for the specific language governing permissions and 13 + // limitations under the License. 14 + 15 + // WARNING: THIS PACKAGE IS EXPERIMENTAL. 16 + // ITS API MAY CHANGE AT ANY TIME. 17 + package protocol 18 + 19 + import ( 20 + "encoding/binary" 21 + "fmt" 22 + ) 23 + 24 + const ( 25 + MsgTypeChanged byte = 0x01 26 + MsgTypeEvalRequest byte = 0x02 27 + MsgTypeEvalResult byte = 0x03 28 + MsgTypeEvalFinished byte = 0x04 29 + ) 30 + 31 + // PeekMessageType returns the message type byte from a raw message, 32 + // allowing callers to determine which struct to decode into. 33 + func PeekMessageType(data []byte) (byte, error) { 34 + if len(data) == 0 { 35 + return 0, fmt.Errorf("empty message") 36 + } 37 + return data[0], nil 38 + } 39 + 40 + // Encoding / decoding helpers 41 + 42 + func appendUint32(buf []byte, v uint32) []byte { 43 + var b [4]byte 44 + binary.BigEndian.PutUint32(b[:], v) 45 + return append(buf, b[:]...) 46 + } 47 + 48 + func appendString(buf []byte, s string) []byte { 49 + buf = appendUint32(buf, uint32(len(s))) 50 + return append(buf, s...) 51 + } 52 + 53 + func appendBytes(buf []byte, data []byte) []byte { 54 + buf = appendUint32(buf, uint32(len(data))) 55 + return append(buf, data...) 56 + } 57 + 58 + // reader provides sequential decoding from a byte slice. 59 + type reader struct { 60 + data []byte 61 + offset int 62 + } 63 + 64 + func (r *reader) remaining() int { return len(r.data) - r.offset } 65 + 66 + func (r *reader) readByte() (byte, error) { 67 + if r.remaining() < 1 { 68 + return 0, fmt.Errorf("unexpected end of message reading byte at offset %d", r.offset) 69 + } 70 + b := r.data[r.offset] 71 + r.offset++ 72 + return b, nil 73 + } 74 + 75 + func (r *reader) readUint32() (uint32, error) { 76 + if r.remaining() < 4 { 77 + return 0, fmt.Errorf("unexpected end of message reading uint32 at offset %d", r.offset) 78 + } 79 + v := binary.BigEndian.Uint32(r.data[r.offset:]) 80 + r.offset += 4 81 + return v, nil 82 + } 83 + 84 + func (r *reader) readString() (string, error) { 85 + n, err := r.readUint32() 86 + if err != nil { 87 + return "", fmt.Errorf("reading string length: %w", err) 88 + } 89 + nInt := int(n) 90 + if r.remaining() < nInt { 91 + return "", fmt.Errorf("unexpected end of message reading string of length %d at offset %d (only %d bytes remain)", n, r.offset, r.remaining()) 92 + } 93 + s := string(r.data[r.offset : r.offset+nInt]) 94 + r.offset += nInt 95 + return s, nil 96 + } 97 + 98 + func (r *reader) readBytes() ([]byte, error) { 99 + n, err := r.readUint32() 100 + if err != nil { 101 + return nil, fmt.Errorf("reading byte array length: %w", err) 102 + } 103 + nInt := int(n) 104 + if r.remaining() < nInt { 105 + return nil, fmt.Errorf("unexpected end of message reading byte array of length %d at offset %d (only %d bytes remain)", n, r.offset, r.remaining()) 106 + } 107 + b := make([]byte, n) 108 + copy(b, r.data[r.offset:r.offset+nInt]) 109 + r.offset += nInt 110 + return b, nil 111 + } 112 + 113 + func (r *reader) expectEmpty(msgName string) error { 114 + if r.remaining() != 0 { 115 + return fmt.Errorf("%s: %d unexpected trailing bytes", msgName, r.remaining()) 116 + } 117 + return nil 118 + } 119 + 120 + // ChangedMsg indicates that something has changed and the client may 121 + // wish to request a re-evaluation. 122 + type ChangedMsg struct{} 123 + 124 + func (m *ChangedMsg) MarshalBytes() []byte { 125 + return []byte{MsgTypeChanged} 126 + } 127 + 128 + func (m *ChangedMsg) UnmarshalBytes(data []byte) error { 129 + r := &reader{data: data} 130 + 131 + typ, err := r.readByte() 132 + if err != nil { 133 + return fmt.Errorf("changed: %w", err) 134 + } 135 + if typ != MsgTypeChanged { 136 + return fmt.Errorf("changed: wrong message type: expected 0x%02x, got 0x%02x", MsgTypeChanged, typ) 137 + } 138 + 139 + return r.expectEmpty("changed") 140 + } 141 + 142 + // EvalRequestMsg is sent to request an evaluation of the 143 + // configuration for a given repository at a given commit, overlaid 144 + // with local modifications supplied as a zip file. 145 + type EvalRequestMsg struct { 146 + RequestID string // opaque evaluation-request identifier 147 + RepoName string // e.g. "https://cue.gerrithub.io/a/cue-lang/cue" 148 + CommitID string // current commit id of the git repo 149 + ZipData []byte // zip file of local modifications 150 + } 151 + 152 + func (m *EvalRequestMsg) MarshalBytes() []byte { 153 + buf := []byte{MsgTypeEvalRequest} 154 + buf = appendString(buf, m.RequestID) 155 + buf = appendString(buf, m.RepoName) 156 + buf = appendString(buf, m.CommitID) 157 + buf = appendBytes(buf, m.ZipData) 158 + return buf 159 + } 160 + 161 + func (m *EvalRequestMsg) UnmarshalBytes(data []byte) error { 162 + r := &reader{data: data} 163 + 164 + typ, err := r.readByte() 165 + if err != nil { 166 + return fmt.Errorf("eval request: %w", err) 167 + } 168 + if typ != MsgTypeEvalRequest { 169 + return fmt.Errorf("eval request: wrong message type: expected 0x%02x, got 0x%02x", MsgTypeEvalRequest, typ) 170 + } 171 + 172 + if m.RequestID, err = r.readString(); err != nil { 173 + return fmt.Errorf("eval request: request ID: %w", err) 174 + } 175 + if m.RepoName, err = r.readString(); err != nil { 176 + return fmt.Errorf("eval request: repo name: %w", err) 177 + } 178 + if m.CommitID, err = r.readString(); err != nil { 179 + return fmt.Errorf("eval request: commit ID: %w", err) 180 + } 181 + if m.ZipData, err = r.readBytes(); err != nil { 182 + return fmt.Errorf("eval request: zip data: %w", err) 183 + } 184 + 185 + return r.expectEmpty("eval request") 186 + } 187 + 188 + // FileCoordinate identifies a position within a file. 189 + type FileCoordinate struct { 190 + // slash-separated path relative to the git repo root. In the 191 + // future, this could become a URI, with the git repo name as the 192 + // prefix 193 + Path string 194 + ByteOffset uint32 // byte offset to the start of the token 195 + } 196 + 197 + // EvalError represents a single error produced during evaluation. 198 + type EvalError struct { 199 + Message string // human-readable error message 200 + Coordinates []FileCoordinate // may be empty 201 + } 202 + 203 + // EvalResultMsg is sent in response to an EvalRequestMsg. Zero or 204 + // more results are sent per request, matched by RequestID. 205 + type EvalResultMsg struct { 206 + RequestID string // echoed from the corresponding EvalRequestMsg 207 + Errors []EvalError // may be empty if evaluation succeeded 208 + } 209 + 210 + func (m *EvalResultMsg) MarshalBytes() []byte { 211 + buf := []byte{MsgTypeEvalResult} 212 + buf = appendString(buf, m.RequestID) 213 + 214 + buf = appendUint32(buf, uint32(len(m.Errors))) 215 + for _, e := range m.Errors { 216 + buf = appendString(buf, e.Message) 217 + 218 + buf = appendUint32(buf, uint32(len(e.Coordinates))) 219 + for _, c := range e.Coordinates { 220 + buf = appendString(buf, c.Path) 221 + buf = appendUint32(buf, c.ByteOffset) 222 + } 223 + } 224 + 225 + return buf 226 + } 227 + 228 + func (m *EvalResultMsg) UnmarshalBytes(data []byte) error { 229 + r := &reader{data: data} 230 + 231 + typ, err := r.readByte() 232 + if err != nil { 233 + return fmt.Errorf("eval result: %w", err) 234 + } 235 + if typ != MsgTypeEvalResult { 236 + return fmt.Errorf("eval result: wrong message type: expected 0x%02x, got 0x%02x", MsgTypeEvalResult, typ) 237 + } 238 + 239 + if m.RequestID, err = r.readString(); err != nil { 240 + return fmt.Errorf("eval result: request ID: %w", err) 241 + } 242 + 243 + numErrors, err := r.readUint32() 244 + if err != nil { 245 + return fmt.Errorf("eval result: error count: %w", err) 246 + } 247 + 248 + m.Errors = make([]EvalError, numErrors) 249 + for i := range m.Errors { 250 + evalErr := &m.Errors[i] 251 + if evalErr.Message, err = r.readString(); err != nil { 252 + return fmt.Errorf("eval result: error[%d] message: %w", i, err) 253 + } 254 + 255 + numCoords, err := r.readUint32() 256 + if err != nil { 257 + return fmt.Errorf("eval result: error[%d] coordinate count: %w", i, err) 258 + } 259 + 260 + evalErr.Coordinates = make([]FileCoordinate, numCoords) 261 + for j := range evalErr.Coordinates { 262 + coord := &evalErr.Coordinates[j] 263 + if coord.Path, err = r.readString(); err != nil { 264 + return fmt.Errorf("eval result: error[%d] coordinate[%d] path: %w", i, j, err) 265 + } 266 + if coord.ByteOffset, err = r.readUint32(); err != nil { 267 + return fmt.Errorf("eval result: error[%d] coordinate[%d] byte offset: %w", i, j, err) 268 + } 269 + } 270 + } 271 + 272 + return r.expectEmpty("eval result") 273 + } 274 + 275 + // EvalFinishedMsg is sent in response to an EvalRequestMsg. Exactly 276 + // one finished message is sent per request, matched by RequestID. 277 + type EvalFinishedMsg struct { 278 + RequestID string // echoed from the corresponding EvalRequestMsg 279 + } 280 + 281 + func (m *EvalFinishedMsg) MarshalBytes() []byte { 282 + buf := []byte{MsgTypeEvalFinished} 283 + buf = appendString(buf, m.RequestID) 284 + 285 + return buf 286 + } 287 + 288 + func (m *EvalFinishedMsg) UnmarshalBytes(data []byte) error { 289 + r := &reader{data: data} 290 + 291 + typ, err := r.readByte() 292 + if err != nil { 293 + return fmt.Errorf("eval finished: %w", err) 294 + } 295 + if typ != MsgTypeEvalFinished { 296 + return fmt.Errorf("eval finished: wrong message type: expected 0x%02x, got 0x%02x", MsgTypeEvalFinished, typ) 297 + } 298 + 299 + if m.RequestID, err = r.readString(); err != nil { 300 + return fmt.Errorf("eval finished: request ID: %w", err) 301 + } 302 + 303 + return r.expectEmpty("eval finished") 304 + }
+336
unstable/lspaux/protocol/protocol_test.go
··· 1 + // Copyright 2026 CUE Authors 2 + // 3 + // Licensed under the Apache License, Version 2.0 (the "License"); 4 + // you may not use this file except in compliance with the License. 5 + // You may obtain a copy of the License at 6 + // 7 + // http://www.apache.org/licenses/LICENSE-2.0 8 + // 9 + // Unless required by applicable law or agreed to in writing, software 10 + // distributed under the License is distributed on an "AS IS" BASIS, 11 + // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 + // See the License for the specific language governing permissions and 13 + // limitations under the License. 14 + 15 + package protocol_test 16 + 17 + import ( 18 + "bytes" 19 + "testing" 20 + 21 + "cuelang.org/go/unstable/lspaux/protocol" 22 + ) 23 + 24 + func TestChangedMsgRoundTrip(t *testing.T) { 25 + orig := &protocol.ChangedMsg{} 26 + data := orig.MarshalBytes() 27 + 28 + if len(data) != 1 || data[0] != protocol.MsgTypeChanged { 29 + t.Fatalf("MarshalBytes: got %x, want [01]", data) 30 + } 31 + 32 + var decoded protocol.ChangedMsg 33 + if err := decoded.UnmarshalBytes(data); err != nil { 34 + t.Fatalf("UnmarshalBytes: %v", err) 35 + } 36 + } 37 + 38 + func TestChangedMsgRejectsTrailingBytes(t *testing.T) { 39 + var m protocol.ChangedMsg 40 + if err := m.UnmarshalBytes([]byte{0x01, 0x00}); err == nil { 41 + t.Fatal("expected error for trailing bytes") 42 + } 43 + } 44 + 45 + func TestChangedMsgRejectsWrongType(t *testing.T) { 46 + var m protocol.ChangedMsg 47 + if err := m.UnmarshalBytes([]byte{0x02}); err == nil { 48 + t.Fatal("expected error for wrong type byte") 49 + } 50 + } 51 + 52 + func TestChangedMsgRejectsEmpty(t *testing.T) { 53 + var m protocol.ChangedMsg 54 + if err := m.UnmarshalBytes(nil); err == nil { 55 + t.Fatal("expected error for empty input") 56 + } 57 + } 58 + 59 + func TestEvalRequestMsgRoundTrip(t *testing.T) { 60 + orig := &protocol.EvalRequestMsg{ 61 + RequestID: "req-42", 62 + RepoName: "https://cue.gerrithub.io/a/cue-lang/cue", 63 + CommitID: "abc123def456", 64 + ZipData: []byte{0x50, 0x4B, 0x03, 0x04, 0xDE, 0xAD}, 65 + } 66 + 67 + data := orig.MarshalBytes() 68 + 69 + var decoded protocol.EvalRequestMsg 70 + if err := decoded.UnmarshalBytes(data); err != nil { 71 + t.Fatalf("UnmarshalBytes: %v", err) 72 + } 73 + 74 + if decoded.RequestID != orig.RequestID { 75 + t.Errorf("RequestID: got %q, want %q", decoded.RequestID, orig.RequestID) 76 + } 77 + if decoded.RepoName != orig.RepoName { 78 + t.Errorf("RepoName: got %q, want %q", decoded.RepoName, orig.RepoName) 79 + } 80 + if decoded.CommitID != orig.CommitID { 81 + t.Errorf("CommitID: got %q, want %q", decoded.CommitID, orig.CommitID) 82 + } 83 + if !bytes.Equal(decoded.ZipData, orig.ZipData) { 84 + t.Errorf("ZipData: got %x, want %x", decoded.ZipData, orig.ZipData) 85 + } 86 + } 87 + 88 + func TestEvalRequestMsgEmptyZip(t *testing.T) { 89 + orig := &protocol.EvalRequestMsg{ 90 + RequestID: "r1", 91 + RepoName: "repo", 92 + CommitID: "aaa", 93 + ZipData: nil, 94 + } 95 + 96 + data := orig.MarshalBytes() 97 + 98 + var decoded protocol.EvalRequestMsg 99 + if err := decoded.UnmarshalBytes(data); err != nil { 100 + t.Fatalf("UnmarshalBytes: %v", err) 101 + } 102 + 103 + if len(decoded.ZipData) != 0 { 104 + t.Errorf("ZipData: got length %d, want 0", len(decoded.ZipData)) 105 + } 106 + } 107 + 108 + func TestEvalRequestMsgRejectsWrongType(t *testing.T) { 109 + var m protocol.EvalRequestMsg 110 + if err := m.UnmarshalBytes([]byte{0x01}); err == nil { 111 + t.Fatal("expected error for wrong type byte") 112 + } 113 + } 114 + 115 + func TestEvalRequestMsgRejectsTruncated(t *testing.T) { 116 + orig := &protocol.EvalRequestMsg{ 117 + RequestID: "req-1", 118 + RepoName: "repo", 119 + CommitID: "commit", 120 + ZipData: []byte{0x01, 0x02}, 121 + } 122 + full := orig.MarshalBytes() 123 + 124 + // Try every possible truncation. 125 + for i := 0; i < len(full)-1; i++ { 126 + var m protocol.EvalRequestMsg 127 + if err := m.UnmarshalBytes(full[:i]); err == nil { 128 + t.Errorf("expected error for truncation at %d/%d bytes", i, len(full)) 129 + } 130 + } 131 + } 132 + 133 + func TestEvalResultMsgRoundTrip(t *testing.T) { 134 + orig := &protocol.EvalResultMsg{ 135 + RequestID: "req-42", 136 + Errors: []protocol.EvalError{ 137 + { 138 + Message: "field not allowed: foo", 139 + Coordinates: []protocol.FileCoordinate{ 140 + {Path: "pkg/config.cue", ByteOffset: 123}, 141 + {Path: "pkg/other.cue", ByteOffset: 456}, 142 + }, 143 + }, 144 + { 145 + Message: "some repo-level error", 146 + Coordinates: nil, // empty coordinates 147 + }, 148 + }, 149 + } 150 + 151 + data := orig.MarshalBytes() 152 + 153 + var decoded protocol.EvalResultMsg 154 + if err := decoded.UnmarshalBytes(data); err != nil { 155 + t.Fatalf("UnmarshalBytes: %v", err) 156 + } 157 + 158 + if decoded.RequestID != orig.RequestID { 159 + t.Errorf("RequestID: got %q, want %q", decoded.RequestID, orig.RequestID) 160 + } 161 + if len(decoded.Errors) != len(orig.Errors) { 162 + t.Fatalf("Errors length: got %d, want %d", len(decoded.Errors), len(orig.Errors)) 163 + } 164 + for i, oe := range orig.Errors { 165 + de := decoded.Errors[i] 166 + if de.Message != oe.Message { 167 + t.Errorf("Errors[%d].Message: got %q, want %q", i, de.Message, oe.Message) 168 + } 169 + if len(de.Coordinates) != len(oe.Coordinates) { 170 + t.Fatalf("Errors[%d].Coordinates length: got %d, want %d", i, len(de.Coordinates), len(oe.Coordinates)) 171 + } 172 + for j, oc := range oe.Coordinates { 173 + dc := de.Coordinates[j] 174 + if dc.Path != oc.Path { 175 + t.Errorf("Errors[%d].Coordinates[%d].Path: got %q, want %q", i, j, dc.Path, oc.Path) 176 + } 177 + if dc.ByteOffset != oc.ByteOffset { 178 + t.Errorf("Errors[%d].Coordinates[%d].ByteOffset: got %d, want %d", i, j, dc.ByteOffset, oc.ByteOffset) 179 + } 180 + } 181 + } 182 + } 183 + 184 + func TestEvalResultMsgNoErrors(t *testing.T) { 185 + orig := &protocol.EvalResultMsg{ 186 + RequestID: "ok-1", 187 + Errors: nil, 188 + } 189 + 190 + data := orig.MarshalBytes() 191 + 192 + var decoded protocol.EvalResultMsg 193 + if err := decoded.UnmarshalBytes(data); err != nil { 194 + t.Fatalf("UnmarshalBytes: %v", err) 195 + } 196 + 197 + if decoded.RequestID != orig.RequestID { 198 + t.Errorf("RequestID: got %q, want %q", decoded.RequestID, orig.RequestID) 199 + } 200 + if len(decoded.Errors) != 0 { 201 + t.Errorf("Errors: got length %d, want 0", len(decoded.Errors)) 202 + } 203 + } 204 + 205 + func TestEvalResultMsgRejectsTrailingBytes(t *testing.T) { 206 + orig := &protocol.EvalResultMsg{RequestID: "x", Errors: nil} 207 + data := append(orig.MarshalBytes(), 0x00) 208 + 209 + var m protocol.EvalResultMsg 210 + if err := m.UnmarshalBytes(data); err == nil { 211 + t.Fatal("expected error for trailing bytes") 212 + } 213 + } 214 + 215 + func TestEvalResultMsgRejectsTruncated(t *testing.T) { 216 + orig := &protocol.EvalResultMsg{ 217 + RequestID: "req-1", 218 + Errors: []protocol.EvalError{ 219 + { 220 + Message: "err", 221 + Coordinates: []protocol.FileCoordinate{ 222 + {Path: "a.cue", ByteOffset: 10}, 223 + }, 224 + }, 225 + }, 226 + } 227 + full := orig.MarshalBytes() 228 + 229 + for i := 0; i < len(full)-1; i++ { 230 + var m protocol.EvalResultMsg 231 + if err := m.UnmarshalBytes(full[:i]); err == nil { 232 + t.Errorf("expected error for truncation at %d/%d bytes", i, len(full)) 233 + } 234 + } 235 + } 236 + 237 + func TestEvalFinishedMsgRoundTrip(t *testing.T) { 238 + orig := &protocol.EvalFinishedMsg{ 239 + RequestID: "req-42", 240 + } 241 + 242 + data := orig.MarshalBytes() 243 + 244 + var decoded protocol.EvalFinishedMsg 245 + if err := decoded.UnmarshalBytes(data); err != nil { 246 + t.Fatalf("UnmarshalBytes: %v", err) 247 + } 248 + 249 + if decoded.RequestID != orig.RequestID { 250 + t.Errorf("RequestID: got %q, want %q", decoded.RequestID, orig.RequestID) 251 + } 252 + } 253 + 254 + func TestEvalFinishedMsgEmptyRequestID(t *testing.T) { 255 + orig := &protocol.EvalFinishedMsg{ 256 + RequestID: "", 257 + } 258 + 259 + data := orig.MarshalBytes() 260 + 261 + var decoded protocol.EvalFinishedMsg 262 + if err := decoded.UnmarshalBytes(data); err != nil { 263 + t.Fatalf("UnmarshalBytes: %v", err) 264 + } 265 + 266 + if decoded.RequestID != "" { 267 + t.Errorf("RequestID: got %q, want %q", decoded.RequestID, "") 268 + } 269 + } 270 + 271 + func TestEvalFinishedMsgRejectsTrailingBytes(t *testing.T) { 272 + orig := &protocol.EvalFinishedMsg{RequestID: "x"} 273 + data := append(orig.MarshalBytes(), 0x00) 274 + 275 + var m protocol.EvalFinishedMsg 276 + if err := m.UnmarshalBytes(data); err == nil { 277 + t.Fatal("expected error for trailing bytes") 278 + } 279 + } 280 + 281 + func TestEvalFinishedMsgRejectsWrongType(t *testing.T) { 282 + var m protocol.EvalFinishedMsg 283 + if err := m.UnmarshalBytes([]byte{0x01}); err == nil { 284 + t.Fatal("expected error for wrong type byte") 285 + } 286 + } 287 + 288 + func TestEvalFinishedMsgRejectsEmpty(t *testing.T) { 289 + var m protocol.EvalFinishedMsg 290 + if err := m.UnmarshalBytes(nil); err == nil { 291 + t.Fatal("expected error for empty input") 292 + } 293 + } 294 + 295 + func TestEvalFinishedMsgRejectsTruncated(t *testing.T) { 296 + orig := &protocol.EvalFinishedMsg{ 297 + RequestID: "req-1", 298 + } 299 + full := orig.MarshalBytes() 300 + 301 + // Try every possible truncation. 302 + for i := 0; i < len(full)-1; i++ { 303 + var m protocol.EvalFinishedMsg 304 + if err := m.UnmarshalBytes(full[:i]); err == nil { 305 + t.Errorf("expected error for truncation at %d/%d bytes", i, len(full)) 306 + } 307 + } 308 + } 309 + 310 + func TestPeekMessageType(t *testing.T) { 311 + tests := []struct { 312 + name string 313 + data []byte 314 + want byte 315 + }{ 316 + {"changed", (&protocol.ChangedMsg{}).MarshalBytes(), protocol.MsgTypeChanged}, 317 + {"eval request", (&protocol.EvalRequestMsg{}).MarshalBytes(), protocol.MsgTypeEvalRequest}, 318 + {"eval result", (&protocol.EvalResultMsg{}).MarshalBytes(), protocol.MsgTypeEvalResult}, 319 + {"eval finished", (&protocol.EvalFinishedMsg{}).MarshalBytes(), protocol.MsgTypeEvalFinished}, 320 + } 321 + for _, tt := range tests { 322 + t.Run(tt.name, func(t *testing.T) { 323 + got, err := protocol.PeekMessageType(tt.data) 324 + if err != nil { 325 + t.Fatalf("MessageType: %v", err) 326 + } 327 + if got != tt.want { 328 + t.Errorf("MessageType: got 0x%02x, want 0x%02x", got, tt.want) 329 + } 330 + }) 331 + } 332 + 333 + if _, err := protocol.PeekMessageType(nil); err == nil { 334 + t.Fatal("expected error for empty data") 335 + } 336 + }