a geicko-2 based round robin ranking system designed to test c++ battleship submissions battleship.dunkirk.sh
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: allow playing against the submissions

+5413 -7
battleship-arena

This is a binary file and will not be displayed.

+8
cmd/battleship-arena/main.go
··· 19 19 "github.com/go-chi/chi/v5" 20 20 "github.com/go-chi/chi/v5/middleware" 21 21 22 + "battleship-arena/internal/game" 22 23 "battleship-arena/internal/runner" 23 24 "battleship-arena/internal/server" 24 25 "battleship-arena/internal/storage" ··· 77 78 78 79 server.InitSSE() 79 80 server.SetConfig(cfg.AdminPasscode, cfg.ExternalURL) 81 + 82 + // Start game cleanup worker 83 + game.Manager.StartCleanupWorker() 80 84 81 85 workerCtx, workerCancel := context.WithCancel(context.Background()) 82 86 defer workerCancel() ··· 131 135 r.Get("/player/{player}", server.HandlePlayerPage) 132 136 r.Get("/user/{username}", server.HandleUserProfile) 133 137 r.Get("/users", server.HandleUsers) 138 + r.Get("/play", server.HandlePlayPage) 139 + r.Get("/play/{aiName}", server.HandlePlayPage) 140 + r.Get("/api/available-ais", server.HandleAvailableAIs) 141 + r.HandleFunc("/ws/game", game.HandleGameWebSocket) 134 142 r.Get("/", server.HandleLeaderboard) 135 143 136 144 log.Println("Server running at " + cfg.ExternalURL)
+1
go.mod
··· 31 31 github.com/creack/pty v1.1.21 // indirect 32 32 github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect 33 33 github.com/go-logfmt/logfmt v0.6.0 // indirect 34 + github.com/gorilla/websocket v1.5.3 // indirect 34 35 github.com/kr/fs v0.1.0 // indirect 35 36 github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 36 37 github.com/mattn/go-isatty v0.0.20 // indirect
+2
go.sum
··· 46 46 github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= 47 47 github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= 48 48 github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= 49 + github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= 50 + github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 49 51 github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= 50 52 github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= 51 53 github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
+698
internal/game/game.go
··· 1 + // Package game provides human vs AI battleship game functionality 2 + // using the isolated player process architecture 3 + package game 4 + 5 + import ( 6 + "bufio" 7 + "context" 8 + "fmt" 9 + "io" 10 + "log" 11 + "math/rand" 12 + "os" 13 + "os/exec" 14 + "path/filepath" 15 + "regexp" 16 + "strconv" 17 + "strings" 18 + "sync" 19 + "time" 20 + ) 21 + 22 + const ( 23 + BoardSize = 10 24 + 25 + // Cell states 26 + CellEmpty = ' ' 27 + CellShip = 'S' 28 + CellHit = 'X' 29 + CellMiss = 'O' 30 + CellSunk = '#' 31 + 32 + // Result codes (matching C++ kasbs.h) 33 + ResultMiss = 0 34 + ResultHit = 8 35 + ResultSunk = 16 36 + ResultShip = 7 37 + 38 + // Ship types (matching C++ kasbs.h) 39 + ShipAC = 1 40 + ShipBS = 2 41 + ShipCR = 3 42 + ShipSB = 4 43 + ShipDS = 5 44 + ) 45 + 46 + // Ship represents a battleship 47 + type Ship struct { 48 + Name string 49 + Size int 50 + ShipType int // Ship number (AC=1, BS=2, etc.) 51 + Marker byte 52 + Hits int 53 + Cells [][2]int // [row, col] pairs 54 + } 55 + 56 + // Board represents a player's game board 57 + type Board struct { 58 + Grid [BoardSize][BoardSize]byte 59 + Ships []Ship 60 + } 61 + 62 + // Game represents a human vs AI battleship game 63 + type Game struct { 64 + ID string 65 + HumanBoard Board 66 + AIBoard Board 67 + PlayerName string 68 + AIName string 69 + CurrentTurn string // "human" or "ai" 70 + GameOver bool 71 + Winner string 72 + MoveCount int 73 + 74 + aiProcess *AIProcess 75 + mu sync.Mutex 76 + } 77 + 78 + // AIProcess manages communication with an AI player process 79 + type AIProcess struct { 80 + cmd *exec.Cmd 81 + stdin io.WriteCloser 82 + stdout *bufio.Reader 83 + alive bool 84 + } 85 + 86 + var enginePath = getEnginePath() 87 + 88 + func getEnginePath() string { 89 + if path := os.Getenv("BATTLESHIP_ENGINE_PATH"); path != "" { 90 + return path 91 + } 92 + return "./battleship-engine" 93 + } 94 + 95 + // compileAI compiles an AI submission if the source exists 96 + func compileAI(prefix string) error { 97 + srcDir := filepath.Join(enginePath, "src") 98 + buildDir := filepath.Join(enginePath, "build") 99 + 100 + srcFile := filepath.Join(srcDir, fmt.Sprintf("memory_functions_%s.cpp", prefix)) 101 + headerFile := filepath.Join(srcDir, fmt.Sprintf("memory_functions_%s.h", prefix)) 102 + 103 + // Check if source exists 104 + if _, err := os.Stat(srcFile); os.IsNotExist(err) { 105 + return fmt.Errorf("source file not found: %s", srcFile) 106 + } 107 + 108 + // Parse function suffix from source 109 + content, err := os.ReadFile(srcFile) 110 + if err != nil { 111 + return err 112 + } 113 + 114 + functionSuffix, err := parseFunctionNames(string(content)) 115 + if err != nil { 116 + return fmt.Errorf("failed to parse function names: %v", err) 117 + } 118 + 119 + // Generate header if missing 120 + if _, err := os.Stat(headerFile); os.IsNotExist(err) { 121 + headerContent := generateHeader(fmt.Sprintf("memory_functions_%s.h", prefix), functionSuffix) 122 + if err := os.WriteFile(headerFile, []byte(headerContent), 0644); err != nil { 123 + return err 124 + } 125 + } 126 + 127 + // Compile player binary 128 + playerBinary := filepath.Join(buildDir, "ai_"+prefix) 129 + 130 + compileArgs := []string{ 131 + "g++", "-std=c++11", "-O3", 132 + "-I", srcDir, 133 + fmt.Sprintf("-DPLAYER_SUFFIX=%s", functionSuffix), 134 + fmt.Sprintf(`-DPLAYER_HEADER="memory_functions_%s.h"`, prefix), 135 + "-o", playerBinary, 136 + filepath.Join(srcDir, "player_wrapper.cpp"), 137 + filepath.Join(srcDir, "battleship.cpp"), 138 + srcFile, 139 + } 140 + 141 + log.Printf("Compiling AI binary: %s", prefix) 142 + 143 + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) 144 + defer cancel() 145 + 146 + cmd := exec.CommandContext(ctx, compileArgs[0], compileArgs[1:]...) 147 + output, err := cmd.CombinedOutput() 148 + if err != nil { 149 + return fmt.Errorf("compilation failed: %s", output) 150 + } 151 + 152 + log.Printf("AI binary compiled: %s", playerBinary) 153 + return nil 154 + } 155 + 156 + func parseFunctionNames(cppContent string) (string, error) { 157 + re := regexp.MustCompile(`void\s+initMemory(\w+)\s*\(`) 158 + matches := re.FindStringSubmatch(cppContent) 159 + if len(matches) < 2 { 160 + return "", fmt.Errorf("could not find initMemory function") 161 + } 162 + return matches[1], nil 163 + } 164 + 165 + func generateHeader(filename, prefix string) string { 166 + guard := strings.ToUpper(strings.Replace(filename, ".", "_", -1)) 167 + 168 + return fmt.Sprintf(`#ifndef %s 169 + #define %s 170 + 171 + #include "memory.h" 172 + #include "battleship.h" 173 + #include <string> 174 + 175 + void initMemory%s(ComputerMemory &memory); 176 + std::string smartMove%s(const ComputerMemory &memory); 177 + void updateMemory%s(int row, int col, int result, ComputerMemory &memory); 178 + 179 + #endif 180 + `, guard, guard, prefix, prefix, prefix) 181 + } 182 + 183 + // NewGame creates a new human vs AI game 184 + func NewGame(playerName, aiSubmissionPrefix string) (*Game, error) { 185 + game := &Game{ 186 + ID: generateGameID(), 187 + PlayerName: playerName, 188 + AIName: aiSubmissionPrefix, 189 + CurrentTurn: "human", 190 + } 191 + 192 + // Initialize boards with random ship placement 193 + initializeBoard(&game.HumanBoard) 194 + initializeBoard(&game.AIBoard) 195 + 196 + // Start AI process (compile if needed) 197 + aiPath := filepath.Join(enginePath, "build", "ai_"+aiSubmissionPrefix) 198 + if _, err := os.Stat(aiPath); os.IsNotExist(err) { 199 + // Try to compile the AI 200 + log.Printf("AI binary not found, attempting to compile: %s", aiSubmissionPrefix) 201 + if err := compileAI(aiSubmissionPrefix); err != nil { 202 + return nil, fmt.Errorf("failed to compile AI %s: %v", aiSubmissionPrefix, err) 203 + } 204 + } 205 + 206 + aiProc, err := startAIProcess(aiPath) 207 + if err != nil { 208 + return nil, fmt.Errorf("failed to start AI: %v", err) 209 + } 210 + game.aiProcess = aiProc 211 + 212 + // Initialize AI 213 + if err := game.initAI(); err != nil { 214 + game.Close() 215 + return nil, fmt.Errorf("failed to initialize AI: %v", err) 216 + } 217 + 218 + return game, nil 219 + } 220 + 221 + func generateGameID() string { 222 + return fmt.Sprintf("game-%d", time.Now().UnixNano()) 223 + } 224 + 225 + func startAIProcess(binaryPath string) (*AIProcess, error) { 226 + cmd := exec.Command(binaryPath) 227 + 228 + stdin, err := cmd.StdinPipe() 229 + if err != nil { 230 + return nil, err 231 + } 232 + 233 + stdout, err := cmd.StdoutPipe() 234 + if err != nil { 235 + stdin.Close() 236 + return nil, err 237 + } 238 + 239 + if err := cmd.Start(); err != nil { 240 + stdin.Close() 241 + stdout.Close() 242 + return nil, err 243 + } 244 + 245 + return &AIProcess{ 246 + cmd: cmd, 247 + stdin: stdin, 248 + stdout: bufio.NewReader(stdout), 249 + alive: true, 250 + }, nil 251 + } 252 + 253 + func (p *AIProcess) sendLine(line string) error { 254 + if !p.alive { 255 + return fmt.Errorf("AI process not alive") 256 + } 257 + _, err := fmt.Fprintln(p.stdin, line) 258 + return err 259 + } 260 + 261 + func (p *AIProcess) readLine() (string, error) { 262 + if !p.alive { 263 + return "", fmt.Errorf("AI process not alive") 264 + } 265 + line, err := p.stdout.ReadString('\n') 266 + if err != nil { 267 + p.alive = false 268 + return "", err 269 + } 270 + return strings.TrimSpace(line), nil 271 + } 272 + 273 + func (p *AIProcess) close() { 274 + if p.stdin != nil { 275 + p.sendLine("QUIT") 276 + p.stdin.Close() 277 + } 278 + if p.cmd != nil && p.cmd.Process != nil { 279 + p.cmd.Process.Kill() 280 + p.cmd.Wait() 281 + } 282 + p.alive = false 283 + } 284 + 285 + func (g *Game) initAI() error { 286 + // Handshake 287 + if err := g.aiProcess.sendLine("HELLO 1"); err != nil { 288 + return err 289 + } 290 + 291 + resp, err := g.aiProcess.readLine() 292 + if err != nil { 293 + return err 294 + } 295 + if resp != "HELLO OK" { 296 + return fmt.Errorf("bad handshake response: %s", resp) 297 + } 298 + 299 + // Init for game 300 + if err := g.aiProcess.sendLine("INIT"); err != nil { 301 + return err 302 + } 303 + 304 + resp, err = g.aiProcess.readLine() 305 + if err != nil { 306 + return err 307 + } 308 + if resp != "OK" { 309 + return fmt.Errorf("bad init response: %s", resp) 310 + } 311 + 312 + return nil 313 + } 314 + 315 + // Close cleans up game resources 316 + func (g *Game) Close() { 317 + g.mu.Lock() 318 + defer g.mu.Unlock() 319 + 320 + if g.aiProcess != nil { 321 + g.aiProcess.close() 322 + g.aiProcess = nil 323 + } 324 + } 325 + 326 + // HumanMove processes a human player's move 327 + func (g *Game) HumanMove(move string) (*MoveResult, error) { 328 + g.mu.Lock() 329 + defer g.mu.Unlock() 330 + 331 + if g.GameOver { 332 + return nil, fmt.Errorf("game is over") 333 + } 334 + 335 + if g.CurrentTurn != "human" { 336 + return nil, fmt.Errorf("not your turn") 337 + } 338 + 339 + // Parse move (e.g., "A5" -> row=0, col=4) 340 + row, col, err := parseMove(move) 341 + if err != nil { 342 + return nil, err 343 + } 344 + 345 + // Check if already targeted 346 + if g.AIBoard.Grid[row][col] == CellHit || g.AIBoard.Grid[row][col] == CellMiss || g.AIBoard.Grid[row][col] == CellSunk { 347 + return nil, fmt.Errorf("cell already targeted") 348 + } 349 + 350 + // Execute move on AI's board 351 + result := g.executeMove(&g.AIBoard, row, col) 352 + g.MoveCount++ 353 + 354 + // Check for win 355 + if g.checkAllSunk(&g.AIBoard) { 356 + g.GameOver = true 357 + g.Winner = "human" 358 + } else { 359 + g.CurrentTurn = "ai" 360 + } 361 + 362 + return result, nil 363 + } 364 + 365 + // AIMove gets and executes the AI's move 366 + func (g *Game) AIMove() (*MoveResult, error) { 367 + g.mu.Lock() 368 + defer g.mu.Unlock() 369 + 370 + if g.GameOver { 371 + return nil, fmt.Errorf("game is over") 372 + } 373 + 374 + if g.CurrentTurn != "ai" { 375 + return nil, fmt.Errorf("not AI's turn") 376 + } 377 + 378 + // Get move from AI 379 + if err := g.aiProcess.sendLine("GET_MOVE"); err != nil { 380 + return nil, fmt.Errorf("failed to request AI move: %v", err) 381 + } 382 + 383 + resp, err := g.aiProcess.readLine() 384 + if err != nil { 385 + return nil, fmt.Errorf("failed to read AI move: %v", err) 386 + } 387 + 388 + if !strings.HasPrefix(resp, "MOVE ") { 389 + return nil, fmt.Errorf("invalid AI response: %s", resp) 390 + } 391 + 392 + move := strings.TrimPrefix(resp, "MOVE ") 393 + row, col, err := parseMove(move) 394 + if err != nil { 395 + // AI gave invalid move, use random 396 + row, col = g.randomValidMove(&g.HumanBoard) 397 + } 398 + 399 + // Check if already targeted, use random if so 400 + for g.HumanBoard.Grid[row][col] == CellHit || g.HumanBoard.Grid[row][col] == CellMiss || g.HumanBoard.Grid[row][col] == CellSunk { 401 + row, col = g.randomValidMove(&g.HumanBoard) 402 + } 403 + 404 + // Execute move on human's board 405 + result := g.executeMove(&g.HumanBoard, row, col) 406 + result.Move = formatMove(row, col) 407 + g.MoveCount++ 408 + 409 + // Update AI with result 410 + updateCmd := fmt.Sprintf("UPDATE %d %d %d", row, col, result.ResultCode) 411 + if err := g.aiProcess.sendLine(updateCmd); err != nil { 412 + return nil, fmt.Errorf("failed to update AI: %v", err) 413 + } 414 + 415 + resp, err = g.aiProcess.readLine() 416 + if err != nil { 417 + return nil, fmt.Errorf("failed to read AI update response: %v", err) 418 + } 419 + 420 + // Check for win 421 + if g.checkAllSunk(&g.HumanBoard) { 422 + g.GameOver = true 423 + // Only set winner if not already set (human won on their last turn) 424 + if g.Winner == "" { 425 + g.Winner = "ai" 426 + } 427 + } else { 428 + g.CurrentTurn = "human" 429 + } 430 + 431 + return result, nil 432 + } 433 + 434 + func (g *Game) randomValidMove(board *Board) (int, int) { 435 + for { 436 + row := rand.Intn(BoardSize) 437 + col := rand.Intn(BoardSize) 438 + cell := board.Grid[row][col] 439 + if cell != CellHit && cell != CellMiss && cell != CellSunk { 440 + return row, col 441 + } 442 + } 443 + } 444 + 445 + // MoveResult represents the result of a move 446 + type MoveResult struct { 447 + Move string `json:"move"` 448 + Row int `json:"row"` 449 + Col int `json:"col"` 450 + Hit bool `json:"hit"` 451 + Sunk bool `json:"sunk"` 452 + ShipName string `json:"shipName,omitempty"` 453 + ResultCode int `json:"resultCode"` 454 + } 455 + 456 + func (g *Game) executeMove(board *Board, row, col int) *MoveResult { 457 + result := &MoveResult{ 458 + Move: formatMove(row, col), 459 + Row: row, 460 + Col: col, 461 + } 462 + 463 + cell := board.Grid[row][col] 464 + 465 + // Check if it's a ship 466 + for i := range board.Ships { 467 + ship := &board.Ships[i] 468 + for _, pos := range ship.Cells { 469 + if pos[0] == row && pos[1] == col { 470 + // Hit! 471 + ship.Hits++ 472 + result.Hit = true 473 + 474 + if ship.Hits >= ship.Size { 475 + // Sunk! 476 + result.Sunk = true 477 + result.ShipName = ship.Name 478 + result.ResultCode = ResultSunk | ship.ShipType 479 + // Mark all cells as sunk 480 + for _, sunkPos := range ship.Cells { 481 + board.Grid[sunkPos[0]][sunkPos[1]] = CellSunk 482 + } 483 + } else { 484 + board.Grid[row][col] = CellHit 485 + result.ResultCode = ResultHit | ship.ShipType 486 + } 487 + return result 488 + } 489 + } 490 + } 491 + 492 + // Miss 493 + if cell == CellEmpty { 494 + board.Grid[row][col] = CellMiss 495 + } 496 + result.Hit = false 497 + result.ResultCode = ResultMiss 498 + return result 499 + } 500 + 501 + func (g *Game) checkAllSunk(board *Board) bool { 502 + for _, ship := range board.Ships { 503 + if ship.Hits < ship.Size { 504 + return false 505 + } 506 + } 507 + return true 508 + } 509 + 510 + // GetState returns the current game state for the UI 511 + func (g *Game) GetState() *GameState { 512 + g.mu.Lock() 513 + defer g.mu.Unlock() 514 + 515 + state := &GameState{ 516 + GameID: g.ID, 517 + PlayerName: g.PlayerName, 518 + AIName: g.AIName, 519 + CurrentTurn: g.CurrentTurn, 520 + GameOver: g.GameOver, 521 + Winner: g.Winner, 522 + MoveCount: g.MoveCount, 523 + } 524 + 525 + // Human's view of AI board (hide ships) 526 + for row := 0; row < BoardSize; row++ { 527 + for col := 0; col < BoardSize; col++ { 528 + cell := g.AIBoard.Grid[row][col] 529 + switch cell { 530 + case CellHit, CellMiss, CellSunk: 531 + state.EnemyBoard[row][col] = cell 532 + default: 533 + state.EnemyBoard[row][col] = CellEmpty 534 + } 535 + } 536 + } 537 + 538 + // Human's own board (show ships) 539 + for row := 0; row < BoardSize; row++ { 540 + for col := 0; col < BoardSize; col++ { 541 + state.OwnBoard[row][col] = g.HumanBoard.Grid[row][col] 542 + } 543 + } 544 + 545 + // Ship status 546 + for _, ship := range g.HumanBoard.Ships { 547 + state.OwnShips = append(state.OwnShips, ShipStatus{ 548 + Name: ship.Name, 549 + Size: ship.Size, 550 + Hits: ship.Hits, 551 + Sunk: ship.Hits >= ship.Size, 552 + }) 553 + } 554 + 555 + for _, ship := range g.AIBoard.Ships { 556 + state.EnemyShips = append(state.EnemyShips, ShipStatus{ 557 + Name: ship.Name, 558 + Size: ship.Size, 559 + Hits: ship.Hits, 560 + Sunk: ship.Hits >= ship.Size, 561 + }) 562 + } 563 + 564 + return state 565 + } 566 + 567 + // GameState represents the game state for JSON serialization 568 + type GameState struct { 569 + GameID string `json:"gameId"` 570 + PlayerName string `json:"playerName"` 571 + AIName string `json:"aiName"` 572 + CurrentTurn string `json:"currentTurn"` 573 + GameOver bool `json:"gameOver"` 574 + Winner string `json:"winner,omitempty"` 575 + MoveCount int `json:"moveCount"` 576 + OwnBoard [BoardSize][BoardSize]byte `json:"ownBoard"` 577 + EnemyBoard [BoardSize][BoardSize]byte `json:"enemyBoard"` 578 + OwnShips []ShipStatus `json:"ownShips"` 579 + EnemyShips []ShipStatus `json:"enemyShips"` 580 + } 581 + 582 + type ShipStatus struct { 583 + Name string `json:"name"` 584 + Size int `json:"size"` 585 + Hits int `json:"hits"` 586 + Sunk bool `json:"sunk"` 587 + } 588 + 589 + // Helper functions 590 + 591 + func parseMove(move string) (int, int, error) { 592 + move = strings.ToUpper(strings.TrimSpace(move)) 593 + if len(move) < 2 || len(move) > 3 { 594 + return 0, 0, fmt.Errorf("invalid move format") 595 + } 596 + 597 + row := int(move[0] - 'A') 598 + if row < 0 || row >= BoardSize { 599 + return 0, 0, fmt.Errorf("invalid row") 600 + } 601 + 602 + col, err := strconv.Atoi(move[1:]) 603 + if err != nil { 604 + return 0, 0, fmt.Errorf("invalid column") 605 + } 606 + col-- // Convert to 0-indexed 607 + 608 + if col < 0 || col >= BoardSize { 609 + return 0, 0, fmt.Errorf("invalid column") 610 + } 611 + 612 + return row, col, nil 613 + } 614 + 615 + func formatMove(row, col int) string { 616 + return fmt.Sprintf("%c%d", 'A'+row, col+1) 617 + } 618 + 619 + func initializeBoard(board *Board) { 620 + // Initialize empty grid 621 + for row := 0; row < BoardSize; row++ { 622 + for col := 0; col < BoardSize; col++ { 623 + board.Grid[row][col] = CellEmpty 624 + } 625 + } 626 + 627 + // Define ships 628 + shipDefs := []struct { 629 + name string 630 + size int 631 + shipType int 632 + marker byte 633 + }{ 634 + {"Aircraft Carrier", 5, ShipAC, 'A'}, 635 + {"Battleship", 4, ShipBS, 'B'}, 636 + {"Cruiser", 3, ShipCR, 'C'}, 637 + {"Submarine", 3, ShipSB, 'S'}, 638 + {"Destroyer", 2, ShipDS, 'D'}, 639 + } 640 + 641 + rand.Seed(time.Now().UnixNano()) 642 + 643 + for _, def := range shipDefs { 644 + ship := Ship{ 645 + Name: def.name, 646 + Size: def.size, 647 + ShipType: def.shipType, 648 + Marker: def.marker, 649 + } 650 + 651 + // Try to place ship randomly 652 + placed := false 653 + for attempts := 0; attempts < 1000 && !placed; attempts++ { 654 + row := rand.Intn(BoardSize) 655 + col := rand.Intn(BoardSize) 656 + horizontal := rand.Intn(2) == 0 657 + 658 + if canPlaceShip(board, row, col, def.size, horizontal) { 659 + placeShip(board, &ship, row, col, horizontal) 660 + placed = true 661 + } 662 + } 663 + 664 + board.Ships = append(board.Ships, ship) 665 + } 666 + } 667 + 668 + func canPlaceShip(board *Board, row, col, size int, horizontal bool) bool { 669 + for i := 0; i < size; i++ { 670 + r, c := row, col 671 + if horizontal { 672 + c += i 673 + } else { 674 + r += i 675 + } 676 + 677 + if r >= BoardSize || c >= BoardSize { 678 + return false 679 + } 680 + if board.Grid[r][c] != CellEmpty { 681 + return false 682 + } 683 + } 684 + return true 685 + } 686 + 687 + func placeShip(board *Board, ship *Ship, row, col int, horizontal bool) { 688 + for i := 0; i < ship.Size; i++ { 689 + r, c := row, col 690 + if horizontal { 691 + c += i 692 + } else { 693 + r += i 694 + } 695 + board.Grid[r][c] = CellShip 696 + ship.Cells = append(ship.Cells, [2]int{r, c}) 697 + } 698 + }
+253
internal/game/websocket.go
··· 1 + package game 2 + 3 + import ( 4 + "encoding/json" 5 + "log" 6 + "net/http" 7 + "sync" 8 + "time" 9 + 10 + "github.com/gorilla/websocket" 11 + ) 12 + 13 + var upgrader = websocket.Upgrader{ 14 + CheckOrigin: func(r *http.Request) bool { 15 + return true // Allow all origins for now 16 + }, 17 + } 18 + 19 + // GameManager manages active games 20 + type GameManager struct { 21 + games map[string]*Game 22 + gameConns map[string]*websocket.Conn // Track which conn owns which game 23 + mu sync.RWMutex 24 + } 25 + 26 + var Manager = &GameManager{ 27 + games: make(map[string]*Game), 28 + gameConns: make(map[string]*websocket.Conn), 29 + } 30 + 31 + // CleanupOrphanedGames removes games that have no active connection 32 + func (m *GameManager) CleanupOrphanedGames() { 33 + m.mu.Lock() 34 + defer m.mu.Unlock() 35 + 36 + for gameID, game := range m.games { 37 + if _, exists := m.gameConns[gameID]; !exists { 38 + log.Printf("Cleaning up orphaned game: %s", gameID) 39 + game.Close() 40 + delete(m.games, gameID) 41 + } 42 + } 43 + } 44 + 45 + // StartCleanupWorker starts a background worker to clean up orphaned games 46 + func (m *GameManager) StartCleanupWorker() { 47 + go func() { 48 + ticker := time.NewTicker(30 * time.Second) 49 + defer ticker.Stop() 50 + 51 + for range ticker.C { 52 + m.CleanupOrphanedGames() 53 + } 54 + }() 55 + } 56 + 57 + // Message types for WebSocket communication 58 + type WSMessageIn struct { 59 + Type string `json:"type"` 60 + Payload json.RawMessage `json:"payload,omitempty"` 61 + } 62 + 63 + type WSMessageOut struct { 64 + Type string `json:"type"` 65 + Payload interface{} `json:"payload,omitempty"` 66 + } 67 + 68 + type StartGamePayload struct { 69 + AIName string `json:"aiName"` 70 + } 71 + 72 + type MovePayload struct { 73 + Move string `json:"move"` 74 + } 75 + 76 + type ErrorPayload struct { 77 + Error string `json:"error"` 78 + } 79 + 80 + // HandleGameWebSocket handles WebSocket connections for gameplay 81 + func HandleGameWebSocket(w http.ResponseWriter, r *http.Request) { 82 + conn, err := upgrader.Upgrade(w, r, nil) 83 + if err != nil { 84 + log.Printf("WebSocket upgrade failed: %v", err) 85 + return 86 + } 87 + defer conn.Close() 88 + 89 + var currentGame *Game 90 + connID := time.Now().UnixNano() 91 + 92 + log.Printf("WebSocket connected: %d", connID) 93 + 94 + // Cleanup function to ensure game is properly closed 95 + cleanupCurrentGame := func() { 96 + if currentGame != nil { 97 + log.Printf("Cleaning up game %s for connection %d", currentGame.ID, connID) 98 + Manager.mu.Lock() 99 + delete(Manager.games, currentGame.ID) 100 + delete(Manager.gameConns, currentGame.ID) 101 + Manager.mu.Unlock() 102 + currentGame.Close() 103 + currentGame = nil 104 + } 105 + } 106 + 107 + defer func() { 108 + log.Printf("WebSocket disconnected: %d", connID) 109 + cleanupCurrentGame() 110 + }() 111 + 112 + for { 113 + _, message, err := conn.ReadMessage() 114 + if err != nil { 115 + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 116 + log.Printf("WebSocket error (conn %d): %v", connID, err) 117 + } 118 + break 119 + } 120 + 121 + var msg WSMessageIn 122 + if err := json.Unmarshal(message, &msg); err != nil { 123 + sendError(conn, "Invalid message format") 124 + continue 125 + } 126 + 127 + switch msg.Type { 128 + case "start_game": 129 + var payload StartGamePayload 130 + if err := json.Unmarshal(msg.Payload, &payload); err != nil { 131 + sendError(conn, "Invalid start_game payload") 132 + continue 133 + } 134 + 135 + // Clean up previous game for this connection 136 + cleanupCurrentGame() 137 + 138 + // Create new game 139 + game, err := NewGame("Player", payload.AIName) 140 + if err != nil { 141 + sendError(conn, err.Error()) 142 + continue 143 + } 144 + 145 + currentGame = game 146 + Manager.mu.Lock() 147 + Manager.games[game.ID] = game 148 + Manager.gameConns[game.ID] = conn 149 + Manager.mu.Unlock() 150 + 151 + log.Printf("Started game %s for connection %d (AI: %s)", game.ID, connID, payload.AIName) 152 + 153 + // Send initial state 154 + sendGameState(conn, game) 155 + 156 + case "move": 157 + if currentGame == nil { 158 + sendError(conn, "No active game") 159 + continue 160 + } 161 + 162 + var payload MovePayload 163 + if err := json.Unmarshal(msg.Payload, &payload); err != nil { 164 + sendError(conn, "Invalid move payload") 165 + continue 166 + } 167 + 168 + // Process human move 169 + result, err := currentGame.HumanMove(payload.Move) 170 + if err != nil { 171 + sendError(conn, err.Error()) 172 + continue 173 + } 174 + 175 + // Send human move result 176 + sendMoveResult(conn, "human_move_result", result) 177 + 178 + // Only let AI move if game is still ongoing 179 + // Check both GameOver flag and that it's actually AI's turn 180 + if !currentGame.GameOver && currentGame.CurrentTurn == "ai" && currentGame.Winner == "" { 181 + aiResult, err := currentGame.AIMove() 182 + if err != nil { 183 + sendError(conn, "AI error: "+err.Error()) 184 + continue 185 + } 186 + sendMoveResult(conn, "ai_move_result", aiResult) 187 + } 188 + 189 + // Send updated state 190 + sendGameState(conn, currentGame) 191 + 192 + // Clean up game if it's over to free AI process immediately 193 + if currentGame.GameOver { 194 + log.Printf("Game %s finished (winner: %s), cleaning up", currentGame.ID, currentGame.Winner) 195 + cleanupCurrentGame() 196 + } 197 + 198 + case "get_state": 199 + if currentGame == nil { 200 + sendError(conn, "No active game") 201 + continue 202 + } 203 + sendGameState(conn, currentGame) 204 + 205 + case "quit": 206 + cleanupCurrentGame() 207 + sendJSON(conn, WSMessageOut{Type: "game_ended"}) 208 + 209 + default: 210 + sendError(conn, "Unknown message type: "+msg.Type) 211 + } 212 + } 213 + } 214 + 215 + func sendJSON(conn *websocket.Conn, v interface{}) { 216 + data, err := json.Marshal(v) 217 + if err != nil { 218 + log.Printf("JSON marshal error: %v", err) 219 + return 220 + } 221 + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { 222 + log.Printf("WebSocket write error: %v", err) 223 + } 224 + } 225 + 226 + func sendError(conn *websocket.Conn, errMsg string) { 227 + sendJSON(conn, WSMessageOut{ 228 + Type: "error", 229 + Payload: ErrorPayload{Error: errMsg}, 230 + }) 231 + } 232 + 233 + func sendGameState(conn *websocket.Conn, game *Game) { 234 + state := game.GetState() 235 + sendJSON(conn, WSMessageOut{ 236 + Type: "game_state", 237 + Payload: state, 238 + }) 239 + } 240 + 241 + func sendMoveResult(conn *websocket.Conn, msgType string, result *MoveResult) { 242 + sendJSON(conn, WSMessageOut{ 243 + Type: msgType, 244 + Payload: result, 245 + }) 246 + } 247 + 248 + // GetAvailableAIs returns list of AI submissions that can be played against 249 + func GetAvailableAIs() ([]string, error) { 250 + // This would query the storage for active submissions 251 + // For now, return a placeholder 252 + return []string{}, nil 253 + }
+731
internal/server/play.go
··· 1 + package server 2 + 3 + import ( 4 + "encoding/json" 5 + "html/template" 6 + "net/http" 7 + "regexp" 8 + 9 + "github.com/go-chi/chi/v5" 10 + 11 + "battleship-arena/internal/storage" 12 + ) 13 + 14 + func HandlePlayPage(w http.ResponseWriter, r *http.Request) { 15 + aiName := chi.URLParam(r, "aiName") 16 + 17 + tmpl := template.Must(template.New("play").Parse(playPageHTML)) 18 + tmpl.Execute(w, map[string]string{ 19 + "AIName": aiName, 20 + "ServerURL": GetServerURL(), 21 + }) 22 + } 23 + 24 + func HandleAvailableAIs(w http.ResponseWriter, r *http.Request) { 25 + entries, err := storage.GetLeaderboard(50) 26 + if err != nil { 27 + http.Error(w, "Failed to load AIs", http.StatusInternalServerError) 28 + return 29 + } 30 + 31 + type AIInfo struct { 32 + Name string `json:"name"` 33 + Rating int `json:"rating"` 34 + } 35 + 36 + var ais []AIInfo 37 + re := regexp.MustCompile(`memory_functions_(\w+)\.cpp`) 38 + 39 + for _, e := range entries { 40 + if e.IsBroken { 41 + continue 42 + } 43 + // Get submission to extract AI name from filename 44 + subs, err := storage.GetUserSubmissions(e.Username) 45 + if err != nil || len(subs) == 0 { 46 + continue 47 + } 48 + 49 + // Find active submission 50 + for _, sub := range subs { 51 + if sub.Status == "completed" { 52 + matches := re.FindStringSubmatch(sub.Filename) 53 + if len(matches) >= 2 { 54 + ais = append(ais, AIInfo{ 55 + Name: matches[1], 56 + Rating: e.Rating, 57 + }) 58 + break 59 + } 60 + } 61 + } 62 + } 63 + 64 + w.Header().Set("Content-Type", "application/json") 65 + json.NewEncoder(w).Encode(ais) 66 + } 67 + 68 + const playPageHTML = ` 69 + <!DOCTYPE html> 70 + <html lang="en"> 71 + <head> 72 + <title>Play Battleship - Battleship Arena</title> 73 + <meta charset="UTF-8"> 74 + <meta name="viewport" content="width=device-width, initial-scale=1.0"> 75 + <link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>⚓</text></svg>"> 76 + <style> 77 + * { 78 + margin: 0; 79 + padding: 0; 80 + box-sizing: border-box; 81 + } 82 + 83 + body { 84 + font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; 85 + background: #0f172a; 86 + color: #e2e8f0; 87 + min-height: 100vh; 88 + padding: 2rem 1rem; 89 + } 90 + 91 + .container { 92 + max-width: 1400px; 93 + margin: 0 auto; 94 + } 95 + 96 + header { 97 + text-align: center; 98 + margin-bottom: 3rem; 99 + } 100 + 101 + h1 { 102 + font-size: 3rem; 103 + font-weight: 800; 104 + background: linear-gradient(135deg, #3b82f6 0%, #8b5cf6 50%, #ec4899 100%); 105 + -webkit-background-clip: text; 106 + -webkit-text-fill-color: transparent; 107 + background-clip: text; 108 + margin-bottom: 0.5rem; 109 + } 110 + 111 + .back-link { 112 + display: inline-block; 113 + margin-bottom: 1.5rem; 114 + color: #60a5fa; 115 + text-decoration: none; 116 + font-size: 0.875rem; 117 + transition: color 0.2s; 118 + } 119 + 120 + .back-link:hover { 121 + color: #93c5fd; 122 + } 123 + 124 + .game-container { 125 + display: flex; 126 + gap: 2rem; 127 + justify-content: center; 128 + flex-wrap: wrap; 129 + } 130 + 131 + .board-section { 132 + background: #1e293b; 133 + border: 1px solid #334155; 134 + border-radius: 0.75rem; 135 + padding: 1.5rem; 136 + } 137 + 138 + .board-title { 139 + font-size: 1.125rem; 140 + font-weight: 700; 141 + margin-bottom: 1.25rem; 142 + color: #e2e8f0; 143 + text-transform: uppercase; 144 + letter-spacing: 0.05em; 145 + font-size: 0.875rem; 146 + } 147 + 148 + .board { 149 + display: grid; 150 + grid-template-columns: 30px repeat(10, 40px); 151 + grid-template-rows: 30px repeat(10, 40px); 152 + gap: 1px; 153 + background: #0f172a; 154 + padding: 0.5rem; 155 + border-radius: 0.5rem; 156 + margin: 0 auto; 157 + width: fit-content; 158 + } 159 + 160 + .board.enemy { 161 + cursor: crosshair; 162 + } 163 + 164 + .board.enemy.disabled { 165 + cursor: not-allowed; 166 + opacity: 0.6; 167 + } 168 + 169 + .header-cell { 170 + display: flex; 171 + align-items: center; 172 + justify-content: center; 173 + font-weight: 600; 174 + color: #64748b; 175 + font-size: 0.75rem; 176 + text-transform: uppercase; 177 + } 178 + 179 + .cell { 180 + width: 40px; 181 + height: 40px; 182 + background: #1e293b; 183 + border: 1px solid #334155; 184 + display: flex; 185 + align-items: center; 186 + justify-content: center; 187 + font-size: 1.25rem; 188 + transition: all 0.15s; 189 + border-radius: 2px; 190 + } 191 + 192 + .board.enemy .cell:not(.hit):not(.miss):not(.sunk):hover { 193 + background: rgba(59, 130, 246, 0.2); 194 + border-color: #3b82f6; 195 + transform: scale(1.05); 196 + cursor: pointer; 197 + } 198 + 199 + .cell.ship { 200 + background: #334155; 201 + } 202 + 203 + .cell.hit { 204 + background: rgba(220, 38, 38, 0.2); 205 + border-color: #dc2626; 206 + color: #fca5a5; 207 + } 208 + 209 + .cell.miss { 210 + background: rgba(30, 64, 175, 0.15); 211 + border-color: #1e40af; 212 + color: #60a5fa; 213 + } 214 + 215 + .cell.sunk { 216 + background: rgba(127, 29, 29, 0.3); 217 + border-color: #7f1d1d; 218 + color: #fca5a5; 219 + } 220 + 221 + .ships-panel { 222 + background: #0f172a; 223 + border: 1px solid #334155; 224 + border-radius: 0.5rem; 225 + padding: 1.25rem; 226 + margin-top: 1.25rem; 227 + } 228 + 229 + .ships-title { 230 + font-size: 0.75rem; 231 + color: #94a3b8; 232 + margin-bottom: 0.75rem; 233 + text-transform: uppercase; 234 + letter-spacing: 0.05em; 235 + font-weight: 600; 236 + } 237 + 238 + .ship-row { 239 + display: flex; 240 + align-items: center; 241 + gap: 0.5rem; 242 + padding: 0.5rem; 243 + font-size: 0.875rem; 244 + border-radius: 0.25rem; 245 + transition: background 0.15s; 246 + } 247 + 248 + .ship-row:hover { 249 + background: rgba(59, 130, 246, 0.05); 250 + } 251 + 252 + .ship-row.sunk { 253 + text-decoration: line-through; 254 + opacity: 0.4; 255 + } 256 + 257 + .ship-icon { 258 + font-size: 1rem; 259 + } 260 + 261 + .ship-name { 262 + font-size: 0.875rem; 263 + color: #e2e8f0; 264 + } 265 + 266 + .status-bar { 267 + text-align: center; 268 + padding: 1.25rem; 269 + margin: 2rem auto; 270 + max-width: 600px; 271 + border-radius: 0.75rem; 272 + font-size: 1.125rem; 273 + font-weight: 600; 274 + border: 1px solid; 275 + } 276 + 277 + .status-bar.your-turn { 278 + background: rgba(16, 185, 129, 0.1); 279 + border-color: rgba(16, 185, 129, 0.3); 280 + color: #10b981; 281 + } 282 + 283 + .status-bar.ai-turn { 284 + background: rgba(245, 158, 11, 0.1); 285 + border-color: rgba(245, 158, 11, 0.3); 286 + color: #f59e0b; 287 + } 288 + 289 + .status-bar.game-over { 290 + background: rgba(139, 92, 246, 0.1); 291 + border-color: rgba(139, 92, 246, 0.3); 292 + color: #8b5cf6; 293 + } 294 + 295 + .ai-selector { 296 + background: #1e293b; 297 + border: 1px solid #334155; 298 + border-radius: 0.75rem; 299 + padding: 2.5rem; 300 + text-align: center; 301 + max-width: 800px; 302 + margin: 0 auto; 303 + } 304 + 305 + .ai-selector h2 { 306 + font-size: 1.5rem; 307 + margin-bottom: 0.5rem; 308 + } 309 + 310 + .ai-selector p { 311 + color: #94a3b8; 312 + margin-bottom: 2rem; 313 + } 314 + 315 + .ai-list { 316 + display: grid; 317 + grid-template-columns: repeat(auto-fill, minmax(160px, 1fr)); 318 + gap: 1rem; 319 + } 320 + 321 + .ai-card { 322 + background: #0f172a; 323 + border: 1px solid #334155; 324 + border-radius: 0.5rem; 325 + padding: 1.25rem; 326 + cursor: pointer; 327 + transition: all 0.15s; 328 + } 329 + 330 + .ai-card:hover { 331 + border-color: #3b82f6; 332 + background: rgba(59, 130, 246, 0.05); 333 + transform: translateY(-2px); 334 + } 335 + 336 + .ai-name { 337 + font-weight: 600; 338 + margin-bottom: 0.5rem; 339 + color: #e2e8f0; 340 + } 341 + 342 + .ai-rating { 343 + font-size: 0.875rem; 344 + color: #94a3b8; 345 + } 346 + 347 + .btn { 348 + background: linear-gradient(135deg, #3b82f6, #8b5cf6); 349 + color: white; 350 + border: none; 351 + padding: 0.875rem 2rem; 352 + border-radius: 0.5rem; 353 + font-size: 0.9375rem; 354 + font-weight: 600; 355 + cursor: pointer; 356 + transition: all 0.15s; 357 + } 358 + 359 + .btn:hover { 360 + transform: translateY(-1px); 361 + box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4); 362 + } 363 + 364 + .btn:disabled { 365 + opacity: 0.5; 366 + cursor: not-allowed; 367 + transform: none; 368 + } 369 + 370 + .move-log { 371 + background: #0f172a; 372 + border: 1px solid #334155; 373 + border-radius: 0.5rem; 374 + padding: 1.25rem; 375 + margin-top: 2rem; 376 + max-width: 800px; 377 + margin-left: auto; 378 + margin-right: auto; 379 + } 380 + 381 + .move-log-title { 382 + font-size: 0.75rem; 383 + color: #94a3b8; 384 + margin-bottom: 0.75rem; 385 + text-transform: uppercase; 386 + letter-spacing: 0.05em; 387 + font-weight: 600; 388 + } 389 + 390 + .move-log-content { 391 + max-height: 200px; 392 + overflow-y: auto; 393 + } 394 + 395 + .move-entry { 396 + font-size: 0.875rem; 397 + padding: 0.5rem; 398 + margin-bottom: 0.25rem; 399 + border-radius: 0.25rem; 400 + background: #1e293b; 401 + } 402 + 403 + .move-entry:last-child { 404 + margin-bottom: 0; 405 + } 406 + 407 + .move-entry.hit { 408 + color: #fca5a5; 409 + border-left: 3px solid #dc2626; 410 + } 411 + 412 + .move-entry.miss { 413 + color: #93c5fd; 414 + border-left: 3px solid #3b82f6; 415 + } 416 + 417 + .move-entry.sunk { 418 + color: #fbbf24; 419 + border-left: 3px solid #f59e0b; 420 + font-weight: 600; 421 + } 422 + 423 + #loading { 424 + text-align: center; 425 + padding: 2rem; 426 + color: #64748b; 427 + } 428 + 429 + .error-toast { 430 + position: fixed; 431 + bottom: 2rem; 432 + left: 50%; 433 + transform: translateX(-50%); 434 + background: #7f1d1d; 435 + border: 1px solid #dc2626; 436 + color: #fca5a5; 437 + padding: 1rem 2rem; 438 + border-radius: 8px; 439 + font-size: 0.9rem; 440 + z-index: 1000; 441 + animation: slideUp 0.3s ease-out; 442 + } 443 + 444 + .error-toast.fade-out { 445 + animation: fadeOut 0.3s ease-out forwards; 446 + } 447 + 448 + @keyframes slideUp { 449 + from { transform: translateX(-50%) translateY(100%); opacity: 0; } 450 + to { transform: translateX(-50%) translateY(0); opacity: 1; } 451 + } 452 + 453 + @keyframes fadeOut { 454 + to { opacity: 0; transform: translateX(-50%) translateY(20px); } 455 + } 456 + </style> 457 + </head> 458 + <body> 459 + <div class="container"> 460 + <a href="/" class="back-link">← Back to Leaderboard</a> 461 + 462 + <header> 463 + <h1>⚓ Battleship Arena</h1> 464 + <p style="color: #94a3b8; font-size: 1.125rem;">Challenge an AI opponent</p> 465 + </header> 466 + 467 + <div id="game-area"> 468 + <div id="loading">Loading available AIs...</div> 469 + </div> 470 + </div> 471 + 472 + <script> 473 + let ws = null; 474 + let gameState = null; 475 + let selectedAI = "{{.AIName}}"; 476 + 477 + const CELL_EMPTY = 32; // ' ' 478 + const CELL_SHIP = 83; // 'S' 479 + const CELL_HIT = 88; // 'X' 480 + const CELL_MISS = 79; // 'O' 481 + const CELL_SUNK = 35; // '#' 482 + 483 + async function init() { 484 + if (selectedAI) { 485 + startGame(selectedAI); 486 + } else { 487 + await showAISelector(); 488 + } 489 + } 490 + 491 + async function showAISelector() { 492 + const res = await fetch('/api/available-ais'); 493 + const ais = await res.json(); 494 + 495 + const area = document.getElementById('game-area'); 496 + 497 + if (!ais || ais.length === 0) { 498 + area.innerHTML = '<div class="ai-selector"><h2>No AIs Available</h2><p>No AI submissions are ready to play against yet.</p></div>'; 499 + return; 500 + } 501 + 502 + let html = '<div class="ai-selector"><h2>Choose an Opponent</h2><div class="ai-list">'; 503 + 504 + for (const ai of ais) { 505 + html += '<div class="ai-card" onclick="startGame(\'' + ai.name + '\')">' + 506 + '<div class="ai-name">' + ai.name + '</div>' + 507 + '<div class="ai-rating">Rating: ' + ai.rating + '</div></div>'; 508 + } 509 + 510 + html += '</div></div>'; 511 + area.innerHTML = html; 512 + } 513 + 514 + function startGame(aiName) { 515 + selectedAI = aiName; 516 + 517 + // Connect WebSocket 518 + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; 519 + ws = new WebSocket(protocol + '//' + window.location.host + '/ws/game'); 520 + 521 + ws.onopen = () => { 522 + ws.send(JSON.stringify({ 523 + type: 'start_game', 524 + payload: { aiName: aiName } 525 + })); 526 + }; 527 + 528 + ws.onmessage = (event) => { 529 + const msg = JSON.parse(event.data); 530 + handleMessage(msg); 531 + }; 532 + 533 + ws.onclose = () => { 534 + console.log('WebSocket closed'); 535 + }; 536 + 537 + ws.onerror = (err) => { 538 + console.error('WebSocket error:', err); 539 + document.getElementById('game-area').innerHTML = 540 + '<div class="ai-selector"><h2>Connection Error</h2><p>Failed to connect to game server.</p>' + 541 + '<button class="btn" onclick="location.reload()">Retry</button></div>'; 542 + }; 543 + 544 + document.getElementById('game-area').innerHTML = '<div id="loading">Starting game against ' + aiName + '...</div>'; 545 + } 546 + 547 + function handleMessage(msg) { 548 + switch (msg.type) { 549 + case 'game_state': 550 + gameState = msg.payload; 551 + renderGame(); 552 + break; 553 + 554 + case 'human_move_result': 555 + case 'ai_move_result': 556 + addMoveLog(msg.type === 'human_move_result' ? 'You' : 'AI', msg.payload); 557 + break; 558 + 559 + case 'error': 560 + showError(msg.payload.error); 561 + break; 562 + 563 + case 'game_ended': 564 + showAISelector(); 565 + break; 566 + } 567 + } 568 + 569 + let moveLog = []; 570 + 571 + function showError(message) { 572 + // Remove existing toast if any 573 + const existing = document.querySelector('.error-toast'); 574 + if (existing) existing.remove(); 575 + 576 + const toast = document.createElement('div'); 577 + toast.className = 'error-toast'; 578 + toast.textContent = message; 579 + document.body.appendChild(toast); 580 + 581 + setTimeout(() => { 582 + toast.classList.add('fade-out'); 583 + setTimeout(() => toast.remove(), 300); 584 + }, 4000); 585 + } 586 + 587 + function addMoveLog(player, result) { 588 + let text = player + ' fired at ' + result.move + ': '; 589 + let cls = 'miss'; 590 + 591 + if (result.sunk) { 592 + text += 'Sunk ' + result.shipName + '!'; 593 + cls = 'sunk'; 594 + } else if (result.hit) { 595 + text += 'Hit!'; 596 + cls = 'hit'; 597 + } else { 598 + text += 'Miss'; 599 + } 600 + 601 + moveLog.unshift({ text, cls }); 602 + if (moveLog.length > 20) moveLog.pop(); 603 + } 604 + 605 + function renderGame() { 606 + if (!gameState) return; 607 + 608 + let statusClass = 'your-turn'; 609 + let statusText = "Your turn - click enemy board to fire!"; 610 + 611 + if (gameState.gameOver) { 612 + statusClass = 'game-over'; 613 + statusText = gameState.winner === 'human' ? '🎉 You Won!' : '💥 AI Wins!'; 614 + } else if (gameState.currentTurn === 'ai') { 615 + statusClass = 'ai-turn'; 616 + statusText = "AI is thinking..."; 617 + } 618 + 619 + let html = '<div class="status-bar ' + statusClass + '">' + statusText + '</div>'; 620 + 621 + html += '<div class="game-container">'; 622 + 623 + // Enemy board (for attacking) 624 + html += '<div class="board-section">'; 625 + html += '<div class="board-title">Enemy Fleet (' + gameState.aiName + ')</div>'; 626 + html += renderBoard(gameState.enemyBoard, true); 627 + html += renderShips(gameState.enemyShips, true); 628 + html += '</div>'; 629 + 630 + // Own board 631 + html += '<div class="board-section">'; 632 + html += '<div class="board-title">Your Fleet</div>'; 633 + html += renderBoard(gameState.ownBoard, false); 634 + html += renderShips(gameState.ownShips, false); 635 + html += '</div>'; 636 + 637 + html += '</div>'; 638 + 639 + // Move log 640 + html += '<div class="move-log"><div class="move-log-title">Battle Log</div><div class="move-log-content">'; 641 + for (const entry of moveLog) { 642 + html += '<div class="move-entry ' + entry.cls + '">' + entry.text + '</div>'; 643 + } 644 + html += '</div></div>'; 645 + 646 + // Play again button 647 + if (gameState.gameOver) { 648 + html += '<div style="text-align:center;margin-top:2rem;">' + 649 + '<button class="btn" onclick="location.reload()">Play Again</button></div>'; 650 + } 651 + 652 + document.getElementById('game-area').innerHTML = html; 653 + } 654 + 655 + function renderBoard(board, isEnemy) { 656 + const disabled = gameState.currentTurn !== 'human' || gameState.gameOver; 657 + let html = '<div class="board ' + (isEnemy ? 'enemy' : '') + (disabled ? ' disabled' : '') + '">'; 658 + 659 + // Header row 660 + html += '<div class="header-cell"></div>'; 661 + for (let c = 1; c <= 10; c++) { 662 + html += '<div class="header-cell">' + c + '</div>'; 663 + } 664 + 665 + // Board cells 666 + for (let r = 0; r < 10; r++) { 667 + html += '<div class="header-cell">' + String.fromCharCode(65 + r) + '</div>'; 668 + 669 + for (let c = 0; c < 10; c++) { 670 + const cell = board[r][c]; 671 + let cls = 'cell'; 672 + let content = ''; 673 + 674 + if (cell === CELL_HIT) { 675 + cls += ' hit'; 676 + content = '💥'; 677 + } else if (cell === CELL_MISS) { 678 + cls += ' miss'; 679 + content = '•'; 680 + } else if (cell === CELL_SUNK) { 681 + cls += ' sunk'; 682 + content = '☠️'; 683 + } else if (cell === CELL_SHIP && !isEnemy) { 684 + cls += ' ship'; 685 + content = '🚢'; 686 + } 687 + 688 + const onclick = isEnemy && !disabled ? ' onclick="fireAt(' + r + ',' + c + ')"' : ''; 689 + html += '<div class="' + cls + '"' + onclick + '>' + content + '</div>'; 690 + } 691 + } 692 + 693 + html += '</div>'; 694 + return html; 695 + } 696 + 697 + function renderShips(ships, isEnemy) { 698 + let html = '<div class="ships-panel"><div class="ships-title">' + 699 + (isEnemy ? 'Enemy Ships' : 'Your Ships') + '</div>'; 700 + 701 + for (const ship of ships) { 702 + const sunkClass = ship.sunk ? ' sunk' : ''; 703 + const icon = ship.sunk ? '💀' : '🚢'; 704 + html += '<div class="ship-row' + sunkClass + '">' + 705 + '<span class="ship-icon">' + icon + '</span>' + 706 + '<span class="ship-name">' + ship.name + ' (' + ship.size + ')</span>' + 707 + '</div>'; 708 + } 709 + 710 + html += '</div>'; 711 + return html; 712 + } 713 + 714 + function fireAt(row, col) { 715 + if (!ws || !gameState || gameState.currentTurn !== 'human' || gameState.gameOver) { 716 + return; 717 + } 718 + 719 + const move = String.fromCharCode(65 + row) + (col + 1); 720 + 721 + ws.send(JSON.stringify({ 722 + type: 'move', 723 + payload: { move: move } 724 + })); 725 + } 726 + 727 + init(); 728 + </script> 729 + </body> 730 + </html> 731 + `
+21 -1
internal/server/web.go
··· 248 248 color: #ef4444; 249 249 } 250 250 251 + .play-btn { 252 + display: inline-flex; 253 + align-items: center; 254 + justify-content: center; 255 + width: 32px; 256 + height: 32px; 257 + background: rgba(59, 130, 246, 0.1); 258 + border-radius: 6px; 259 + text-decoration: none; 260 + font-size: 1.25rem; 261 + transition: all 0.2s; 262 + } 263 + 264 + .play-btn:hover { 265 + background: rgba(59, 130, 246, 0.3); 266 + transform: scale(1.1); 267 + } 268 + 251 269 .info-card { 252 270 background: #1e293b; 253 271 border: 1px solid #334155; ··· 845 863 <th>Win Rate</th> 846 864 <th><span class="tooltip" data-tooltip="Average moves to win (lower is better)">Avg Moves</span></th> 847 865 <th>Last Active</th> 866 + <th>Play</th> 848 867 </tr> 849 868 </thead> 850 869 <tbody> ··· 859 878 <td>{{if $e.IsPending}}-{{else}}<span class="win-rate {{winRateClass $e}}">{{winRate $e}}%</span>{{end}}</td> 860 879 <td>{{if $e.IsPending}}-{{else}}{{printf "%.1f" $e.AvgMoves}}{{end}}</td> 861 880 <td style="color: #64748b;">{{if $e.IsPending}}Waiting...{{else}}{{$e.LastPlayed.Format "Jan 2, 3:04 PM"}}{{end}}</td> 881 + <td>{{if and (not $e.IsPending) (not $e.IsBroken) $e.AIName}}<a href="/play/{{$e.AIName}}" class="play-btn" title="Play against {{$e.Username}}">🎮</a>{{else}}-{{end}}</td> 862 882 </tr> 863 883 {{end}} 864 884 {{else}} 865 885 <tr> 866 - <td colspan="8"> 886 + <td colspan="9"> 867 887 <div class="empty-state"> 868 888 <div class="empty-state-icon">🎯</div> 869 889 <div>No submissions yet. Be the first to compete!</div>
+30 -6
internal/storage/database.go
··· 23 23 IsPending bool 24 24 IsBroken bool 25 25 FailureMessage string 26 + AIName string // extracted from filename (e.g., "klukas" from "memory_functions_klukas.cpp") 26 27 } 27 28 28 29 type Submission struct { ··· 208 209 MAX(m.timestamp) as last_played, 209 210 0 as is_pending, 210 211 0 as is_broken, 211 - '' as failure_message 212 + '' as failure_message, 213 + s.filename 212 214 FROM submissions s 213 215 LEFT JOIN matches m ON (m.player1_id = s.id OR m.player2_id = s.id) AND m.is_valid = 1 214 216 WHERE s.is_active = 1 AND s.status NOT IN ('compilation_failed', 'match_failed') 215 - GROUP BY s.username, s.glicko_rating, s.glicko_rd 217 + GROUP BY s.username, s.glicko_rating, s.glicko_rd, s.filename 216 218 HAVING COUNT(m.id) > 0 217 219 218 220 UNION ALL ··· 227 229 s.upload_time as last_played, 228 230 1 as is_pending, 229 231 0 as is_broken, 230 - '' as failure_message 232 + '' as failure_message, 233 + s.filename 231 234 FROM submissions s 232 235 LEFT JOIN matches m ON (m.player1_id = s.id OR m.player2_id = s.id) AND m.is_valid = 1 233 236 WHERE s.is_active = 1 AND s.status IN ('pending', 'testing', 'completed') 234 - GROUP BY s.username, s.upload_time 237 + GROUP BY s.username, s.upload_time, s.filename 235 238 HAVING COUNT(m.id) = 0 236 239 237 240 UNION ALL ··· 246 249 s.upload_time as last_played, 247 250 0 as is_pending, 248 251 1 as is_broken, 249 - COALESCE(s.failure_message, '') as failure_message 252 + COALESCE(s.failure_message, '') as failure_message, 253 + s.filename 250 254 FROM submissions s 251 255 WHERE s.is_active = 1 AND s.status IN ('compilation_failed', 'match_failed') 252 256 ··· 266 270 var lastPlayed string 267 271 var rating, rd float64 268 272 var isPending, isBroken int 269 - err := rows.Scan(&e.Username, &rating, &rd, &e.Wins, &e.Losses, &e.AvgMoves, &lastPlayed, &isPending, &isBroken, &e.FailureMessage) 273 + var filename string 274 + err := rows.Scan(&e.Username, &rating, &rd, &e.Wins, &e.Losses, &e.AvgMoves, &lastPlayed, &isPending, &isBroken, &e.FailureMessage, &filename) 270 275 if err != nil { 271 276 return nil, err 272 277 } ··· 276 281 e.IsPending = isPending == 1 277 282 e.IsBroken = isBroken == 1 278 283 284 + // Extract AI name from filename (e.g., "memory_functions_klukas.cpp" -> "klukas") 285 + e.AIName = extractAIName(filename) 286 + 279 287 totalGames := e.Wins + e.Losses 280 288 if totalGames > 0 { 281 289 e.WinPct = float64(e.Wins) / float64(totalGames) * 100.0 ··· 286 294 } 287 295 288 296 return entries, rows.Err() 297 + } 298 + 299 + func extractAIName(filename string) string { 300 + // Extract AI name from "memory_functions_NAME.cpp" 301 + if len(filename) < 22 { 302 + return "" 303 + } 304 + // Remove "memory_functions_" prefix and ".cpp" suffix 305 + name := filename 306 + if len(name) > 17 && name[:17] == "memory_functions_" { 307 + name = name[17:] 308 + } 309 + if len(name) > 4 && name[len(name)-4:] == ".cpp" { 310 + name = name[:len(name)-4] 311 + } 312 + return name 289 313 } 290 314 291 315 func AddSubmission(username, filename string) (int64, error) {
+25
vendor/github.com/gorilla/websocket/.gitignore
··· 1 + # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 + *.o 3 + *.a 4 + *.so 5 + 6 + # Folders 7 + _obj 8 + _test 9 + 10 + # Architecture specific extensions/prefixes 11 + *.[568vq] 12 + [568vq].out 13 + 14 + *.cgo1.go 15 + *.cgo2.c 16 + _cgo_defun.c 17 + _cgo_gotypes.go 18 + _cgo_export.* 19 + 20 + _testmain.go 21 + 22 + *.exe 23 + 24 + .idea/ 25 + *.iml
+9
vendor/github.com/gorilla/websocket/AUTHORS
··· 1 + # This is the official list of Gorilla WebSocket authors for copyright 2 + # purposes. 3 + # 4 + # Please keep the list sorted. 5 + 6 + Gary Burd <gary@beagledreams.com> 7 + Google LLC (https://opensource.google.com/) 8 + Joachim Bauch <mail@joachim-bauch.de> 9 +
+22
vendor/github.com/gorilla/websocket/LICENSE
··· 1 + Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + 3 + Redistribution and use in source and binary forms, with or without 4 + modification, are permitted provided that the following conditions are met: 5 + 6 + Redistributions of source code must retain the above copyright notice, this 7 + list of conditions and the following disclaimer. 8 + 9 + Redistributions in binary form must reproduce the above copyright notice, 10 + this list of conditions and the following disclaimer in the documentation 11 + and/or other materials provided with the distribution. 12 + 13 + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 14 + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 15 + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 16 + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 17 + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 18 + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 19 + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 20 + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 21 + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+33
vendor/github.com/gorilla/websocket/README.md
··· 1 + # Gorilla WebSocket 2 + 3 + [![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) 4 + [![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) 5 + 6 + Gorilla WebSocket is a [Go](http://golang.org/) implementation of the 7 + [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. 8 + 9 + 10 + ### Documentation 11 + 12 + * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) 13 + * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) 14 + * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) 15 + * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) 16 + * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) 17 + 18 + ### Status 19 + 20 + The Gorilla WebSocket package provides a complete and tested implementation of 21 + the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The 22 + package API is stable. 23 + 24 + ### Installation 25 + 26 + go get github.com/gorilla/websocket 27 + 28 + ### Protocol Compliance 29 + 30 + The Gorilla WebSocket package passes the server tests in the [Autobahn Test 31 + Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn 32 + subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). 33 +
+434
vendor/github.com/gorilla/websocket/client.go
··· 1 + // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "bytes" 9 + "context" 10 + "crypto/tls" 11 + "errors" 12 + "fmt" 13 + "io" 14 + "io/ioutil" 15 + "net" 16 + "net/http" 17 + "net/http/httptrace" 18 + "net/url" 19 + "strings" 20 + "time" 21 + ) 22 + 23 + // ErrBadHandshake is returned when the server response to opening handshake is 24 + // invalid. 25 + var ErrBadHandshake = errors.New("websocket: bad handshake") 26 + 27 + var errInvalidCompression = errors.New("websocket: invalid compression negotiation") 28 + 29 + // NewClient creates a new client connection using the given net connection. 30 + // The URL u specifies the host and request URI. Use requestHeader to specify 31 + // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies 32 + // (Cookie). Use the response.Header to get the selected subprotocol 33 + // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 34 + // 35 + // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 36 + // non-nil *http.Response so that callers can handle redirects, authentication, 37 + // etc. 38 + // 39 + // Deprecated: Use Dialer instead. 40 + func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { 41 + d := Dialer{ 42 + ReadBufferSize: readBufSize, 43 + WriteBufferSize: writeBufSize, 44 + NetDial: func(net, addr string) (net.Conn, error) { 45 + return netConn, nil 46 + }, 47 + } 48 + return d.Dial(u.String(), requestHeader) 49 + } 50 + 51 + // A Dialer contains options for connecting to WebSocket server. 52 + // 53 + // It is safe to call Dialer's methods concurrently. 54 + type Dialer struct { 55 + // NetDial specifies the dial function for creating TCP connections. If 56 + // NetDial is nil, net.Dial is used. 57 + NetDial func(network, addr string) (net.Conn, error) 58 + 59 + // NetDialContext specifies the dial function for creating TCP connections. If 60 + // NetDialContext is nil, NetDial is used. 61 + NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) 62 + 63 + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If 64 + // NetDialTLSContext is nil, NetDialContext is used. 65 + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and 66 + // TLSClientConfig is ignored. 67 + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) 68 + 69 + // Proxy specifies a function to return a proxy for a given 70 + // Request. If the function returns a non-nil error, the 71 + // request is aborted with the provided error. 72 + // If Proxy is nil or returns a nil *URL, no proxy is used. 73 + Proxy func(*http.Request) (*url.URL, error) 74 + 75 + // TLSClientConfig specifies the TLS configuration to use with tls.Client. 76 + // If nil, the default configuration is used. 77 + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake 78 + // is done there and TLSClientConfig is ignored. 79 + TLSClientConfig *tls.Config 80 + 81 + // HandshakeTimeout specifies the duration for the handshake to complete. 82 + HandshakeTimeout time.Duration 83 + 84 + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 85 + // size is zero, then a useful default size is used. The I/O buffer sizes 86 + // do not limit the size of the messages that can be sent or received. 87 + ReadBufferSize, WriteBufferSize int 88 + 89 + // WriteBufferPool is a pool of buffers for write operations. If the value 90 + // is not set, then write buffers are allocated to the connection for the 91 + // lifetime of the connection. 92 + // 93 + // A pool is most useful when the application has a modest volume of writes 94 + // across a large number of connections. 95 + // 96 + // Applications should use a single pool for each unique value of 97 + // WriteBufferSize. 98 + WriteBufferPool BufferPool 99 + 100 + // Subprotocols specifies the client's requested subprotocols. 101 + Subprotocols []string 102 + 103 + // EnableCompression specifies if the client should attempt to negotiate 104 + // per message compression (RFC 7692). Setting this value to true does not 105 + // guarantee that compression will be supported. Currently only "no context 106 + // takeover" modes are supported. 107 + EnableCompression bool 108 + 109 + // Jar specifies the cookie jar. 110 + // If Jar is nil, cookies are not sent in requests and ignored 111 + // in responses. 112 + Jar http.CookieJar 113 + } 114 + 115 + // Dial creates a new client connection by calling DialContext with a background context. 116 + func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 117 + return d.DialContext(context.Background(), urlStr, requestHeader) 118 + } 119 + 120 + var errMalformedURL = errors.New("malformed ws or wss URL") 121 + 122 + func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { 123 + hostPort = u.Host 124 + hostNoPort = u.Host 125 + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { 126 + hostNoPort = hostNoPort[:i] 127 + } else { 128 + switch u.Scheme { 129 + case "wss": 130 + hostPort += ":443" 131 + case "https": 132 + hostPort += ":443" 133 + default: 134 + hostPort += ":80" 135 + } 136 + } 137 + return hostPort, hostNoPort 138 + } 139 + 140 + // DefaultDialer is a dialer with all fields set to the default values. 141 + var DefaultDialer = &Dialer{ 142 + Proxy: http.ProxyFromEnvironment, 143 + HandshakeTimeout: 45 * time.Second, 144 + } 145 + 146 + // nilDialer is dialer to use when receiver is nil. 147 + var nilDialer = *DefaultDialer 148 + 149 + // DialContext creates a new client connection. Use requestHeader to specify the 150 + // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). 151 + // Use the response.Header to get the selected subprotocol 152 + // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). 153 + // 154 + // The context will be used in the request and in the Dialer. 155 + // 156 + // If the WebSocket handshake fails, ErrBadHandshake is returned along with a 157 + // non-nil *http.Response so that callers can handle redirects, authentication, 158 + // etcetera. The response body may not contain the entire response and does not 159 + // need to be closed by the application. 160 + func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { 161 + if d == nil { 162 + d = &nilDialer 163 + } 164 + 165 + challengeKey, err := generateChallengeKey() 166 + if err != nil { 167 + return nil, nil, err 168 + } 169 + 170 + u, err := url.Parse(urlStr) 171 + if err != nil { 172 + return nil, nil, err 173 + } 174 + 175 + switch u.Scheme { 176 + case "ws": 177 + u.Scheme = "http" 178 + case "wss": 179 + u.Scheme = "https" 180 + default: 181 + return nil, nil, errMalformedURL 182 + } 183 + 184 + if u.User != nil { 185 + // User name and password are not allowed in websocket URIs. 186 + return nil, nil, errMalformedURL 187 + } 188 + 189 + req := &http.Request{ 190 + Method: http.MethodGet, 191 + URL: u, 192 + Proto: "HTTP/1.1", 193 + ProtoMajor: 1, 194 + ProtoMinor: 1, 195 + Header: make(http.Header), 196 + Host: u.Host, 197 + } 198 + req = req.WithContext(ctx) 199 + 200 + // Set the cookies present in the cookie jar of the dialer 201 + if d.Jar != nil { 202 + for _, cookie := range d.Jar.Cookies(u) { 203 + req.AddCookie(cookie) 204 + } 205 + } 206 + 207 + // Set the request headers using the capitalization for names and values in 208 + // RFC examples. Although the capitalization shouldn't matter, there are 209 + // servers that depend on it. The Header.Set method is not used because the 210 + // method canonicalizes the header names. 211 + req.Header["Upgrade"] = []string{"websocket"} 212 + req.Header["Connection"] = []string{"Upgrade"} 213 + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} 214 + req.Header["Sec-WebSocket-Version"] = []string{"13"} 215 + if len(d.Subprotocols) > 0 { 216 + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} 217 + } 218 + for k, vs := range requestHeader { 219 + switch { 220 + case k == "Host": 221 + if len(vs) > 0 { 222 + req.Host = vs[0] 223 + } 224 + case k == "Upgrade" || 225 + k == "Connection" || 226 + k == "Sec-Websocket-Key" || 227 + k == "Sec-Websocket-Version" || 228 + k == "Sec-Websocket-Extensions" || 229 + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): 230 + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) 231 + case k == "Sec-Websocket-Protocol": 232 + req.Header["Sec-WebSocket-Protocol"] = vs 233 + default: 234 + req.Header[k] = vs 235 + } 236 + } 237 + 238 + if d.EnableCompression { 239 + req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} 240 + } 241 + 242 + if d.HandshakeTimeout != 0 { 243 + var cancel func() 244 + ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) 245 + defer cancel() 246 + } 247 + 248 + // Get network dial function. 249 + var netDial func(network, add string) (net.Conn, error) 250 + 251 + switch u.Scheme { 252 + case "http": 253 + if d.NetDialContext != nil { 254 + netDial = func(network, addr string) (net.Conn, error) { 255 + return d.NetDialContext(ctx, network, addr) 256 + } 257 + } else if d.NetDial != nil { 258 + netDial = d.NetDial 259 + } 260 + case "https": 261 + if d.NetDialTLSContext != nil { 262 + netDial = func(network, addr string) (net.Conn, error) { 263 + return d.NetDialTLSContext(ctx, network, addr) 264 + } 265 + } else if d.NetDialContext != nil { 266 + netDial = func(network, addr string) (net.Conn, error) { 267 + return d.NetDialContext(ctx, network, addr) 268 + } 269 + } else if d.NetDial != nil { 270 + netDial = d.NetDial 271 + } 272 + default: 273 + return nil, nil, errMalformedURL 274 + } 275 + 276 + if netDial == nil { 277 + netDialer := &net.Dialer{} 278 + netDial = func(network, addr string) (net.Conn, error) { 279 + return netDialer.DialContext(ctx, network, addr) 280 + } 281 + } 282 + 283 + // If needed, wrap the dial function to set the connection deadline. 284 + if deadline, ok := ctx.Deadline(); ok { 285 + forwardDial := netDial 286 + netDial = func(network, addr string) (net.Conn, error) { 287 + c, err := forwardDial(network, addr) 288 + if err != nil { 289 + return nil, err 290 + } 291 + err = c.SetDeadline(deadline) 292 + if err != nil { 293 + c.Close() 294 + return nil, err 295 + } 296 + return c, nil 297 + } 298 + } 299 + 300 + // If needed, wrap the dial function to connect through a proxy. 301 + if d.Proxy != nil { 302 + proxyURL, err := d.Proxy(req) 303 + if err != nil { 304 + return nil, nil, err 305 + } 306 + if proxyURL != nil { 307 + dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) 308 + if err != nil { 309 + return nil, nil, err 310 + } 311 + netDial = dialer.Dial 312 + } 313 + } 314 + 315 + hostPort, hostNoPort := hostPortNoPort(u) 316 + trace := httptrace.ContextClientTrace(ctx) 317 + if trace != nil && trace.GetConn != nil { 318 + trace.GetConn(hostPort) 319 + } 320 + 321 + netConn, err := netDial("tcp", hostPort) 322 + if err != nil { 323 + return nil, nil, err 324 + } 325 + if trace != nil && trace.GotConn != nil { 326 + trace.GotConn(httptrace.GotConnInfo{ 327 + Conn: netConn, 328 + }) 329 + } 330 + 331 + defer func() { 332 + if netConn != nil { 333 + netConn.Close() 334 + } 335 + }() 336 + 337 + if u.Scheme == "https" && d.NetDialTLSContext == nil { 338 + // If NetDialTLSContext is set, assume that the TLS handshake has already been done 339 + 340 + cfg := cloneTLSConfig(d.TLSClientConfig) 341 + if cfg.ServerName == "" { 342 + cfg.ServerName = hostNoPort 343 + } 344 + tlsConn := tls.Client(netConn, cfg) 345 + netConn = tlsConn 346 + 347 + if trace != nil && trace.TLSHandshakeStart != nil { 348 + trace.TLSHandshakeStart() 349 + } 350 + err := doHandshake(ctx, tlsConn, cfg) 351 + if trace != nil && trace.TLSHandshakeDone != nil { 352 + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) 353 + } 354 + 355 + if err != nil { 356 + return nil, nil, err 357 + } 358 + } 359 + 360 + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) 361 + 362 + if err := req.Write(netConn); err != nil { 363 + return nil, nil, err 364 + } 365 + 366 + if trace != nil && trace.GotFirstResponseByte != nil { 367 + if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { 368 + trace.GotFirstResponseByte() 369 + } 370 + } 371 + 372 + resp, err := http.ReadResponse(conn.br, req) 373 + if err != nil { 374 + if d.TLSClientConfig != nil { 375 + for _, proto := range d.TLSClientConfig.NextProtos { 376 + if proto != "http/1.1" { 377 + return nil, nil, fmt.Errorf( 378 + "websocket: protocol %q was given but is not supported;"+ 379 + "sharing tls.Config with net/http Transport can cause this error: %w", 380 + proto, err, 381 + ) 382 + } 383 + } 384 + } 385 + return nil, nil, err 386 + } 387 + 388 + if d.Jar != nil { 389 + if rc := resp.Cookies(); len(rc) > 0 { 390 + d.Jar.SetCookies(u, rc) 391 + } 392 + } 393 + 394 + if resp.StatusCode != 101 || 395 + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || 396 + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || 397 + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { 398 + // Before closing the network connection on return from this 399 + // function, slurp up some of the response to aid application 400 + // debugging. 401 + buf := make([]byte, 1024) 402 + n, _ := io.ReadFull(resp.Body, buf) 403 + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) 404 + return nil, resp, ErrBadHandshake 405 + } 406 + 407 + for _, ext := range parseExtensions(resp.Header) { 408 + if ext[""] != "permessage-deflate" { 409 + continue 410 + } 411 + _, snct := ext["server_no_context_takeover"] 412 + _, cnct := ext["client_no_context_takeover"] 413 + if !snct || !cnct { 414 + return nil, resp, errInvalidCompression 415 + } 416 + conn.newCompressionWriter = compressNoContextTakeover 417 + conn.newDecompressionReader = decompressNoContextTakeover 418 + break 419 + } 420 + 421 + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) 422 + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") 423 + 424 + netConn.SetDeadline(time.Time{}) 425 + netConn = nil // to avoid close in defer. 426 + return conn, resp, nil 427 + } 428 + 429 + func cloneTLSConfig(cfg *tls.Config) *tls.Config { 430 + if cfg == nil { 431 + return &tls.Config{} 432 + } 433 + return cfg.Clone() 434 + }
+148
vendor/github.com/gorilla/websocket/compression.go
··· 1 + // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "compress/flate" 9 + "errors" 10 + "io" 11 + "strings" 12 + "sync" 13 + ) 14 + 15 + const ( 16 + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 17 + maxCompressionLevel = flate.BestCompression 18 + defaultCompressionLevel = 1 19 + ) 20 + 21 + var ( 22 + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool 23 + flateReaderPool = sync.Pool{New: func() interface{} { 24 + return flate.NewReader(nil) 25 + }} 26 + ) 27 + 28 + func decompressNoContextTakeover(r io.Reader) io.ReadCloser { 29 + const tail = 30 + // Add four bytes as specified in RFC 31 + "\x00\x00\xff\xff" + 32 + // Add final block to squelch unexpected EOF error from flate reader. 33 + "\x01\x00\x00\xff\xff" 34 + 35 + fr, _ := flateReaderPool.Get().(io.ReadCloser) 36 + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) 37 + return &flateReadWrapper{fr} 38 + } 39 + 40 + func isValidCompressionLevel(level int) bool { 41 + return minCompressionLevel <= level && level <= maxCompressionLevel 42 + } 43 + 44 + func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { 45 + p := &flateWriterPools[level-minCompressionLevel] 46 + tw := &truncWriter{w: w} 47 + fw, _ := p.Get().(*flate.Writer) 48 + if fw == nil { 49 + fw, _ = flate.NewWriter(tw, level) 50 + } else { 51 + fw.Reset(tw) 52 + } 53 + return &flateWriteWrapper{fw: fw, tw: tw, p: p} 54 + } 55 + 56 + // truncWriter is an io.Writer that writes all but the last four bytes of the 57 + // stream to another io.Writer. 58 + type truncWriter struct { 59 + w io.WriteCloser 60 + n int 61 + p [4]byte 62 + } 63 + 64 + func (w *truncWriter) Write(p []byte) (int, error) { 65 + n := 0 66 + 67 + // fill buffer first for simplicity. 68 + if w.n < len(w.p) { 69 + n = copy(w.p[w.n:], p) 70 + p = p[n:] 71 + w.n += n 72 + if len(p) == 0 { 73 + return n, nil 74 + } 75 + } 76 + 77 + m := len(p) 78 + if m > len(w.p) { 79 + m = len(w.p) 80 + } 81 + 82 + if nn, err := w.w.Write(w.p[:m]); err != nil { 83 + return n + nn, err 84 + } 85 + 86 + copy(w.p[:], w.p[m:]) 87 + copy(w.p[len(w.p)-m:], p[len(p)-m:]) 88 + nn, err := w.w.Write(p[:len(p)-m]) 89 + return n + nn, err 90 + } 91 + 92 + type flateWriteWrapper struct { 93 + fw *flate.Writer 94 + tw *truncWriter 95 + p *sync.Pool 96 + } 97 + 98 + func (w *flateWriteWrapper) Write(p []byte) (int, error) { 99 + if w.fw == nil { 100 + return 0, errWriteClosed 101 + } 102 + return w.fw.Write(p) 103 + } 104 + 105 + func (w *flateWriteWrapper) Close() error { 106 + if w.fw == nil { 107 + return errWriteClosed 108 + } 109 + err1 := w.fw.Flush() 110 + w.p.Put(w.fw) 111 + w.fw = nil 112 + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { 113 + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") 114 + } 115 + err2 := w.tw.w.Close() 116 + if err1 != nil { 117 + return err1 118 + } 119 + return err2 120 + } 121 + 122 + type flateReadWrapper struct { 123 + fr io.ReadCloser 124 + } 125 + 126 + func (r *flateReadWrapper) Read(p []byte) (int, error) { 127 + if r.fr == nil { 128 + return 0, io.ErrClosedPipe 129 + } 130 + n, err := r.fr.Read(p) 131 + if err == io.EOF { 132 + // Preemptively place the reader back in the pool. This helps with 133 + // scenarios where the application does not call NextReader() soon after 134 + // this final read. 135 + r.Close() 136 + } 137 + return n, err 138 + } 139 + 140 + func (r *flateReadWrapper) Close() error { 141 + if r.fr == nil { 142 + return io.ErrClosedPipe 143 + } 144 + err := r.fr.Close() 145 + flateReaderPool.Put(r.fr) 146 + r.fr = nil 147 + return err 148 + }
+1238
vendor/github.com/gorilla/websocket/conn.go
··· 1 + // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "bufio" 9 + "encoding/binary" 10 + "errors" 11 + "io" 12 + "io/ioutil" 13 + "math/rand" 14 + "net" 15 + "strconv" 16 + "strings" 17 + "sync" 18 + "time" 19 + "unicode/utf8" 20 + ) 21 + 22 + const ( 23 + // Frame header byte 0 bits from Section 5.2 of RFC 6455 24 + finalBit = 1 << 7 25 + rsv1Bit = 1 << 6 26 + rsv2Bit = 1 << 5 27 + rsv3Bit = 1 << 4 28 + 29 + // Frame header byte 1 bits from Section 5.2 of RFC 6455 30 + maskBit = 1 << 7 31 + 32 + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask 33 + maxControlFramePayloadSize = 125 34 + 35 + writeWait = time.Second 36 + 37 + defaultReadBufferSize = 4096 38 + defaultWriteBufferSize = 4096 39 + 40 + continuationFrame = 0 41 + noFrame = -1 42 + ) 43 + 44 + // Close codes defined in RFC 6455, section 11.7. 45 + const ( 46 + CloseNormalClosure = 1000 47 + CloseGoingAway = 1001 48 + CloseProtocolError = 1002 49 + CloseUnsupportedData = 1003 50 + CloseNoStatusReceived = 1005 51 + CloseAbnormalClosure = 1006 52 + CloseInvalidFramePayloadData = 1007 53 + ClosePolicyViolation = 1008 54 + CloseMessageTooBig = 1009 55 + CloseMandatoryExtension = 1010 56 + CloseInternalServerErr = 1011 57 + CloseServiceRestart = 1012 58 + CloseTryAgainLater = 1013 59 + CloseTLSHandshake = 1015 60 + ) 61 + 62 + // The message types are defined in RFC 6455, section 11.8. 63 + const ( 64 + // TextMessage denotes a text data message. The text message payload is 65 + // interpreted as UTF-8 encoded text data. 66 + TextMessage = 1 67 + 68 + // BinaryMessage denotes a binary data message. 69 + BinaryMessage = 2 70 + 71 + // CloseMessage denotes a close control message. The optional message 72 + // payload contains a numeric code and text. Use the FormatCloseMessage 73 + // function to format a close message payload. 74 + CloseMessage = 8 75 + 76 + // PingMessage denotes a ping control message. The optional message payload 77 + // is UTF-8 encoded text. 78 + PingMessage = 9 79 + 80 + // PongMessage denotes a pong control message. The optional message payload 81 + // is UTF-8 encoded text. 82 + PongMessage = 10 83 + ) 84 + 85 + // ErrCloseSent is returned when the application writes a message to the 86 + // connection after sending a close message. 87 + var ErrCloseSent = errors.New("websocket: close sent") 88 + 89 + // ErrReadLimit is returned when reading a message that is larger than the 90 + // read limit set for the connection. 91 + var ErrReadLimit = errors.New("websocket: read limit exceeded") 92 + 93 + // netError satisfies the net Error interface. 94 + type netError struct { 95 + msg string 96 + temporary bool 97 + timeout bool 98 + } 99 + 100 + func (e *netError) Error() string { return e.msg } 101 + func (e *netError) Temporary() bool { return e.temporary } 102 + func (e *netError) Timeout() bool { return e.timeout } 103 + 104 + // CloseError represents a close message. 105 + type CloseError struct { 106 + // Code is defined in RFC 6455, section 11.7. 107 + Code int 108 + 109 + // Text is the optional text payload. 110 + Text string 111 + } 112 + 113 + func (e *CloseError) Error() string { 114 + s := []byte("websocket: close ") 115 + s = strconv.AppendInt(s, int64(e.Code), 10) 116 + switch e.Code { 117 + case CloseNormalClosure: 118 + s = append(s, " (normal)"...) 119 + case CloseGoingAway: 120 + s = append(s, " (going away)"...) 121 + case CloseProtocolError: 122 + s = append(s, " (protocol error)"...) 123 + case CloseUnsupportedData: 124 + s = append(s, " (unsupported data)"...) 125 + case CloseNoStatusReceived: 126 + s = append(s, " (no status)"...) 127 + case CloseAbnormalClosure: 128 + s = append(s, " (abnormal closure)"...) 129 + case CloseInvalidFramePayloadData: 130 + s = append(s, " (invalid payload data)"...) 131 + case ClosePolicyViolation: 132 + s = append(s, " (policy violation)"...) 133 + case CloseMessageTooBig: 134 + s = append(s, " (message too big)"...) 135 + case CloseMandatoryExtension: 136 + s = append(s, " (mandatory extension missing)"...) 137 + case CloseInternalServerErr: 138 + s = append(s, " (internal server error)"...) 139 + case CloseTLSHandshake: 140 + s = append(s, " (TLS handshake error)"...) 141 + } 142 + if e.Text != "" { 143 + s = append(s, ": "...) 144 + s = append(s, e.Text...) 145 + } 146 + return string(s) 147 + } 148 + 149 + // IsCloseError returns boolean indicating whether the error is a *CloseError 150 + // with one of the specified codes. 151 + func IsCloseError(err error, codes ...int) bool { 152 + if e, ok := err.(*CloseError); ok { 153 + for _, code := range codes { 154 + if e.Code == code { 155 + return true 156 + } 157 + } 158 + } 159 + return false 160 + } 161 + 162 + // IsUnexpectedCloseError returns boolean indicating whether the error is a 163 + // *CloseError with a code not in the list of expected codes. 164 + func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { 165 + if e, ok := err.(*CloseError); ok { 166 + for _, code := range expectedCodes { 167 + if e.Code == code { 168 + return false 169 + } 170 + } 171 + return true 172 + } 173 + return false 174 + } 175 + 176 + var ( 177 + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} 178 + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} 179 + errBadWriteOpCode = errors.New("websocket: bad write message type") 180 + errWriteClosed = errors.New("websocket: write closed") 181 + errInvalidControlFrame = errors.New("websocket: invalid control frame") 182 + ) 183 + 184 + func newMaskKey() [4]byte { 185 + n := rand.Uint32() 186 + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} 187 + } 188 + 189 + func hideTempErr(err error) error { 190 + if e, ok := err.(net.Error); ok && e.Temporary() { 191 + err = &netError{msg: e.Error(), timeout: e.Timeout()} 192 + } 193 + return err 194 + } 195 + 196 + func isControl(frameType int) bool { 197 + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage 198 + } 199 + 200 + func isData(frameType int) bool { 201 + return frameType == TextMessage || frameType == BinaryMessage 202 + } 203 + 204 + var validReceivedCloseCodes = map[int]bool{ 205 + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number 206 + 207 + CloseNormalClosure: true, 208 + CloseGoingAway: true, 209 + CloseProtocolError: true, 210 + CloseUnsupportedData: true, 211 + CloseNoStatusReceived: false, 212 + CloseAbnormalClosure: false, 213 + CloseInvalidFramePayloadData: true, 214 + ClosePolicyViolation: true, 215 + CloseMessageTooBig: true, 216 + CloseMandatoryExtension: true, 217 + CloseInternalServerErr: true, 218 + CloseServiceRestart: true, 219 + CloseTryAgainLater: true, 220 + CloseTLSHandshake: false, 221 + } 222 + 223 + func isValidReceivedCloseCode(code int) bool { 224 + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) 225 + } 226 + 227 + // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this 228 + // interface. The type of the value stored in a pool is not specified. 229 + type BufferPool interface { 230 + // Get gets a value from the pool or returns nil if the pool is empty. 231 + Get() interface{} 232 + // Put adds a value to the pool. 233 + Put(interface{}) 234 + } 235 + 236 + // writePoolData is the type added to the write buffer pool. This wrapper is 237 + // used to prevent applications from peeking at and depending on the values 238 + // added to the pool. 239 + type writePoolData struct{ buf []byte } 240 + 241 + // The Conn type represents a WebSocket connection. 242 + type Conn struct { 243 + conn net.Conn 244 + isServer bool 245 + subprotocol string 246 + 247 + // Write fields 248 + mu chan struct{} // used as mutex to protect write to conn 249 + writeBuf []byte // frame is constructed in this buffer. 250 + writePool BufferPool 251 + writeBufSize int 252 + writeDeadline time.Time 253 + writer io.WriteCloser // the current writer returned to the application 254 + isWriting bool // for best-effort concurrent write detection 255 + 256 + writeErrMu sync.Mutex 257 + writeErr error 258 + 259 + enableWriteCompression bool 260 + compressionLevel int 261 + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser 262 + 263 + // Read fields 264 + reader io.ReadCloser // the current reader returned to the application 265 + readErr error 266 + br *bufio.Reader 267 + // bytes remaining in current frame. 268 + // set setReadRemaining to safely update this value and prevent overflow 269 + readRemaining int64 270 + readFinal bool // true the current message has more frames. 271 + readLength int64 // Message size. 272 + readLimit int64 // Maximum message size. 273 + readMaskPos int 274 + readMaskKey [4]byte 275 + handlePong func(string) error 276 + handlePing func(string) error 277 + handleClose func(int, string) error 278 + readErrCount int 279 + messageReader *messageReader // the current low-level reader 280 + 281 + readDecompress bool // whether last read frame had RSV1 set 282 + newDecompressionReader func(io.Reader) io.ReadCloser 283 + } 284 + 285 + func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { 286 + 287 + if br == nil { 288 + if readBufferSize == 0 { 289 + readBufferSize = defaultReadBufferSize 290 + } else if readBufferSize < maxControlFramePayloadSize { 291 + // must be large enough for control frame 292 + readBufferSize = maxControlFramePayloadSize 293 + } 294 + br = bufio.NewReaderSize(conn, readBufferSize) 295 + } 296 + 297 + if writeBufferSize <= 0 { 298 + writeBufferSize = defaultWriteBufferSize 299 + } 300 + writeBufferSize += maxFrameHeaderSize 301 + 302 + if writeBuf == nil && writeBufferPool == nil { 303 + writeBuf = make([]byte, writeBufferSize) 304 + } 305 + 306 + mu := make(chan struct{}, 1) 307 + mu <- struct{}{} 308 + c := &Conn{ 309 + isServer: isServer, 310 + br: br, 311 + conn: conn, 312 + mu: mu, 313 + readFinal: true, 314 + writeBuf: writeBuf, 315 + writePool: writeBufferPool, 316 + writeBufSize: writeBufferSize, 317 + enableWriteCompression: true, 318 + compressionLevel: defaultCompressionLevel, 319 + } 320 + c.SetCloseHandler(nil) 321 + c.SetPingHandler(nil) 322 + c.SetPongHandler(nil) 323 + return c 324 + } 325 + 326 + // setReadRemaining tracks the number of bytes remaining on the connection. If n 327 + // overflows, an ErrReadLimit is returned. 328 + func (c *Conn) setReadRemaining(n int64) error { 329 + if n < 0 { 330 + return ErrReadLimit 331 + } 332 + 333 + c.readRemaining = n 334 + return nil 335 + } 336 + 337 + // Subprotocol returns the negotiated protocol for the connection. 338 + func (c *Conn) Subprotocol() string { 339 + return c.subprotocol 340 + } 341 + 342 + // Close closes the underlying network connection without sending or waiting 343 + // for a close message. 344 + func (c *Conn) Close() error { 345 + return c.conn.Close() 346 + } 347 + 348 + // LocalAddr returns the local network address. 349 + func (c *Conn) LocalAddr() net.Addr { 350 + return c.conn.LocalAddr() 351 + } 352 + 353 + // RemoteAddr returns the remote network address. 354 + func (c *Conn) RemoteAddr() net.Addr { 355 + return c.conn.RemoteAddr() 356 + } 357 + 358 + // Write methods 359 + 360 + func (c *Conn) writeFatal(err error) error { 361 + err = hideTempErr(err) 362 + c.writeErrMu.Lock() 363 + if c.writeErr == nil { 364 + c.writeErr = err 365 + } 366 + c.writeErrMu.Unlock() 367 + return err 368 + } 369 + 370 + func (c *Conn) read(n int) ([]byte, error) { 371 + p, err := c.br.Peek(n) 372 + if err == io.EOF { 373 + err = errUnexpectedEOF 374 + } 375 + c.br.Discard(len(p)) 376 + return p, err 377 + } 378 + 379 + func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { 380 + <-c.mu 381 + defer func() { c.mu <- struct{}{} }() 382 + 383 + c.writeErrMu.Lock() 384 + err := c.writeErr 385 + c.writeErrMu.Unlock() 386 + if err != nil { 387 + return err 388 + } 389 + 390 + c.conn.SetWriteDeadline(deadline) 391 + if len(buf1) == 0 { 392 + _, err = c.conn.Write(buf0) 393 + } else { 394 + err = c.writeBufs(buf0, buf1) 395 + } 396 + if err != nil { 397 + return c.writeFatal(err) 398 + } 399 + if frameType == CloseMessage { 400 + c.writeFatal(ErrCloseSent) 401 + } 402 + return nil 403 + } 404 + 405 + func (c *Conn) writeBufs(bufs ...[]byte) error { 406 + b := net.Buffers(bufs) 407 + _, err := b.WriteTo(c.conn) 408 + return err 409 + } 410 + 411 + // WriteControl writes a control message with the given deadline. The allowed 412 + // message types are CloseMessage, PingMessage and PongMessage. 413 + func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { 414 + if !isControl(messageType) { 415 + return errBadWriteOpCode 416 + } 417 + if len(data) > maxControlFramePayloadSize { 418 + return errInvalidControlFrame 419 + } 420 + 421 + b0 := byte(messageType) | finalBit 422 + b1 := byte(len(data)) 423 + if !c.isServer { 424 + b1 |= maskBit 425 + } 426 + 427 + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) 428 + buf = append(buf, b0, b1) 429 + 430 + if c.isServer { 431 + buf = append(buf, data...) 432 + } else { 433 + key := newMaskKey() 434 + buf = append(buf, key[:]...) 435 + buf = append(buf, data...) 436 + maskBytes(key, 0, buf[6:]) 437 + } 438 + 439 + d := 1000 * time.Hour 440 + if !deadline.IsZero() { 441 + d = deadline.Sub(time.Now()) 442 + if d < 0 { 443 + return errWriteTimeout 444 + } 445 + } 446 + 447 + timer := time.NewTimer(d) 448 + select { 449 + case <-c.mu: 450 + timer.Stop() 451 + case <-timer.C: 452 + return errWriteTimeout 453 + } 454 + defer func() { c.mu <- struct{}{} }() 455 + 456 + c.writeErrMu.Lock() 457 + err := c.writeErr 458 + c.writeErrMu.Unlock() 459 + if err != nil { 460 + return err 461 + } 462 + 463 + c.conn.SetWriteDeadline(deadline) 464 + _, err = c.conn.Write(buf) 465 + if err != nil { 466 + return c.writeFatal(err) 467 + } 468 + if messageType == CloseMessage { 469 + c.writeFatal(ErrCloseSent) 470 + } 471 + return err 472 + } 473 + 474 + // beginMessage prepares a connection and message writer for a new message. 475 + func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { 476 + // Close previous writer if not already closed by the application. It's 477 + // probably better to return an error in this situation, but we cannot 478 + // change this without breaking existing applications. 479 + if c.writer != nil { 480 + c.writer.Close() 481 + c.writer = nil 482 + } 483 + 484 + if !isControl(messageType) && !isData(messageType) { 485 + return errBadWriteOpCode 486 + } 487 + 488 + c.writeErrMu.Lock() 489 + err := c.writeErr 490 + c.writeErrMu.Unlock() 491 + if err != nil { 492 + return err 493 + } 494 + 495 + mw.c = c 496 + mw.frameType = messageType 497 + mw.pos = maxFrameHeaderSize 498 + 499 + if c.writeBuf == nil { 500 + wpd, ok := c.writePool.Get().(writePoolData) 501 + if ok { 502 + c.writeBuf = wpd.buf 503 + } else { 504 + c.writeBuf = make([]byte, c.writeBufSize) 505 + } 506 + } 507 + return nil 508 + } 509 + 510 + // NextWriter returns a writer for the next message to send. The writer's Close 511 + // method flushes the complete message to the network. 512 + // 513 + // There can be at most one open writer on a connection. NextWriter closes the 514 + // previous writer if the application has not already done so. 515 + // 516 + // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and 517 + // PongMessage) are supported. 518 + func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { 519 + var mw messageWriter 520 + if err := c.beginMessage(&mw, messageType); err != nil { 521 + return nil, err 522 + } 523 + c.writer = &mw 524 + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { 525 + w := c.newCompressionWriter(c.writer, c.compressionLevel) 526 + mw.compress = true 527 + c.writer = w 528 + } 529 + return c.writer, nil 530 + } 531 + 532 + type messageWriter struct { 533 + c *Conn 534 + compress bool // whether next call to flushFrame should set RSV1 535 + pos int // end of data in writeBuf. 536 + frameType int // type of the current frame. 537 + err error 538 + } 539 + 540 + func (w *messageWriter) endMessage(err error) error { 541 + if w.err != nil { 542 + return err 543 + } 544 + c := w.c 545 + w.err = err 546 + c.writer = nil 547 + if c.writePool != nil { 548 + c.writePool.Put(writePoolData{buf: c.writeBuf}) 549 + c.writeBuf = nil 550 + } 551 + return err 552 + } 553 + 554 + // flushFrame writes buffered data and extra as a frame to the network. The 555 + // final argument indicates that this is the last frame in the message. 556 + func (w *messageWriter) flushFrame(final bool, extra []byte) error { 557 + c := w.c 558 + length := w.pos - maxFrameHeaderSize + len(extra) 559 + 560 + // Check for invalid control frames. 561 + if isControl(w.frameType) && 562 + (!final || length > maxControlFramePayloadSize) { 563 + return w.endMessage(errInvalidControlFrame) 564 + } 565 + 566 + b0 := byte(w.frameType) 567 + if final { 568 + b0 |= finalBit 569 + } 570 + if w.compress { 571 + b0 |= rsv1Bit 572 + } 573 + w.compress = false 574 + 575 + b1 := byte(0) 576 + if !c.isServer { 577 + b1 |= maskBit 578 + } 579 + 580 + // Assume that the frame starts at beginning of c.writeBuf. 581 + framePos := 0 582 + if c.isServer { 583 + // Adjust up if mask not included in the header. 584 + framePos = 4 585 + } 586 + 587 + switch { 588 + case length >= 65536: 589 + c.writeBuf[framePos] = b0 590 + c.writeBuf[framePos+1] = b1 | 127 591 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) 592 + case length > 125: 593 + framePos += 6 594 + c.writeBuf[framePos] = b0 595 + c.writeBuf[framePos+1] = b1 | 126 596 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) 597 + default: 598 + framePos += 8 599 + c.writeBuf[framePos] = b0 600 + c.writeBuf[framePos+1] = b1 | byte(length) 601 + } 602 + 603 + if !c.isServer { 604 + key := newMaskKey() 605 + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) 606 + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) 607 + if len(extra) > 0 { 608 + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) 609 + } 610 + } 611 + 612 + // Write the buffers to the connection with best-effort detection of 613 + // concurrent writes. See the concurrency section in the package 614 + // documentation for more info. 615 + 616 + if c.isWriting { 617 + panic("concurrent write to websocket connection") 618 + } 619 + c.isWriting = true 620 + 621 + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) 622 + 623 + if !c.isWriting { 624 + panic("concurrent write to websocket connection") 625 + } 626 + c.isWriting = false 627 + 628 + if err != nil { 629 + return w.endMessage(err) 630 + } 631 + 632 + if final { 633 + w.endMessage(errWriteClosed) 634 + return nil 635 + } 636 + 637 + // Setup for next frame. 638 + w.pos = maxFrameHeaderSize 639 + w.frameType = continuationFrame 640 + return nil 641 + } 642 + 643 + func (w *messageWriter) ncopy(max int) (int, error) { 644 + n := len(w.c.writeBuf) - w.pos 645 + if n <= 0 { 646 + if err := w.flushFrame(false, nil); err != nil { 647 + return 0, err 648 + } 649 + n = len(w.c.writeBuf) - w.pos 650 + } 651 + if n > max { 652 + n = max 653 + } 654 + return n, nil 655 + } 656 + 657 + func (w *messageWriter) Write(p []byte) (int, error) { 658 + if w.err != nil { 659 + return 0, w.err 660 + } 661 + 662 + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { 663 + // Don't buffer large messages. 664 + err := w.flushFrame(false, p) 665 + if err != nil { 666 + return 0, err 667 + } 668 + return len(p), nil 669 + } 670 + 671 + nn := len(p) 672 + for len(p) > 0 { 673 + n, err := w.ncopy(len(p)) 674 + if err != nil { 675 + return 0, err 676 + } 677 + copy(w.c.writeBuf[w.pos:], p[:n]) 678 + w.pos += n 679 + p = p[n:] 680 + } 681 + return nn, nil 682 + } 683 + 684 + func (w *messageWriter) WriteString(p string) (int, error) { 685 + if w.err != nil { 686 + return 0, w.err 687 + } 688 + 689 + nn := len(p) 690 + for len(p) > 0 { 691 + n, err := w.ncopy(len(p)) 692 + if err != nil { 693 + return 0, err 694 + } 695 + copy(w.c.writeBuf[w.pos:], p[:n]) 696 + w.pos += n 697 + p = p[n:] 698 + } 699 + return nn, nil 700 + } 701 + 702 + func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { 703 + if w.err != nil { 704 + return 0, w.err 705 + } 706 + for { 707 + if w.pos == len(w.c.writeBuf) { 708 + err = w.flushFrame(false, nil) 709 + if err != nil { 710 + break 711 + } 712 + } 713 + var n int 714 + n, err = r.Read(w.c.writeBuf[w.pos:]) 715 + w.pos += n 716 + nn += int64(n) 717 + if err != nil { 718 + if err == io.EOF { 719 + err = nil 720 + } 721 + break 722 + } 723 + } 724 + return nn, err 725 + } 726 + 727 + func (w *messageWriter) Close() error { 728 + if w.err != nil { 729 + return w.err 730 + } 731 + return w.flushFrame(true, nil) 732 + } 733 + 734 + // WritePreparedMessage writes prepared message into connection. 735 + func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { 736 + frameType, frameData, err := pm.frame(prepareKey{ 737 + isServer: c.isServer, 738 + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), 739 + compressionLevel: c.compressionLevel, 740 + }) 741 + if err != nil { 742 + return err 743 + } 744 + if c.isWriting { 745 + panic("concurrent write to websocket connection") 746 + } 747 + c.isWriting = true 748 + err = c.write(frameType, c.writeDeadline, frameData, nil) 749 + if !c.isWriting { 750 + panic("concurrent write to websocket connection") 751 + } 752 + c.isWriting = false 753 + return err 754 + } 755 + 756 + // WriteMessage is a helper method for getting a writer using NextWriter, 757 + // writing the message and closing the writer. 758 + func (c *Conn) WriteMessage(messageType int, data []byte) error { 759 + 760 + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { 761 + // Fast path with no allocations and single frame. 762 + 763 + var mw messageWriter 764 + if err := c.beginMessage(&mw, messageType); err != nil { 765 + return err 766 + } 767 + n := copy(c.writeBuf[mw.pos:], data) 768 + mw.pos += n 769 + data = data[n:] 770 + return mw.flushFrame(true, data) 771 + } 772 + 773 + w, err := c.NextWriter(messageType) 774 + if err != nil { 775 + return err 776 + } 777 + if _, err = w.Write(data); err != nil { 778 + return err 779 + } 780 + return w.Close() 781 + } 782 + 783 + // SetWriteDeadline sets the write deadline on the underlying network 784 + // connection. After a write has timed out, the websocket state is corrupt and 785 + // all future writes will return an error. A zero value for t means writes will 786 + // not time out. 787 + func (c *Conn) SetWriteDeadline(t time.Time) error { 788 + c.writeDeadline = t 789 + return nil 790 + } 791 + 792 + // Read methods 793 + 794 + func (c *Conn) advanceFrame() (int, error) { 795 + // 1. Skip remainder of previous frame. 796 + 797 + if c.readRemaining > 0 { 798 + if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { 799 + return noFrame, err 800 + } 801 + } 802 + 803 + // 2. Read and parse first two bytes of frame header. 804 + // To aid debugging, collect and report all errors in the first two bytes 805 + // of the header. 806 + 807 + var errors []string 808 + 809 + p, err := c.read(2) 810 + if err != nil { 811 + return noFrame, err 812 + } 813 + 814 + frameType := int(p[0] & 0xf) 815 + final := p[0]&finalBit != 0 816 + rsv1 := p[0]&rsv1Bit != 0 817 + rsv2 := p[0]&rsv2Bit != 0 818 + rsv3 := p[0]&rsv3Bit != 0 819 + mask := p[1]&maskBit != 0 820 + c.setReadRemaining(int64(p[1] & 0x7f)) 821 + 822 + c.readDecompress = false 823 + if rsv1 { 824 + if c.newDecompressionReader != nil { 825 + c.readDecompress = true 826 + } else { 827 + errors = append(errors, "RSV1 set") 828 + } 829 + } 830 + 831 + if rsv2 { 832 + errors = append(errors, "RSV2 set") 833 + } 834 + 835 + if rsv3 { 836 + errors = append(errors, "RSV3 set") 837 + } 838 + 839 + switch frameType { 840 + case CloseMessage, PingMessage, PongMessage: 841 + if c.readRemaining > maxControlFramePayloadSize { 842 + errors = append(errors, "len > 125 for control") 843 + } 844 + if !final { 845 + errors = append(errors, "FIN not set on control") 846 + } 847 + case TextMessage, BinaryMessage: 848 + if !c.readFinal { 849 + errors = append(errors, "data before FIN") 850 + } 851 + c.readFinal = final 852 + case continuationFrame: 853 + if c.readFinal { 854 + errors = append(errors, "continuation after FIN") 855 + } 856 + c.readFinal = final 857 + default: 858 + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) 859 + } 860 + 861 + if mask != c.isServer { 862 + errors = append(errors, "bad MASK") 863 + } 864 + 865 + if len(errors) > 0 { 866 + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) 867 + } 868 + 869 + // 3. Read and parse frame length as per 870 + // https://tools.ietf.org/html/rfc6455#section-5.2 871 + // 872 + // The length of the "Payload data", in bytes: if 0-125, that is the payload 873 + // length. 874 + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned 875 + // integer are the payload length. 876 + // - If 127, the following 8 bytes interpreted as 877 + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the 878 + // payload length. Multibyte length quantities are expressed in network byte 879 + // order. 880 + 881 + switch c.readRemaining { 882 + case 126: 883 + p, err := c.read(2) 884 + if err != nil { 885 + return noFrame, err 886 + } 887 + 888 + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { 889 + return noFrame, err 890 + } 891 + case 127: 892 + p, err := c.read(8) 893 + if err != nil { 894 + return noFrame, err 895 + } 896 + 897 + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { 898 + return noFrame, err 899 + } 900 + } 901 + 902 + // 4. Handle frame masking. 903 + 904 + if mask { 905 + c.readMaskPos = 0 906 + p, err := c.read(len(c.readMaskKey)) 907 + if err != nil { 908 + return noFrame, err 909 + } 910 + copy(c.readMaskKey[:], p) 911 + } 912 + 913 + // 5. For text and binary messages, enforce read limit and return. 914 + 915 + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { 916 + 917 + c.readLength += c.readRemaining 918 + // Don't allow readLength to overflow in the presence of a large readRemaining 919 + // counter. 920 + if c.readLength < 0 { 921 + return noFrame, ErrReadLimit 922 + } 923 + 924 + if c.readLimit > 0 && c.readLength > c.readLimit { 925 + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) 926 + return noFrame, ErrReadLimit 927 + } 928 + 929 + return frameType, nil 930 + } 931 + 932 + // 6. Read control frame payload. 933 + 934 + var payload []byte 935 + if c.readRemaining > 0 { 936 + payload, err = c.read(int(c.readRemaining)) 937 + c.setReadRemaining(0) 938 + if err != nil { 939 + return noFrame, err 940 + } 941 + if c.isServer { 942 + maskBytes(c.readMaskKey, 0, payload) 943 + } 944 + } 945 + 946 + // 7. Process control frame payload. 947 + 948 + switch frameType { 949 + case PongMessage: 950 + if err := c.handlePong(string(payload)); err != nil { 951 + return noFrame, err 952 + } 953 + case PingMessage: 954 + if err := c.handlePing(string(payload)); err != nil { 955 + return noFrame, err 956 + } 957 + case CloseMessage: 958 + closeCode := CloseNoStatusReceived 959 + closeText := "" 960 + if len(payload) >= 2 { 961 + closeCode = int(binary.BigEndian.Uint16(payload)) 962 + if !isValidReceivedCloseCode(closeCode) { 963 + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) 964 + } 965 + closeText = string(payload[2:]) 966 + if !utf8.ValidString(closeText) { 967 + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") 968 + } 969 + } 970 + if err := c.handleClose(closeCode, closeText); err != nil { 971 + return noFrame, err 972 + } 973 + return noFrame, &CloseError{Code: closeCode, Text: closeText} 974 + } 975 + 976 + return frameType, nil 977 + } 978 + 979 + func (c *Conn) handleProtocolError(message string) error { 980 + data := FormatCloseMessage(CloseProtocolError, message) 981 + if len(data) > maxControlFramePayloadSize { 982 + data = data[:maxControlFramePayloadSize] 983 + } 984 + c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) 985 + return errors.New("websocket: " + message) 986 + } 987 + 988 + // NextReader returns the next data message received from the peer. The 989 + // returned messageType is either TextMessage or BinaryMessage. 990 + // 991 + // There can be at most one open reader on a connection. NextReader discards 992 + // the previous message if the application has not already consumed it. 993 + // 994 + // Applications must break out of the application's read loop when this method 995 + // returns a non-nil error value. Errors returned from this method are 996 + // permanent. Once this method returns a non-nil error, all subsequent calls to 997 + // this method return the same error. 998 + func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { 999 + // Close previous reader, only relevant for decompression. 1000 + if c.reader != nil { 1001 + c.reader.Close() 1002 + c.reader = nil 1003 + } 1004 + 1005 + c.messageReader = nil 1006 + c.readLength = 0 1007 + 1008 + for c.readErr == nil { 1009 + frameType, err := c.advanceFrame() 1010 + if err != nil { 1011 + c.readErr = hideTempErr(err) 1012 + break 1013 + } 1014 + 1015 + if frameType == TextMessage || frameType == BinaryMessage { 1016 + c.messageReader = &messageReader{c} 1017 + c.reader = c.messageReader 1018 + if c.readDecompress { 1019 + c.reader = c.newDecompressionReader(c.reader) 1020 + } 1021 + return frameType, c.reader, nil 1022 + } 1023 + } 1024 + 1025 + // Applications that do handle the error returned from this method spin in 1026 + // tight loop on connection failure. To help application developers detect 1027 + // this error, panic on repeated reads to the failed connection. 1028 + c.readErrCount++ 1029 + if c.readErrCount >= 1000 { 1030 + panic("repeated read on failed websocket connection") 1031 + } 1032 + 1033 + return noFrame, nil, c.readErr 1034 + } 1035 + 1036 + type messageReader struct{ c *Conn } 1037 + 1038 + func (r *messageReader) Read(b []byte) (int, error) { 1039 + c := r.c 1040 + if c.messageReader != r { 1041 + return 0, io.EOF 1042 + } 1043 + 1044 + for c.readErr == nil { 1045 + 1046 + if c.readRemaining > 0 { 1047 + if int64(len(b)) > c.readRemaining { 1048 + b = b[:c.readRemaining] 1049 + } 1050 + n, err := c.br.Read(b) 1051 + c.readErr = hideTempErr(err) 1052 + if c.isServer { 1053 + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) 1054 + } 1055 + rem := c.readRemaining 1056 + rem -= int64(n) 1057 + c.setReadRemaining(rem) 1058 + if c.readRemaining > 0 && c.readErr == io.EOF { 1059 + c.readErr = errUnexpectedEOF 1060 + } 1061 + return n, c.readErr 1062 + } 1063 + 1064 + if c.readFinal { 1065 + c.messageReader = nil 1066 + return 0, io.EOF 1067 + } 1068 + 1069 + frameType, err := c.advanceFrame() 1070 + switch { 1071 + case err != nil: 1072 + c.readErr = hideTempErr(err) 1073 + case frameType == TextMessage || frameType == BinaryMessage: 1074 + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") 1075 + } 1076 + } 1077 + 1078 + err := c.readErr 1079 + if err == io.EOF && c.messageReader == r { 1080 + err = errUnexpectedEOF 1081 + } 1082 + return 0, err 1083 + } 1084 + 1085 + func (r *messageReader) Close() error { 1086 + return nil 1087 + } 1088 + 1089 + // ReadMessage is a helper method for getting a reader using NextReader and 1090 + // reading from that reader to a buffer. 1091 + func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { 1092 + var r io.Reader 1093 + messageType, r, err = c.NextReader() 1094 + if err != nil { 1095 + return messageType, nil, err 1096 + } 1097 + p, err = ioutil.ReadAll(r) 1098 + return messageType, p, err 1099 + } 1100 + 1101 + // SetReadDeadline sets the read deadline on the underlying network connection. 1102 + // After a read has timed out, the websocket connection state is corrupt and 1103 + // all future reads will return an error. A zero value for t means reads will 1104 + // not time out. 1105 + func (c *Conn) SetReadDeadline(t time.Time) error { 1106 + return c.conn.SetReadDeadline(t) 1107 + } 1108 + 1109 + // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a 1110 + // message exceeds the limit, the connection sends a close message to the peer 1111 + // and returns ErrReadLimit to the application. 1112 + func (c *Conn) SetReadLimit(limit int64) { 1113 + c.readLimit = limit 1114 + } 1115 + 1116 + // CloseHandler returns the current close handler 1117 + func (c *Conn) CloseHandler() func(code int, text string) error { 1118 + return c.handleClose 1119 + } 1120 + 1121 + // SetCloseHandler sets the handler for close messages received from the peer. 1122 + // The code argument to h is the received close code or CloseNoStatusReceived 1123 + // if the close message is empty. The default close handler sends a close 1124 + // message back to the peer. 1125 + // 1126 + // The handler function is called from the NextReader, ReadMessage and message 1127 + // reader Read methods. The application must read the connection to process 1128 + // close messages as described in the section on Control Messages above. 1129 + // 1130 + // The connection read methods return a CloseError when a close message is 1131 + // received. Most applications should handle close messages as part of their 1132 + // normal error handling. Applications should only set a close handler when the 1133 + // application must perform some action before sending a close message back to 1134 + // the peer. 1135 + func (c *Conn) SetCloseHandler(h func(code int, text string) error) { 1136 + if h == nil { 1137 + h = func(code int, text string) error { 1138 + message := FormatCloseMessage(code, "") 1139 + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) 1140 + return nil 1141 + } 1142 + } 1143 + c.handleClose = h 1144 + } 1145 + 1146 + // PingHandler returns the current ping handler 1147 + func (c *Conn) PingHandler() func(appData string) error { 1148 + return c.handlePing 1149 + } 1150 + 1151 + // SetPingHandler sets the handler for ping messages received from the peer. 1152 + // The appData argument to h is the PING message application data. The default 1153 + // ping handler sends a pong to the peer. 1154 + // 1155 + // The handler function is called from the NextReader, ReadMessage and message 1156 + // reader Read methods. The application must read the connection to process 1157 + // ping messages as described in the section on Control Messages above. 1158 + func (c *Conn) SetPingHandler(h func(appData string) error) { 1159 + if h == nil { 1160 + h = func(message string) error { 1161 + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) 1162 + if err == ErrCloseSent { 1163 + return nil 1164 + } else if e, ok := err.(net.Error); ok && e.Temporary() { 1165 + return nil 1166 + } 1167 + return err 1168 + } 1169 + } 1170 + c.handlePing = h 1171 + } 1172 + 1173 + // PongHandler returns the current pong handler 1174 + func (c *Conn) PongHandler() func(appData string) error { 1175 + return c.handlePong 1176 + } 1177 + 1178 + // SetPongHandler sets the handler for pong messages received from the peer. 1179 + // The appData argument to h is the PONG message application data. The default 1180 + // pong handler does nothing. 1181 + // 1182 + // The handler function is called from the NextReader, ReadMessage and message 1183 + // reader Read methods. The application must read the connection to process 1184 + // pong messages as described in the section on Control Messages above. 1185 + func (c *Conn) SetPongHandler(h func(appData string) error) { 1186 + if h == nil { 1187 + h = func(string) error { return nil } 1188 + } 1189 + c.handlePong = h 1190 + } 1191 + 1192 + // NetConn returns the underlying connection that is wrapped by c. 1193 + // Note that writing to or reading from this connection directly will corrupt the 1194 + // WebSocket connection. 1195 + func (c *Conn) NetConn() net.Conn { 1196 + return c.conn 1197 + } 1198 + 1199 + // UnderlyingConn returns the internal net.Conn. This can be used to further 1200 + // modifications to connection specific flags. 1201 + // Deprecated: Use the NetConn method. 1202 + func (c *Conn) UnderlyingConn() net.Conn { 1203 + return c.conn 1204 + } 1205 + 1206 + // EnableWriteCompression enables and disables write compression of 1207 + // subsequent text and binary messages. This function is a noop if 1208 + // compression was not negotiated with the peer. 1209 + func (c *Conn) EnableWriteCompression(enable bool) { 1210 + c.enableWriteCompression = enable 1211 + } 1212 + 1213 + // SetCompressionLevel sets the flate compression level for subsequent text and 1214 + // binary messages. This function is a noop if compression was not negotiated 1215 + // with the peer. See the compress/flate package for a description of 1216 + // compression levels. 1217 + func (c *Conn) SetCompressionLevel(level int) error { 1218 + if !isValidCompressionLevel(level) { 1219 + return errors.New("websocket: invalid compression level") 1220 + } 1221 + c.compressionLevel = level 1222 + return nil 1223 + } 1224 + 1225 + // FormatCloseMessage formats closeCode and text as a WebSocket close message. 1226 + // An empty message is returned for code CloseNoStatusReceived. 1227 + func FormatCloseMessage(closeCode int, text string) []byte { 1228 + if closeCode == CloseNoStatusReceived { 1229 + // Return empty message because it's illegal to send 1230 + // CloseNoStatusReceived. Return non-nil value in case application 1231 + // checks for nil. 1232 + return []byte{} 1233 + } 1234 + buf := make([]byte, 2+len(text)) 1235 + binary.BigEndian.PutUint16(buf, uint16(closeCode)) 1236 + copy(buf[2:], text) 1237 + return buf 1238 + }
+227
vendor/github.com/gorilla/websocket/doc.go
··· 1 + // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + // Package websocket implements the WebSocket protocol defined in RFC 6455. 6 + // 7 + // Overview 8 + // 9 + // The Conn type represents a WebSocket connection. A server application calls 10 + // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: 11 + // 12 + // var upgrader = websocket.Upgrader{ 13 + // ReadBufferSize: 1024, 14 + // WriteBufferSize: 1024, 15 + // } 16 + // 17 + // func handler(w http.ResponseWriter, r *http.Request) { 18 + // conn, err := upgrader.Upgrade(w, r, nil) 19 + // if err != nil { 20 + // log.Println(err) 21 + // return 22 + // } 23 + // ... Use conn to send and receive messages. 24 + // } 25 + // 26 + // Call the connection's WriteMessage and ReadMessage methods to send and 27 + // receive messages as a slice of bytes. This snippet of code shows how to echo 28 + // messages using these methods: 29 + // 30 + // for { 31 + // messageType, p, err := conn.ReadMessage() 32 + // if err != nil { 33 + // log.Println(err) 34 + // return 35 + // } 36 + // if err := conn.WriteMessage(messageType, p); err != nil { 37 + // log.Println(err) 38 + // return 39 + // } 40 + // } 41 + // 42 + // In above snippet of code, p is a []byte and messageType is an int with value 43 + // websocket.BinaryMessage or websocket.TextMessage. 44 + // 45 + // An application can also send and receive messages using the io.WriteCloser 46 + // and io.Reader interfaces. To send a message, call the connection NextWriter 47 + // method to get an io.WriteCloser, write the message to the writer and close 48 + // the writer when done. To receive a message, call the connection NextReader 49 + // method to get an io.Reader and read until io.EOF is returned. This snippet 50 + // shows how to echo messages using the NextWriter and NextReader methods: 51 + // 52 + // for { 53 + // messageType, r, err := conn.NextReader() 54 + // if err != nil { 55 + // return 56 + // } 57 + // w, err := conn.NextWriter(messageType) 58 + // if err != nil { 59 + // return err 60 + // } 61 + // if _, err := io.Copy(w, r); err != nil { 62 + // return err 63 + // } 64 + // if err := w.Close(); err != nil { 65 + // return err 66 + // } 67 + // } 68 + // 69 + // Data Messages 70 + // 71 + // The WebSocket protocol distinguishes between text and binary data messages. 72 + // Text messages are interpreted as UTF-8 encoded text. The interpretation of 73 + // binary messages is left to the application. 74 + // 75 + // This package uses the TextMessage and BinaryMessage integer constants to 76 + // identify the two data message types. The ReadMessage and NextReader methods 77 + // return the type of the received message. The messageType argument to the 78 + // WriteMessage and NextWriter methods specifies the type of a sent message. 79 + // 80 + // It is the application's responsibility to ensure that text messages are 81 + // valid UTF-8 encoded text. 82 + // 83 + // Control Messages 84 + // 85 + // The WebSocket protocol defines three types of control messages: close, ping 86 + // and pong. Call the connection WriteControl, WriteMessage or NextWriter 87 + // methods to send a control message to the peer. 88 + // 89 + // Connections handle received close messages by calling the handler function 90 + // set with the SetCloseHandler method and by returning a *CloseError from the 91 + // NextReader, ReadMessage or the message Read method. The default close 92 + // handler sends a close message to the peer. 93 + // 94 + // Connections handle received ping messages by calling the handler function 95 + // set with the SetPingHandler method. The default ping handler sends a pong 96 + // message to the peer. 97 + // 98 + // Connections handle received pong messages by calling the handler function 99 + // set with the SetPongHandler method. The default pong handler does nothing. 100 + // If an application sends ping messages, then the application should set a 101 + // pong handler to receive the corresponding pong. 102 + // 103 + // The control message handler functions are called from the NextReader, 104 + // ReadMessage and message reader Read methods. The default close and ping 105 + // handlers can block these methods for a short time when the handler writes to 106 + // the connection. 107 + // 108 + // The application must read the connection to process close, ping and pong 109 + // messages sent from the peer. If the application is not otherwise interested 110 + // in messages from the peer, then the application should start a goroutine to 111 + // read and discard messages from the peer. A simple example is: 112 + // 113 + // func readLoop(c *websocket.Conn) { 114 + // for { 115 + // if _, _, err := c.NextReader(); err != nil { 116 + // c.Close() 117 + // break 118 + // } 119 + // } 120 + // } 121 + // 122 + // Concurrency 123 + // 124 + // Connections support one concurrent reader and one concurrent writer. 125 + // 126 + // Applications are responsible for ensuring that no more than one goroutine 127 + // calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, 128 + // WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and 129 + // that no more than one goroutine calls the read methods (NextReader, 130 + // SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler) 131 + // concurrently. 132 + // 133 + // The Close and WriteControl methods can be called concurrently with all other 134 + // methods. 135 + // 136 + // Origin Considerations 137 + // 138 + // Web browsers allow Javascript applications to open a WebSocket connection to 139 + // any host. It's up to the server to enforce an origin policy using the Origin 140 + // request header sent by the browser. 141 + // 142 + // The Upgrader calls the function specified in the CheckOrigin field to check 143 + // the origin. If the CheckOrigin function returns false, then the Upgrade 144 + // method fails the WebSocket handshake with HTTP status 403. 145 + // 146 + // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail 147 + // the handshake if the Origin request header is present and the Origin host is 148 + // not equal to the Host request header. 149 + // 150 + // The deprecated package-level Upgrade function does not perform origin 151 + // checking. The application is responsible for checking the Origin header 152 + // before calling the Upgrade function. 153 + // 154 + // Buffers 155 + // 156 + // Connections buffer network input and output to reduce the number 157 + // of system calls when reading or writing messages. 158 + // 159 + // Write buffers are also used for constructing WebSocket frames. See RFC 6455, 160 + // Section 5 for a discussion of message framing. A WebSocket frame header is 161 + // written to the network each time a write buffer is flushed to the network. 162 + // Decreasing the size of the write buffer can increase the amount of framing 163 + // overhead on the connection. 164 + // 165 + // The buffer sizes in bytes are specified by the ReadBufferSize and 166 + // WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default 167 + // size of 4096 when a buffer size field is set to zero. The Upgrader reuses 168 + // buffers created by the HTTP server when a buffer size field is set to zero. 169 + // The HTTP server buffers have a size of 4096 at the time of this writing. 170 + // 171 + // The buffer sizes do not limit the size of a message that can be read or 172 + // written by a connection. 173 + // 174 + // Buffers are held for the lifetime of the connection by default. If the 175 + // Dialer or Upgrader WriteBufferPool field is set, then a connection holds the 176 + // write buffer only when writing a message. 177 + // 178 + // Applications should tune the buffer sizes to balance memory use and 179 + // performance. Increasing the buffer size uses more memory, but can reduce the 180 + // number of system calls to read or write the network. In the case of writing, 181 + // increasing the buffer size can reduce the number of frame headers written to 182 + // the network. 183 + // 184 + // Some guidelines for setting buffer parameters are: 185 + // 186 + // Limit the buffer sizes to the maximum expected message size. Buffers larger 187 + // than the largest message do not provide any benefit. 188 + // 189 + // Depending on the distribution of message sizes, setting the buffer size to 190 + // a value less than the maximum expected message size can greatly reduce memory 191 + // use with a small impact on performance. Here's an example: If 99% of the 192 + // messages are smaller than 256 bytes and the maximum message size is 512 193 + // bytes, then a buffer size of 256 bytes will result in 1.01 more system calls 194 + // than a buffer size of 512 bytes. The memory savings is 50%. 195 + // 196 + // A write buffer pool is useful when the application has a modest number 197 + // writes over a large number of connections. when buffers are pooled, a larger 198 + // buffer size has a reduced impact on total memory use and has the benefit of 199 + // reducing system calls and frame overhead. 200 + // 201 + // Compression EXPERIMENTAL 202 + // 203 + // Per message compression extensions (RFC 7692) are experimentally supported 204 + // by this package in a limited capacity. Setting the EnableCompression option 205 + // to true in Dialer or Upgrader will attempt to negotiate per message deflate 206 + // support. 207 + // 208 + // var upgrader = websocket.Upgrader{ 209 + // EnableCompression: true, 210 + // } 211 + // 212 + // If compression was successfully negotiated with the connection's peer, any 213 + // message received in compressed form will be automatically decompressed. 214 + // All Read methods will return uncompressed bytes. 215 + // 216 + // Per message compression of messages written to a connection can be enabled 217 + // or disabled by calling the corresponding Conn method: 218 + // 219 + // conn.EnableWriteCompression(false) 220 + // 221 + // Currently this package does not support compression with "context takeover". 222 + // This means that messages must be compressed and decompressed in isolation, 223 + // without retaining sliding window or dictionary state across messages. For 224 + // more details refer to RFC 7692. 225 + // 226 + // Use of compression is experimental and may result in decreased performance. 227 + package websocket
+42
vendor/github.com/gorilla/websocket/join.go
··· 1 + // Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "io" 9 + "strings" 10 + ) 11 + 12 + // JoinMessages concatenates received messages to create a single io.Reader. 13 + // The string term is appended to each message. The returned reader does not 14 + // support concurrent calls to the Read method. 15 + func JoinMessages(c *Conn, term string) io.Reader { 16 + return &joinReader{c: c, term: term} 17 + } 18 + 19 + type joinReader struct { 20 + c *Conn 21 + term string 22 + r io.Reader 23 + } 24 + 25 + func (r *joinReader) Read(p []byte) (int, error) { 26 + if r.r == nil { 27 + var err error 28 + _, r.r, err = r.c.NextReader() 29 + if err != nil { 30 + return 0, err 31 + } 32 + if r.term != "" { 33 + r.r = io.MultiReader(r.r, strings.NewReader(r.term)) 34 + } 35 + } 36 + n, err := r.r.Read(p) 37 + if err == io.EOF { 38 + err = nil 39 + r.r = nil 40 + } 41 + return n, err 42 + }
+60
vendor/github.com/gorilla/websocket/json.go
··· 1 + // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "encoding/json" 9 + "io" 10 + ) 11 + 12 + // WriteJSON writes the JSON encoding of v as a message. 13 + // 14 + // Deprecated: Use c.WriteJSON instead. 15 + func WriteJSON(c *Conn, v interface{}) error { 16 + return c.WriteJSON(v) 17 + } 18 + 19 + // WriteJSON writes the JSON encoding of v as a message. 20 + // 21 + // See the documentation for encoding/json Marshal for details about the 22 + // conversion of Go values to JSON. 23 + func (c *Conn) WriteJSON(v interface{}) error { 24 + w, err := c.NextWriter(TextMessage) 25 + if err != nil { 26 + return err 27 + } 28 + err1 := json.NewEncoder(w).Encode(v) 29 + err2 := w.Close() 30 + if err1 != nil { 31 + return err1 32 + } 33 + return err2 34 + } 35 + 36 + // ReadJSON reads the next JSON-encoded message from the connection and stores 37 + // it in the value pointed to by v. 38 + // 39 + // Deprecated: Use c.ReadJSON instead. 40 + func ReadJSON(c *Conn, v interface{}) error { 41 + return c.ReadJSON(v) 42 + } 43 + 44 + // ReadJSON reads the next JSON-encoded message from the connection and stores 45 + // it in the value pointed to by v. 46 + // 47 + // See the documentation for the encoding/json Unmarshal function for details 48 + // about the conversion of JSON to a Go value. 49 + func (c *Conn) ReadJSON(v interface{}) error { 50 + _, r, err := c.NextReader() 51 + if err != nil { 52 + return err 53 + } 54 + err = json.NewDecoder(r).Decode(v) 55 + if err == io.EOF { 56 + // One value is expected in the message. 57 + err = io.ErrUnexpectedEOF 58 + } 59 + return err 60 + }
+55
vendor/github.com/gorilla/websocket/mask.go
··· 1 + // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 + // this source code is governed by a BSD-style license that can be found in the 3 + // LICENSE file. 4 + 5 + //go:build !appengine 6 + // +build !appengine 7 + 8 + package websocket 9 + 10 + import "unsafe" 11 + 12 + const wordSize = int(unsafe.Sizeof(uintptr(0))) 13 + 14 + func maskBytes(key [4]byte, pos int, b []byte) int { 15 + // Mask one byte at a time for small buffers. 16 + if len(b) < 2*wordSize { 17 + for i := range b { 18 + b[i] ^= key[pos&3] 19 + pos++ 20 + } 21 + return pos & 3 22 + } 23 + 24 + // Mask one byte at a time to word boundary. 25 + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { 26 + n = wordSize - n 27 + for i := range b[:n] { 28 + b[i] ^= key[pos&3] 29 + pos++ 30 + } 31 + b = b[n:] 32 + } 33 + 34 + // Create aligned word size key. 35 + var k [wordSize]byte 36 + for i := range k { 37 + k[i] = key[(pos+i)&3] 38 + } 39 + kw := *(*uintptr)(unsafe.Pointer(&k)) 40 + 41 + // Mask one word at a time. 42 + n := (len(b) / wordSize) * wordSize 43 + for i := 0; i < n; i += wordSize { 44 + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw 45 + } 46 + 47 + // Mask one byte at a time for remaining bytes. 48 + b = b[n:] 49 + for i := range b { 50 + b[i] ^= key[pos&3] 51 + pos++ 52 + } 53 + 54 + return pos & 3 55 + }
+16
vendor/github.com/gorilla/websocket/mask_safe.go
··· 1 + // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of 2 + // this source code is governed by a BSD-style license that can be found in the 3 + // LICENSE file. 4 + 5 + //go:build appengine 6 + // +build appengine 7 + 8 + package websocket 9 + 10 + func maskBytes(key [4]byte, pos int, b []byte) int { 11 + for i := range b { 12 + b[i] ^= key[pos&3] 13 + pos++ 14 + } 15 + return pos & 3 16 + }
+102
vendor/github.com/gorilla/websocket/prepared.go
··· 1 + // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "bytes" 9 + "net" 10 + "sync" 11 + "time" 12 + ) 13 + 14 + // PreparedMessage caches on the wire representations of a message payload. 15 + // Use PreparedMessage to efficiently send a message payload to multiple 16 + // connections. PreparedMessage is especially useful when compression is used 17 + // because the CPU and memory expensive compression operation can be executed 18 + // once for a given set of compression options. 19 + type PreparedMessage struct { 20 + messageType int 21 + data []byte 22 + mu sync.Mutex 23 + frames map[prepareKey]*preparedFrame 24 + } 25 + 26 + // prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. 27 + type prepareKey struct { 28 + isServer bool 29 + compress bool 30 + compressionLevel int 31 + } 32 + 33 + // preparedFrame contains data in wire representation. 34 + type preparedFrame struct { 35 + once sync.Once 36 + data []byte 37 + } 38 + 39 + // NewPreparedMessage returns an initialized PreparedMessage. You can then send 40 + // it to connection using WritePreparedMessage method. Valid wire 41 + // representation will be calculated lazily only once for a set of current 42 + // connection options. 43 + func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { 44 + pm := &PreparedMessage{ 45 + messageType: messageType, 46 + frames: make(map[prepareKey]*preparedFrame), 47 + data: data, 48 + } 49 + 50 + // Prepare a plain server frame. 51 + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) 52 + if err != nil { 53 + return nil, err 54 + } 55 + 56 + // To protect against caller modifying the data argument, remember the data 57 + // copied to the plain server frame. 58 + pm.data = frameData[len(frameData)-len(data):] 59 + return pm, nil 60 + } 61 + 62 + func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { 63 + pm.mu.Lock() 64 + frame, ok := pm.frames[key] 65 + if !ok { 66 + frame = &preparedFrame{} 67 + pm.frames[key] = frame 68 + } 69 + pm.mu.Unlock() 70 + 71 + var err error 72 + frame.once.Do(func() { 73 + // Prepare a frame using a 'fake' connection. 74 + // TODO: Refactor code in conn.go to allow more direct construction of 75 + // the frame. 76 + mu := make(chan struct{}, 1) 77 + mu <- struct{}{} 78 + var nc prepareConn 79 + c := &Conn{ 80 + conn: &nc, 81 + mu: mu, 82 + isServer: key.isServer, 83 + compressionLevel: key.compressionLevel, 84 + enableWriteCompression: true, 85 + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), 86 + } 87 + if key.compress { 88 + c.newCompressionWriter = compressNoContextTakeover 89 + } 90 + err = c.WriteMessage(pm.messageType, pm.data) 91 + frame.data = nc.buf.Bytes() 92 + }) 93 + return pm.messageType, frame.data, err 94 + } 95 + 96 + type prepareConn struct { 97 + buf bytes.Buffer 98 + net.Conn 99 + } 100 + 101 + func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } 102 + func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }
+77
vendor/github.com/gorilla/websocket/proxy.go
··· 1 + // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "bufio" 9 + "encoding/base64" 10 + "errors" 11 + "net" 12 + "net/http" 13 + "net/url" 14 + "strings" 15 + ) 16 + 17 + type netDialerFunc func(network, addr string) (net.Conn, error) 18 + 19 + func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { 20 + return fn(network, addr) 21 + } 22 + 23 + func init() { 24 + proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { 25 + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil 26 + }) 27 + } 28 + 29 + type httpProxyDialer struct { 30 + proxyURL *url.URL 31 + forwardDial func(network, addr string) (net.Conn, error) 32 + } 33 + 34 + func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { 35 + hostPort, _ := hostPortNoPort(hpd.proxyURL) 36 + conn, err := hpd.forwardDial(network, hostPort) 37 + if err != nil { 38 + return nil, err 39 + } 40 + 41 + connectHeader := make(http.Header) 42 + if user := hpd.proxyURL.User; user != nil { 43 + proxyUser := user.Username() 44 + if proxyPassword, passwordSet := user.Password(); passwordSet { 45 + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) 46 + connectHeader.Set("Proxy-Authorization", "Basic "+credential) 47 + } 48 + } 49 + 50 + connectReq := &http.Request{ 51 + Method: http.MethodConnect, 52 + URL: &url.URL{Opaque: addr}, 53 + Host: addr, 54 + Header: connectHeader, 55 + } 56 + 57 + if err := connectReq.Write(conn); err != nil { 58 + conn.Close() 59 + return nil, err 60 + } 61 + 62 + // Read response. It's OK to use and discard buffered reader here becaue 63 + // the remote server does not speak until spoken to. 64 + br := bufio.NewReader(conn) 65 + resp, err := http.ReadResponse(br, connectReq) 66 + if err != nil { 67 + conn.Close() 68 + return nil, err 69 + } 70 + 71 + if resp.StatusCode != 200 { 72 + conn.Close() 73 + f := strings.SplitN(resp.Status, " ", 2) 74 + return nil, errors.New(f[1]) 75 + } 76 + return conn, nil 77 + }
+365
vendor/github.com/gorilla/websocket/server.go
··· 1 + // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "bufio" 9 + "errors" 10 + "io" 11 + "net/http" 12 + "net/url" 13 + "strings" 14 + "time" 15 + ) 16 + 17 + // HandshakeError describes an error with the handshake from the peer. 18 + type HandshakeError struct { 19 + message string 20 + } 21 + 22 + func (e HandshakeError) Error() string { return e.message } 23 + 24 + // Upgrader specifies parameters for upgrading an HTTP connection to a 25 + // WebSocket connection. 26 + // 27 + // It is safe to call Upgrader's methods concurrently. 28 + type Upgrader struct { 29 + // HandshakeTimeout specifies the duration for the handshake to complete. 30 + HandshakeTimeout time.Duration 31 + 32 + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer 33 + // size is zero, then buffers allocated by the HTTP server are used. The 34 + // I/O buffer sizes do not limit the size of the messages that can be sent 35 + // or received. 36 + ReadBufferSize, WriteBufferSize int 37 + 38 + // WriteBufferPool is a pool of buffers for write operations. If the value 39 + // is not set, then write buffers are allocated to the connection for the 40 + // lifetime of the connection. 41 + // 42 + // A pool is most useful when the application has a modest volume of writes 43 + // across a large number of connections. 44 + // 45 + // Applications should use a single pool for each unique value of 46 + // WriteBufferSize. 47 + WriteBufferPool BufferPool 48 + 49 + // Subprotocols specifies the server's supported protocols in order of 50 + // preference. If this field is not nil, then the Upgrade method negotiates a 51 + // subprotocol by selecting the first match in this list with a protocol 52 + // requested by the client. If there's no match, then no protocol is 53 + // negotiated (the Sec-Websocket-Protocol header is not included in the 54 + // handshake response). 55 + Subprotocols []string 56 + 57 + // Error specifies the function for generating HTTP error responses. If Error 58 + // is nil, then http.Error is used to generate the HTTP response. 59 + Error func(w http.ResponseWriter, r *http.Request, status int, reason error) 60 + 61 + // CheckOrigin returns true if the request Origin header is acceptable. If 62 + // CheckOrigin is nil, then a safe default is used: return false if the 63 + // Origin request header is present and the origin host is not equal to 64 + // request Host header. 65 + // 66 + // A CheckOrigin function should carefully validate the request origin to 67 + // prevent cross-site request forgery. 68 + CheckOrigin func(r *http.Request) bool 69 + 70 + // EnableCompression specify if the server should attempt to negotiate per 71 + // message compression (RFC 7692). Setting this value to true does not 72 + // guarantee that compression will be supported. Currently only "no context 73 + // takeover" modes are supported. 74 + EnableCompression bool 75 + } 76 + 77 + func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { 78 + err := HandshakeError{reason} 79 + if u.Error != nil { 80 + u.Error(w, r, status, err) 81 + } else { 82 + w.Header().Set("Sec-Websocket-Version", "13") 83 + http.Error(w, http.StatusText(status), status) 84 + } 85 + return nil, err 86 + } 87 + 88 + // checkSameOrigin returns true if the origin is not set or is equal to the request host. 89 + func checkSameOrigin(r *http.Request) bool { 90 + origin := r.Header["Origin"] 91 + if len(origin) == 0 { 92 + return true 93 + } 94 + u, err := url.Parse(origin[0]) 95 + if err != nil { 96 + return false 97 + } 98 + return equalASCIIFold(u.Host, r.Host) 99 + } 100 + 101 + func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { 102 + if u.Subprotocols != nil { 103 + clientProtocols := Subprotocols(r) 104 + for _, serverProtocol := range u.Subprotocols { 105 + for _, clientProtocol := range clientProtocols { 106 + if clientProtocol == serverProtocol { 107 + return clientProtocol 108 + } 109 + } 110 + } 111 + } else if responseHeader != nil { 112 + return responseHeader.Get("Sec-Websocket-Protocol") 113 + } 114 + return "" 115 + } 116 + 117 + // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 118 + // 119 + // The responseHeader is included in the response to the client's upgrade 120 + // request. Use the responseHeader to specify cookies (Set-Cookie). To specify 121 + // subprotocols supported by the server, set Upgrader.Subprotocols directly. 122 + // 123 + // If the upgrade fails, then Upgrade replies to the client with an HTTP error 124 + // response. 125 + func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { 126 + const badHandshake = "websocket: the client is not using the websocket protocol: " 127 + 128 + if !tokenListContainsValue(r.Header, "Connection", "upgrade") { 129 + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") 130 + } 131 + 132 + if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { 133 + return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") 134 + } 135 + 136 + if r.Method != http.MethodGet { 137 + return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") 138 + } 139 + 140 + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { 141 + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") 142 + } 143 + 144 + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { 145 + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") 146 + } 147 + 148 + checkOrigin := u.CheckOrigin 149 + if checkOrigin == nil { 150 + checkOrigin = checkSameOrigin 151 + } 152 + if !checkOrigin(r) { 153 + return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") 154 + } 155 + 156 + challengeKey := r.Header.Get("Sec-Websocket-Key") 157 + if !isValidChallengeKey(challengeKey) { 158 + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") 159 + } 160 + 161 + subprotocol := u.selectSubprotocol(r, responseHeader) 162 + 163 + // Negotiate PMCE 164 + var compress bool 165 + if u.EnableCompression { 166 + for _, ext := range parseExtensions(r.Header) { 167 + if ext[""] != "permessage-deflate" { 168 + continue 169 + } 170 + compress = true 171 + break 172 + } 173 + } 174 + 175 + h, ok := w.(http.Hijacker) 176 + if !ok { 177 + return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") 178 + } 179 + var brw *bufio.ReadWriter 180 + netConn, brw, err := h.Hijack() 181 + if err != nil { 182 + return u.returnError(w, r, http.StatusInternalServerError, err.Error()) 183 + } 184 + 185 + if brw.Reader.Buffered() > 0 { 186 + netConn.Close() 187 + return nil, errors.New("websocket: client sent data before handshake is complete") 188 + } 189 + 190 + var br *bufio.Reader 191 + if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { 192 + // Reuse hijacked buffered reader as connection reader. 193 + br = brw.Reader 194 + } 195 + 196 + buf := bufioWriterBuffer(netConn, brw.Writer) 197 + 198 + var writeBuf []byte 199 + if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { 200 + // Reuse hijacked write buffer as connection buffer. 201 + writeBuf = buf 202 + } 203 + 204 + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) 205 + c.subprotocol = subprotocol 206 + 207 + if compress { 208 + c.newCompressionWriter = compressNoContextTakeover 209 + c.newDecompressionReader = decompressNoContextTakeover 210 + } 211 + 212 + // Use larger of hijacked buffer and connection write buffer for header. 213 + p := buf 214 + if len(c.writeBuf) > len(p) { 215 + p = c.writeBuf 216 + } 217 + p = p[:0] 218 + 219 + p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 220 + p = append(p, computeAcceptKey(challengeKey)...) 221 + p = append(p, "\r\n"...) 222 + if c.subprotocol != "" { 223 + p = append(p, "Sec-WebSocket-Protocol: "...) 224 + p = append(p, c.subprotocol...) 225 + p = append(p, "\r\n"...) 226 + } 227 + if compress { 228 + p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) 229 + } 230 + for k, vs := range responseHeader { 231 + if k == "Sec-Websocket-Protocol" { 232 + continue 233 + } 234 + for _, v := range vs { 235 + p = append(p, k...) 236 + p = append(p, ": "...) 237 + for i := 0; i < len(v); i++ { 238 + b := v[i] 239 + if b <= 31 { 240 + // prevent response splitting. 241 + b = ' ' 242 + } 243 + p = append(p, b) 244 + } 245 + p = append(p, "\r\n"...) 246 + } 247 + } 248 + p = append(p, "\r\n"...) 249 + 250 + // Clear deadlines set by HTTP server. 251 + netConn.SetDeadline(time.Time{}) 252 + 253 + if u.HandshakeTimeout > 0 { 254 + netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) 255 + } 256 + if _, err = netConn.Write(p); err != nil { 257 + netConn.Close() 258 + return nil, err 259 + } 260 + if u.HandshakeTimeout > 0 { 261 + netConn.SetWriteDeadline(time.Time{}) 262 + } 263 + 264 + return c, nil 265 + } 266 + 267 + // Upgrade upgrades the HTTP server connection to the WebSocket protocol. 268 + // 269 + // Deprecated: Use websocket.Upgrader instead. 270 + // 271 + // Upgrade does not perform origin checking. The application is responsible for 272 + // checking the Origin header before calling Upgrade. An example implementation 273 + // of the same origin policy check is: 274 + // 275 + // if req.Header.Get("Origin") != "http://"+req.Host { 276 + // http.Error(w, "Origin not allowed", http.StatusForbidden) 277 + // return 278 + // } 279 + // 280 + // If the endpoint supports subprotocols, then the application is responsible 281 + // for negotiating the protocol used on the connection. Use the Subprotocols() 282 + // function to get the subprotocols requested by the client. Use the 283 + // Sec-Websocket-Protocol response header to specify the subprotocol selected 284 + // by the application. 285 + // 286 + // The responseHeader is included in the response to the client's upgrade 287 + // request. Use the responseHeader to specify cookies (Set-Cookie) and the 288 + // negotiated subprotocol (Sec-Websocket-Protocol). 289 + // 290 + // The connection buffers IO to the underlying network connection. The 291 + // readBufSize and writeBufSize parameters specify the size of the buffers to 292 + // use. Messages can be larger than the buffers. 293 + // 294 + // If the request is not a valid WebSocket handshake, then Upgrade returns an 295 + // error of type HandshakeError. Applications should handle this error by 296 + // replying to the client with an HTTP error response. 297 + func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { 298 + u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} 299 + u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { 300 + // don't return errors to maintain backwards compatibility 301 + } 302 + u.CheckOrigin = func(r *http.Request) bool { 303 + // allow all connections by default 304 + return true 305 + } 306 + return u.Upgrade(w, r, responseHeader) 307 + } 308 + 309 + // Subprotocols returns the subprotocols requested by the client in the 310 + // Sec-Websocket-Protocol header. 311 + func Subprotocols(r *http.Request) []string { 312 + h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) 313 + if h == "" { 314 + return nil 315 + } 316 + protocols := strings.Split(h, ",") 317 + for i := range protocols { 318 + protocols[i] = strings.TrimSpace(protocols[i]) 319 + } 320 + return protocols 321 + } 322 + 323 + // IsWebSocketUpgrade returns true if the client requested upgrade to the 324 + // WebSocket protocol. 325 + func IsWebSocketUpgrade(r *http.Request) bool { 326 + return tokenListContainsValue(r.Header, "Connection", "upgrade") && 327 + tokenListContainsValue(r.Header, "Upgrade", "websocket") 328 + } 329 + 330 + // bufioReaderSize size returns the size of a bufio.Reader. 331 + func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { 332 + // This code assumes that peek on a reset reader returns 333 + // bufio.Reader.buf[:0]. 334 + // TODO: Use bufio.Reader.Size() after Go 1.10 335 + br.Reset(originalReader) 336 + if p, err := br.Peek(0); err == nil { 337 + return cap(p) 338 + } 339 + return 0 340 + } 341 + 342 + // writeHook is an io.Writer that records the last slice passed to it vio 343 + // io.Writer.Write. 344 + type writeHook struct { 345 + p []byte 346 + } 347 + 348 + func (wh *writeHook) Write(p []byte) (int, error) { 349 + wh.p = p 350 + return len(p), nil 351 + } 352 + 353 + // bufioWriterBuffer grabs the buffer from a bufio.Writer. 354 + func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { 355 + // This code assumes that bufio.Writer.buf[:1] is passed to the 356 + // bufio.Writer's underlying writer. 357 + var wh writeHook 358 + bw.Reset(&wh) 359 + bw.WriteByte(0) 360 + bw.Flush() 361 + 362 + bw.Reset(originalWriter) 363 + 364 + return wh.p[:cap(wh.p)] 365 + }
+21
vendor/github.com/gorilla/websocket/tls_handshake.go
··· 1 + //go:build go1.17 2 + // +build go1.17 3 + 4 + package websocket 5 + 6 + import ( 7 + "context" 8 + "crypto/tls" 9 + ) 10 + 11 + func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { 12 + if err := tlsConn.HandshakeContext(ctx); err != nil { 13 + return err 14 + } 15 + if !cfg.InsecureSkipVerify { 16 + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { 17 + return err 18 + } 19 + } 20 + return nil 21 + }
+21
vendor/github.com/gorilla/websocket/tls_handshake_116.go
··· 1 + //go:build !go1.17 2 + // +build !go1.17 3 + 4 + package websocket 5 + 6 + import ( 7 + "context" 8 + "crypto/tls" 9 + ) 10 + 11 + func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { 12 + if err := tlsConn.Handshake(); err != nil { 13 + return err 14 + } 15 + if !cfg.InsecureSkipVerify { 16 + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { 17 + return err 18 + } 19 + } 20 + return nil 21 + }
+298
vendor/github.com/gorilla/websocket/util.go
··· 1 + // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 + // Use of this source code is governed by a BSD-style 3 + // license that can be found in the LICENSE file. 4 + 5 + package websocket 6 + 7 + import ( 8 + "crypto/rand" 9 + "crypto/sha1" 10 + "encoding/base64" 11 + "io" 12 + "net/http" 13 + "strings" 14 + "unicode/utf8" 15 + ) 16 + 17 + var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 18 + 19 + func computeAcceptKey(challengeKey string) string { 20 + h := sha1.New() 21 + h.Write([]byte(challengeKey)) 22 + h.Write(keyGUID) 23 + return base64.StdEncoding.EncodeToString(h.Sum(nil)) 24 + } 25 + 26 + func generateChallengeKey() (string, error) { 27 + p := make([]byte, 16) 28 + if _, err := io.ReadFull(rand.Reader, p); err != nil { 29 + return "", err 30 + } 31 + return base64.StdEncoding.EncodeToString(p), nil 32 + } 33 + 34 + // Token octets per RFC 2616. 35 + var isTokenOctet = [256]bool{ 36 + '!': true, 37 + '#': true, 38 + '$': true, 39 + '%': true, 40 + '&': true, 41 + '\'': true, 42 + '*': true, 43 + '+': true, 44 + '-': true, 45 + '.': true, 46 + '0': true, 47 + '1': true, 48 + '2': true, 49 + '3': true, 50 + '4': true, 51 + '5': true, 52 + '6': true, 53 + '7': true, 54 + '8': true, 55 + '9': true, 56 + 'A': true, 57 + 'B': true, 58 + 'C': true, 59 + 'D': true, 60 + 'E': true, 61 + 'F': true, 62 + 'G': true, 63 + 'H': true, 64 + 'I': true, 65 + 'J': true, 66 + 'K': true, 67 + 'L': true, 68 + 'M': true, 69 + 'N': true, 70 + 'O': true, 71 + 'P': true, 72 + 'Q': true, 73 + 'R': true, 74 + 'S': true, 75 + 'T': true, 76 + 'U': true, 77 + 'W': true, 78 + 'V': true, 79 + 'X': true, 80 + 'Y': true, 81 + 'Z': true, 82 + '^': true, 83 + '_': true, 84 + '`': true, 85 + 'a': true, 86 + 'b': true, 87 + 'c': true, 88 + 'd': true, 89 + 'e': true, 90 + 'f': true, 91 + 'g': true, 92 + 'h': true, 93 + 'i': true, 94 + 'j': true, 95 + 'k': true, 96 + 'l': true, 97 + 'm': true, 98 + 'n': true, 99 + 'o': true, 100 + 'p': true, 101 + 'q': true, 102 + 'r': true, 103 + 's': true, 104 + 't': true, 105 + 'u': true, 106 + 'v': true, 107 + 'w': true, 108 + 'x': true, 109 + 'y': true, 110 + 'z': true, 111 + '|': true, 112 + '~': true, 113 + } 114 + 115 + // skipSpace returns a slice of the string s with all leading RFC 2616 linear 116 + // whitespace removed. 117 + func skipSpace(s string) (rest string) { 118 + i := 0 119 + for ; i < len(s); i++ { 120 + if b := s[i]; b != ' ' && b != '\t' { 121 + break 122 + } 123 + } 124 + return s[i:] 125 + } 126 + 127 + // nextToken returns the leading RFC 2616 token of s and the string following 128 + // the token. 129 + func nextToken(s string) (token, rest string) { 130 + i := 0 131 + for ; i < len(s); i++ { 132 + if !isTokenOctet[s[i]] { 133 + break 134 + } 135 + } 136 + return s[:i], s[i:] 137 + } 138 + 139 + // nextTokenOrQuoted returns the leading token or quoted string per RFC 2616 140 + // and the string following the token or quoted string. 141 + func nextTokenOrQuoted(s string) (value string, rest string) { 142 + if !strings.HasPrefix(s, "\"") { 143 + return nextToken(s) 144 + } 145 + s = s[1:] 146 + for i := 0; i < len(s); i++ { 147 + switch s[i] { 148 + case '"': 149 + return s[:i], s[i+1:] 150 + case '\\': 151 + p := make([]byte, len(s)-1) 152 + j := copy(p, s[:i]) 153 + escape := true 154 + for i = i + 1; i < len(s); i++ { 155 + b := s[i] 156 + switch { 157 + case escape: 158 + escape = false 159 + p[j] = b 160 + j++ 161 + case b == '\\': 162 + escape = true 163 + case b == '"': 164 + return string(p[:j]), s[i+1:] 165 + default: 166 + p[j] = b 167 + j++ 168 + } 169 + } 170 + return "", "" 171 + } 172 + } 173 + return "", "" 174 + } 175 + 176 + // equalASCIIFold returns true if s is equal to t with ASCII case folding as 177 + // defined in RFC 4790. 178 + func equalASCIIFold(s, t string) bool { 179 + for s != "" && t != "" { 180 + sr, size := utf8.DecodeRuneInString(s) 181 + s = s[size:] 182 + tr, size := utf8.DecodeRuneInString(t) 183 + t = t[size:] 184 + if sr == tr { 185 + continue 186 + } 187 + if 'A' <= sr && sr <= 'Z' { 188 + sr = sr + 'a' - 'A' 189 + } 190 + if 'A' <= tr && tr <= 'Z' { 191 + tr = tr + 'a' - 'A' 192 + } 193 + if sr != tr { 194 + return false 195 + } 196 + } 197 + return s == t 198 + } 199 + 200 + // tokenListContainsValue returns true if the 1#token header with the given 201 + // name contains a token equal to value with ASCII case folding. 202 + func tokenListContainsValue(header http.Header, name string, value string) bool { 203 + headers: 204 + for _, s := range header[name] { 205 + for { 206 + var t string 207 + t, s = nextToken(skipSpace(s)) 208 + if t == "" { 209 + continue headers 210 + } 211 + s = skipSpace(s) 212 + if s != "" && s[0] != ',' { 213 + continue headers 214 + } 215 + if equalASCIIFold(t, value) { 216 + return true 217 + } 218 + if s == "" { 219 + continue headers 220 + } 221 + s = s[1:] 222 + } 223 + } 224 + return false 225 + } 226 + 227 + // parseExtensions parses WebSocket extensions from a header. 228 + func parseExtensions(header http.Header) []map[string]string { 229 + // From RFC 6455: 230 + // 231 + // Sec-WebSocket-Extensions = extension-list 232 + // extension-list = 1#extension 233 + // extension = extension-token *( ";" extension-param ) 234 + // extension-token = registered-token 235 + // registered-token = token 236 + // extension-param = token [ "=" (token | quoted-string) ] 237 + // ;When using the quoted-string syntax variant, the value 238 + // ;after quoted-string unescaping MUST conform to the 239 + // ;'token' ABNF. 240 + 241 + var result []map[string]string 242 + headers: 243 + for _, s := range header["Sec-Websocket-Extensions"] { 244 + for { 245 + var t string 246 + t, s = nextToken(skipSpace(s)) 247 + if t == "" { 248 + continue headers 249 + } 250 + ext := map[string]string{"": t} 251 + for { 252 + s = skipSpace(s) 253 + if !strings.HasPrefix(s, ";") { 254 + break 255 + } 256 + var k string 257 + k, s = nextToken(skipSpace(s[1:])) 258 + if k == "" { 259 + continue headers 260 + } 261 + s = skipSpace(s) 262 + var v string 263 + if strings.HasPrefix(s, "=") { 264 + v, s = nextTokenOrQuoted(skipSpace(s[1:])) 265 + s = skipSpace(s) 266 + } 267 + if s != "" && s[0] != ',' && s[0] != ';' { 268 + continue headers 269 + } 270 + ext[k] = v 271 + } 272 + if s != "" && s[0] != ',' { 273 + continue headers 274 + } 275 + result = append(result, ext) 276 + if s == "" { 277 + continue headers 278 + } 279 + s = s[1:] 280 + } 281 + } 282 + return result 283 + } 284 + 285 + // isValidChallengeKey checks if the argument meets RFC6455 specification. 286 + func isValidChallengeKey(s string) bool { 287 + // From RFC6455: 288 + // 289 + // A |Sec-WebSocket-Key| header field with a base64-encoded (see 290 + // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in 291 + // length. 292 + 293 + if s == "" { 294 + return false 295 + } 296 + decoded, err := base64.StdEncoding.DecodeString(s) 297 + return err == nil && len(decoded) == 16 298 + }
+473
vendor/github.com/gorilla/websocket/x_net_proxy.go
··· 1 + // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. 2 + //go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy 3 + 4 + // Package proxy provides support for a variety of protocols to proxy network 5 + // data. 6 + // 7 + 8 + package websocket 9 + 10 + import ( 11 + "errors" 12 + "io" 13 + "net" 14 + "net/url" 15 + "os" 16 + "strconv" 17 + "strings" 18 + "sync" 19 + ) 20 + 21 + type proxy_direct struct{} 22 + 23 + // Direct is a direct proxy: one that makes network connections directly. 24 + var proxy_Direct = proxy_direct{} 25 + 26 + func (proxy_direct) Dial(network, addr string) (net.Conn, error) { 27 + return net.Dial(network, addr) 28 + } 29 + 30 + // A PerHost directs connections to a default Dialer unless the host name 31 + // requested matches one of a number of exceptions. 32 + type proxy_PerHost struct { 33 + def, bypass proxy_Dialer 34 + 35 + bypassNetworks []*net.IPNet 36 + bypassIPs []net.IP 37 + bypassZones []string 38 + bypassHosts []string 39 + } 40 + 41 + // NewPerHost returns a PerHost Dialer that directs connections to either 42 + // defaultDialer or bypass, depending on whether the connection matches one of 43 + // the configured rules. 44 + func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { 45 + return &proxy_PerHost{ 46 + def: defaultDialer, 47 + bypass: bypass, 48 + } 49 + } 50 + 51 + // Dial connects to the address addr on the given network through either 52 + // defaultDialer or bypass. 53 + func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { 54 + host, _, err := net.SplitHostPort(addr) 55 + if err != nil { 56 + return nil, err 57 + } 58 + 59 + return p.dialerForRequest(host).Dial(network, addr) 60 + } 61 + 62 + func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { 63 + if ip := net.ParseIP(host); ip != nil { 64 + for _, net := range p.bypassNetworks { 65 + if net.Contains(ip) { 66 + return p.bypass 67 + } 68 + } 69 + for _, bypassIP := range p.bypassIPs { 70 + if bypassIP.Equal(ip) { 71 + return p.bypass 72 + } 73 + } 74 + return p.def 75 + } 76 + 77 + for _, zone := range p.bypassZones { 78 + if strings.HasSuffix(host, zone) { 79 + return p.bypass 80 + } 81 + if host == zone[1:] { 82 + // For a zone ".example.com", we match "example.com" 83 + // too. 84 + return p.bypass 85 + } 86 + } 87 + for _, bypassHost := range p.bypassHosts { 88 + if bypassHost == host { 89 + return p.bypass 90 + } 91 + } 92 + return p.def 93 + } 94 + 95 + // AddFromString parses a string that contains comma-separated values 96 + // specifying hosts that should use the bypass proxy. Each value is either an 97 + // IP address, a CIDR range, a zone (*.example.com) or a host name 98 + // (localhost). A best effort is made to parse the string and errors are 99 + // ignored. 100 + func (p *proxy_PerHost) AddFromString(s string) { 101 + hosts := strings.Split(s, ",") 102 + for _, host := range hosts { 103 + host = strings.TrimSpace(host) 104 + if len(host) == 0 { 105 + continue 106 + } 107 + if strings.Contains(host, "/") { 108 + // We assume that it's a CIDR address like 127.0.0.0/8 109 + if _, net, err := net.ParseCIDR(host); err == nil { 110 + p.AddNetwork(net) 111 + } 112 + continue 113 + } 114 + if ip := net.ParseIP(host); ip != nil { 115 + p.AddIP(ip) 116 + continue 117 + } 118 + if strings.HasPrefix(host, "*.") { 119 + p.AddZone(host[1:]) 120 + continue 121 + } 122 + p.AddHost(host) 123 + } 124 + } 125 + 126 + // AddIP specifies an IP address that will use the bypass proxy. Note that 127 + // this will only take effect if a literal IP address is dialed. A connection 128 + // to a named host will never match an IP. 129 + func (p *proxy_PerHost) AddIP(ip net.IP) { 130 + p.bypassIPs = append(p.bypassIPs, ip) 131 + } 132 + 133 + // AddNetwork specifies an IP range that will use the bypass proxy. Note that 134 + // this will only take effect if a literal IP address is dialed. A connection 135 + // to a named host will never match. 136 + func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { 137 + p.bypassNetworks = append(p.bypassNetworks, net) 138 + } 139 + 140 + // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of 141 + // "example.com" matches "example.com" and all of its subdomains. 142 + func (p *proxy_PerHost) AddZone(zone string) { 143 + if strings.HasSuffix(zone, ".") { 144 + zone = zone[:len(zone)-1] 145 + } 146 + if !strings.HasPrefix(zone, ".") { 147 + zone = "." + zone 148 + } 149 + p.bypassZones = append(p.bypassZones, zone) 150 + } 151 + 152 + // AddHost specifies a host name that will use the bypass proxy. 153 + func (p *proxy_PerHost) AddHost(host string) { 154 + if strings.HasSuffix(host, ".") { 155 + host = host[:len(host)-1] 156 + } 157 + p.bypassHosts = append(p.bypassHosts, host) 158 + } 159 + 160 + // A Dialer is a means to establish a connection. 161 + type proxy_Dialer interface { 162 + // Dial connects to the given address via the proxy. 163 + Dial(network, addr string) (c net.Conn, err error) 164 + } 165 + 166 + // Auth contains authentication parameters that specific Dialers may require. 167 + type proxy_Auth struct { 168 + User, Password string 169 + } 170 + 171 + // FromEnvironment returns the dialer specified by the proxy related variables in 172 + // the environment. 173 + func proxy_FromEnvironment() proxy_Dialer { 174 + allProxy := proxy_allProxyEnv.Get() 175 + if len(allProxy) == 0 { 176 + return proxy_Direct 177 + } 178 + 179 + proxyURL, err := url.Parse(allProxy) 180 + if err != nil { 181 + return proxy_Direct 182 + } 183 + proxy, err := proxy_FromURL(proxyURL, proxy_Direct) 184 + if err != nil { 185 + return proxy_Direct 186 + } 187 + 188 + noProxy := proxy_noProxyEnv.Get() 189 + if len(noProxy) == 0 { 190 + return proxy 191 + } 192 + 193 + perHost := proxy_NewPerHost(proxy, proxy_Direct) 194 + perHost.AddFromString(noProxy) 195 + return perHost 196 + } 197 + 198 + // proxySchemes is a map from URL schemes to a function that creates a Dialer 199 + // from a URL with such a scheme. 200 + var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) 201 + 202 + // RegisterDialerType takes a URL scheme and a function to generate Dialers from 203 + // a URL with that scheme and a forwarding Dialer. Registered schemes are used 204 + // by FromURL. 205 + func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { 206 + if proxy_proxySchemes == nil { 207 + proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) 208 + } 209 + proxy_proxySchemes[scheme] = f 210 + } 211 + 212 + // FromURL returns a Dialer given a URL specification and an underlying 213 + // Dialer for it to make network requests. 214 + func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { 215 + var auth *proxy_Auth 216 + if u.User != nil { 217 + auth = new(proxy_Auth) 218 + auth.User = u.User.Username() 219 + if p, ok := u.User.Password(); ok { 220 + auth.Password = p 221 + } 222 + } 223 + 224 + switch u.Scheme { 225 + case "socks5": 226 + return proxy_SOCKS5("tcp", u.Host, auth, forward) 227 + } 228 + 229 + // If the scheme doesn't match any of the built-in schemes, see if it 230 + // was registered by another package. 231 + if proxy_proxySchemes != nil { 232 + if f, ok := proxy_proxySchemes[u.Scheme]; ok { 233 + return f(u, forward) 234 + } 235 + } 236 + 237 + return nil, errors.New("proxy: unknown scheme: " + u.Scheme) 238 + } 239 + 240 + var ( 241 + proxy_allProxyEnv = &proxy_envOnce{ 242 + names: []string{"ALL_PROXY", "all_proxy"}, 243 + } 244 + proxy_noProxyEnv = &proxy_envOnce{ 245 + names: []string{"NO_PROXY", "no_proxy"}, 246 + } 247 + ) 248 + 249 + // envOnce looks up an environment variable (optionally by multiple 250 + // names) once. It mitigates expensive lookups on some platforms 251 + // (e.g. Windows). 252 + // (Borrowed from net/http/transport.go) 253 + type proxy_envOnce struct { 254 + names []string 255 + once sync.Once 256 + val string 257 + } 258 + 259 + func (e *proxy_envOnce) Get() string { 260 + e.once.Do(e.init) 261 + return e.val 262 + } 263 + 264 + func (e *proxy_envOnce) init() { 265 + for _, n := range e.names { 266 + e.val = os.Getenv(n) 267 + if e.val != "" { 268 + return 269 + } 270 + } 271 + } 272 + 273 + // SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address 274 + // with an optional username and password. See RFC 1928 and RFC 1929. 275 + func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { 276 + s := &proxy_socks5{ 277 + network: network, 278 + addr: addr, 279 + forward: forward, 280 + } 281 + if auth != nil { 282 + s.user = auth.User 283 + s.password = auth.Password 284 + } 285 + 286 + return s, nil 287 + } 288 + 289 + type proxy_socks5 struct { 290 + user, password string 291 + network, addr string 292 + forward proxy_Dialer 293 + } 294 + 295 + const proxy_socks5Version = 5 296 + 297 + const ( 298 + proxy_socks5AuthNone = 0 299 + proxy_socks5AuthPassword = 2 300 + ) 301 + 302 + const proxy_socks5Connect = 1 303 + 304 + const ( 305 + proxy_socks5IP4 = 1 306 + proxy_socks5Domain = 3 307 + proxy_socks5IP6 = 4 308 + ) 309 + 310 + var proxy_socks5Errors = []string{ 311 + "", 312 + "general failure", 313 + "connection forbidden", 314 + "network unreachable", 315 + "host unreachable", 316 + "connection refused", 317 + "TTL expired", 318 + "command not supported", 319 + "address type not supported", 320 + } 321 + 322 + // Dial connects to the address addr on the given network via the SOCKS5 proxy. 323 + func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { 324 + switch network { 325 + case "tcp", "tcp6", "tcp4": 326 + default: 327 + return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) 328 + } 329 + 330 + conn, err := s.forward.Dial(s.network, s.addr) 331 + if err != nil { 332 + return nil, err 333 + } 334 + if err := s.connect(conn, addr); err != nil { 335 + conn.Close() 336 + return nil, err 337 + } 338 + return conn, nil 339 + } 340 + 341 + // connect takes an existing connection to a socks5 proxy server, 342 + // and commands the server to extend that connection to target, 343 + // which must be a canonical address with a host and port. 344 + func (s *proxy_socks5) connect(conn net.Conn, target string) error { 345 + host, portStr, err := net.SplitHostPort(target) 346 + if err != nil { 347 + return err 348 + } 349 + 350 + port, err := strconv.Atoi(portStr) 351 + if err != nil { 352 + return errors.New("proxy: failed to parse port number: " + portStr) 353 + } 354 + if port < 1 || port > 0xffff { 355 + return errors.New("proxy: port number out of range: " + portStr) 356 + } 357 + 358 + // the size here is just an estimate 359 + buf := make([]byte, 0, 6+len(host)) 360 + 361 + buf = append(buf, proxy_socks5Version) 362 + if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { 363 + buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) 364 + } else { 365 + buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) 366 + } 367 + 368 + if _, err := conn.Write(buf); err != nil { 369 + return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) 370 + } 371 + 372 + if _, err := io.ReadFull(conn, buf[:2]); err != nil { 373 + return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 374 + } 375 + if buf[0] != 5 { 376 + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) 377 + } 378 + if buf[1] == 0xff { 379 + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") 380 + } 381 + 382 + // See RFC 1929 383 + if buf[1] == proxy_socks5AuthPassword { 384 + buf = buf[:0] 385 + buf = append(buf, 1 /* password protocol version */) 386 + buf = append(buf, uint8(len(s.user))) 387 + buf = append(buf, s.user...) 388 + buf = append(buf, uint8(len(s.password))) 389 + buf = append(buf, s.password...) 390 + 391 + if _, err := conn.Write(buf); err != nil { 392 + return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) 393 + } 394 + 395 + if _, err := io.ReadFull(conn, buf[:2]); err != nil { 396 + return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 397 + } 398 + 399 + if buf[1] != 0 { 400 + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") 401 + } 402 + } 403 + 404 + buf = buf[:0] 405 + buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) 406 + 407 + if ip := net.ParseIP(host); ip != nil { 408 + if ip4 := ip.To4(); ip4 != nil { 409 + buf = append(buf, proxy_socks5IP4) 410 + ip = ip4 411 + } else { 412 + buf = append(buf, proxy_socks5IP6) 413 + } 414 + buf = append(buf, ip...) 415 + } else { 416 + if len(host) > 255 { 417 + return errors.New("proxy: destination host name too long: " + host) 418 + } 419 + buf = append(buf, proxy_socks5Domain) 420 + buf = append(buf, byte(len(host))) 421 + buf = append(buf, host...) 422 + } 423 + buf = append(buf, byte(port>>8), byte(port)) 424 + 425 + if _, err := conn.Write(buf); err != nil { 426 + return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) 427 + } 428 + 429 + if _, err := io.ReadFull(conn, buf[:4]); err != nil { 430 + return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 431 + } 432 + 433 + failure := "unknown error" 434 + if int(buf[1]) < len(proxy_socks5Errors) { 435 + failure = proxy_socks5Errors[buf[1]] 436 + } 437 + 438 + if len(failure) > 0 { 439 + return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) 440 + } 441 + 442 + bytesToDiscard := 0 443 + switch buf[3] { 444 + case proxy_socks5IP4: 445 + bytesToDiscard = net.IPv4len 446 + case proxy_socks5IP6: 447 + bytesToDiscard = net.IPv6len 448 + case proxy_socks5Domain: 449 + _, err := io.ReadFull(conn, buf[:1]) 450 + if err != nil { 451 + return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 452 + } 453 + bytesToDiscard = int(buf[0]) 454 + default: 455 + return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) 456 + } 457 + 458 + if cap(buf) < bytesToDiscard { 459 + buf = make([]byte, bytesToDiscard) 460 + } else { 461 + buf = buf[:bytesToDiscard] 462 + } 463 + if _, err := io.ReadFull(conn, buf); err != nil { 464 + return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 465 + } 466 + 467 + // Also need to discard the port number 468 + if _, err := io.ReadFull(conn, buf[:2]); err != nil { 469 + return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) 470 + } 471 + 472 + return nil 473 + }
+3
vendor/modules.txt
··· 70 70 # github.com/go-logfmt/logfmt v0.6.0 71 71 ## explicit; go 1.17 72 72 github.com/go-logfmt/logfmt 73 + # github.com/gorilla/websocket v1.5.3 74 + ## explicit; go 1.12 75 + github.com/gorilla/websocket 73 76 # github.com/kr/fs v0.1.0 74 77 ## explicit 75 78 github.com/kr/fs