this repo has no description
2
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 995 lines 28 kB view raw
1package main 2 3import ( 4 "database/sql" 5 "encoding/base64" 6 "encoding/json" 7 "fmt" 8 "log" 9 "net/http" 10 "os" 11 "strings" 12 "sync" 13 "time" 14 15 "github.com/gorilla/websocket" 16 sqlite3 "github.com/mattn/go-sqlite3" 17) 18 19var upgrader = websocket.Upgrader{ 20 // Allow all cross-origin connections 21 CheckOrigin: func(r *http.Request) bool { 22 return true 23 }, 24} 25 26// Environment configuration 27var requireAuth = os.Getenv("REQUIRE_AUTH") == "true" 28 29// Room management 30type Room struct { 31 clients map[*websocket.Conn]bool 32 mu sync.Mutex 33} 34 35var rooms = make(map[string]*Room) 36var roomsMu sync.Mutex 37 38// Get or create a room 39func getRoom(roomID string) *Room { 40 roomsMu.Lock() 41 defer roomsMu.Unlock() 42 43 if room, exists := rooms[roomID]; exists { 44 return room 45 } 46 47 room := &Room{ 48 clients: make(map[*websocket.Conn]bool), 49 } 50 rooms[roomID] = room 51 return room 52} 53 54// Add client to room 55func addClientToRoom(roomID string, conn *websocket.Conn) { 56 room := getRoom(roomID) 57 58 room.mu.Lock() 59 defer room.mu.Unlock() 60 61 room.clients[conn] = true 62 log.Printf("Client added to room %s. Total clients: %d", roomID, len(room.clients)) 63} 64 65// Remove client from room 66func removeClientFromRoom(roomID string, conn *websocket.Conn) { 67 roomsMu.Lock() 68 room, exists := rooms[roomID] 69 roomsMu.Unlock() 70 71 if !exists { 72 return 73 } 74 75 room.mu.Lock() 76 defer room.mu.Unlock() 77 78 delete(room.clients, conn) 79 log.Printf("Client removed from room %s. Remaining clients: %d", roomID, len(room.clients)) 80 81 // Clean up empty rooms 82 if len(room.clients) == 0 { 83 roomsMu.Lock() 84 delete(rooms, roomID) 85 roomsMu.Unlock() 86 log.Printf("Room %s removed (empty)", roomID) 87 } 88} 89 90// Broadcast changes to all clients in a room except sender 91func broadcastToRoom(roomID string, sender *websocket.Conn, message []byte) { 92 roomsMu.Lock() 93 room, exists := rooms[roomID] 94 roomsMu.Unlock() 95 96 if !exists { 97 log.Printf("Room %s not found for broadcasting", roomID) 98 return 99 } 100 101 room.mu.Lock() 102 defer room.mu.Unlock() 103 104 clientCount := 0 105 for client := range room.clients { 106 if client != sender { 107 clientCount++ 108 err := client.WriteMessage(websocket.TextMessage, message) 109 if err != nil { 110 log.Printf("Error broadcasting to client: %v", err) 111 } 112 } 113 } 114 log.Printf("Broadcasted changes to %d clients in room %s", clientCount, roomID) 115} 116 117func handleWebSocket(w http.ResponseWriter, r *http.Request) { 118 // Set CORS headers for the WebSocket handshake 119 w.Header().Set("Access-Control-Allow-Origin", "*") 120 121 // Extract room ID and publicKey from query parameters 122 roomID := r.URL.Query().Get("room") 123 publicKey := r.URL.Query().Get("publicKey") 124 125 if roomID == "" { 126 http.Error(w, "Missing room parameter", http.StatusBadRequest) 127 return 128 } 129 130 if publicKey == "" { 131 http.Error(w, "Missing publicKey parameter", http.StatusBadRequest) 132 return 133 } 134 135 // Get or create the room in the auth database 136 err := GetOrCreateRoom(roomID) 137 if err != nil { 138 log.Printf("Error with auth database for room %s: %v", roomID, err) 139 http.Error(w, "Server error", http.StatusInternalServerError) 140 return 141 } 142 143 // Determine access level 144 access := determineAccess(roomID, publicKey) 145 146 // Upgrade WebSocket 147 conn, err := upgrader.Upgrade(w, r, nil) 148 if err != nil { 149 log.Println("Error upgrading connection:", err) 150 return 151 } 152 153 // Send immediate room status 154 sendRoomStatus(conn, access) 155 156 // Close connection if no read access 157 if access == "none" || access == "no_room" { 158 log.Printf("Closing connection for user with access: %s", access) 159 conn.Close() 160 return 161 } 162 163 // Add client to room only if authorized 164 addClientToRoom(roomID, conn) 165 166 // Create database connection for this room 167 dbPath := "./rooms/" + roomID + ".db" 168 db, err := sql.Open("sqlite3_with_extensions", dbPath) 169 if err != nil { 170 log.Println("Error opening database:", err) 171 conn.Close() 172 removeClientFromRoom(roomID, conn) 173 return 174 } 175 defer db.Close() 176 177 // Ensure the database has the necessary tables 178 setupDatabase(db) 179 180 // Set up ping/pong to keep connection alive 181 conn.SetPingHandler(func(string) error { 182 conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(time.Second)) 183 return nil 184 }) 185 186 // Clean up when connection closes 187 defer func() { 188 conn.Close() 189 removeClientFromRoom(roomID, conn) 190 }() 191 192 // Handle incoming messages 193 for { 194 _, message, err := conn.ReadMessage() 195 if err != nil { 196 log.Println("Read error:", err) 197 break 198 } 199 200 var msg map[string]interface{} 201 if err := json.Unmarshal(message, &msg); err != nil { 202 log.Println("Error parsing message:", err) 203 continue 204 } 205 206 msgType, ok := msg["type"].(string) 207 if !ok { 208 log.Println("Message missing 'type' field") 209 continue 210 } 211 212 switch msgType { 213 case "sync_request": 214 log.Printf("Client in room %s requested sync", roomID) 215 var syncMsg SyncRequestMessage 216 syncData, _ := json.Marshal(msg) 217 json.Unmarshal(syncData, &syncMsg) 218 219 log.Printf("Sync request with site_id: %s, contiguous_up_to: %d.%d, max_version_seen: %d.%d", 220 syncMsg.SiteID, syncMsg.ContiguousUpTo.DBVersion, syncMsg.ContiguousUpTo.ColVersion, 221 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion) 222 223 // Verify publicKey matches 224 if syncMsg.PublicKey != publicKey { 225 log.Printf("PublicKey mismatch in sync request") 226 sendError(conn, "AUTH_FAILED", "PublicKey mismatch") 227 continue 228 } 229 230 // Get current server version pair 231 serverMaxVersionPair, err := getLatestDBVersionCol(db) 232 if err != nil { 233 log.Printf("Error getting server's latest version pair: %v", err) 234 serverMaxVersionPair = VersionColPair{DBVersion: 0, ColVersion: 0} 235 } 236 log.Printf("Server max version: %d.%d, client max version seen: %d.%d", 237 serverMaxVersionPair.DBVersion, serverMaxVersionPair.ColVersion, 238 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion) 239 240 // Check if client has newer data than server 241 if compareVersionPairs(syncMsg.MaxVersionSeen, serverMaxVersionPair) > 0 { 242 log.Printf("Client has higher version (%d.%d) than server (%d.%d). Requesting changes.", 243 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion, 244 serverMaxVersionPair.DBVersion, serverMaxVersionPair.ColVersion) 245 246 // Send a request_changes message to the client 247 requestChangesMsg := RequestChangesMessage{ 248 Type: "request_changes", 249 RoomID: roomID, 250 Version: serverMaxVersionPair, 251 } 252 253 requestJSON, _ := json.Marshal(requestChangesMsg) 254 conn.WriteMessage(websocket.TextMessage, requestJSON) 255 } 256 257 // Always send sync_response with any changes the server has for the client 258 changes := getChangesForSyncRequest(db, syncMsg) 259 260 response := SyncResponseMessage{ 261 Type: "sync_response", 262 CurrentMaxVersion: serverMaxVersionPair, 263 Changes: changes, 264 } 265 responseJSON, _ := json.Marshal(response) 266 conn.WriteMessage(websocket.TextMessage, responseJSON) 267 268 case "pull": 269 // Legacy support for v1 protocol 270 log.Printf("Client in room %s requested pull (legacy)", roomID) 271 var pullMsg PullMessage 272 pullData, _ := json.Marshal(msg) 273 json.Unmarshal(pullData, &pullMsg) 274 275 log.Printf("Pull request with site_id: %s, version: %d", pullMsg.SiteID, pullMsg.Version) 276 277 // Get the server's latest db_version 278 serverLatestVersion, err := getLatestDBVersion(db) 279 if err != nil { 280 log.Printf("Error getting server's latest db_version: %v", err) 281 serverLatestVersion = 0 282 } 283 284 // Check if client has a higher version than the server 285 if pullMsg.Version > serverLatestVersion { 286 log.Printf("Client has higher version (%d) than server (%d). Requesting changes.", 287 pullMsg.Version, serverLatestVersion) 288 289 // Send a request_changes message to the client 290 requestChangesMsg := RequestChangesMessage{ 291 Type: "request_changes", 292 RoomID: roomID, 293 Version: VersionColPair{DBVersion: serverLatestVersion, ColVersion: 0}, 294 } 295 296 requestJSON, _ := json.Marshal(requestChangesMsg) 297 conn.WriteMessage(websocket.TextMessage, requestJSON) 298 } else { 299 // Only send changes if we're newer 300 changes := getChangesFromDB(db, pullMsg.SiteID, pullMsg.Version) 301 response := map[string]interface{}{ 302 "type": "changes", 303 "data": changes, 304 } 305 responseJSON, _ := json.Marshal(response) 306 conn.WriteMessage(websocket.TextMessage, responseJSON) 307 } 308 309 case "changes": 310 log.Printf("Received changes from client in room %s", roomID) 311 312 // Check write permission 313 if access != "write" { 314 log.Printf("User has no write permission for room %s", roomID) 315 sendError(conn, "PERMISSION_DENIED", "No write permission") 316 continue 317 } 318 319 // For signed changes, verify the signature 320 if msgPublicKey, hasKey := msg["publicKey"].(string); hasKey { 321 log.Printf("Changes are authenticated with public key: %s...", msgPublicKey[:20]) 322 323 // Verify publicKey matches connection 324 if msgPublicKey != publicKey { 325 log.Printf("PublicKey mismatch in changes message") 326 sendError(conn, "AUTH_FAILED", "PublicKey mismatch") 327 continue 328 } 329 330 // TODO: Verify signature when ECDSA is implemented 331 if signature, hasSig := msg["signature"].(string); hasSig && requireAuth { 332 dataStr := "" 333 if data, ok := msg["data"]; ok { 334 dataBytes, _ := json.Marshal(data) 335 dataStr = string(dataBytes) 336 } 337 338 isValid, err := VerifySignature(msgPublicKey, "changes:"+dataStr, signature) 339 if err != nil || !isValid { 340 log.Printf("Signature verification failed") 341 sendError(conn, "AUTH_FAILED", "Invalid signature") 342 continue 343 } 344 } 345 } 346 347 if data, ok := msg["data"].([]interface{}); ok { 348 log.Printf("Processing %d changes from client", len(data)) 349 applyChangesToDB(db, data) 350 351 log.Printf("Broadcasting changes to other clients in room %s", roomID) 352 broadcastToRoom(roomID, conn, message) 353 log.Printf("Broadcast completed for room %s", roomID) 354 } 355 } 356 } 357} 358 359func setupDatabase(db *sql.DB) { 360 // Create todos table if it doesn't exist 361 _, err := db.Exec(` 362 CREATE TABLE IF NOT EXISTS todos ( 363 id BLOB PRIMARY KEY NOT NULL, 364 description TEXT, 365 project text, 366 tags text, 367 due text, 368 wait text, 369 priority text, 370 urgency real, 371 completed INTEGER NOT NULL DEFAULT 0 372 ); 373 SELECT crsql_as_crr('todos'); 374 `) 375 if err != nil { 376 log.Println("Error setting up database:", err) 377 } 378} 379 380func sendInitialChanges(conn *websocket.Conn, db *sql.DB) { 381 // Query all changes 382 rows, err := db.Query("SELECT * FROM crsql_changes") 383 if err != nil { 384 log.Println("Error querying changes:", err) 385 return 386 } 387 defer rows.Close() 388 389 var changes []map[string]interface{} 390 391 for rows.Next() { 392 var tableName string 393 var pk []byte 394 var columnName string 395 var value interface{} 396 var colVersion, dbVersion int64 397 var siteID []byte 398 var cl, seq int64 399 400 if err := rows.Scan(&tableName, &pk, &columnName, &value, &colVersion, &dbVersion, &siteID, &cl, &seq); err != nil { 401 log.Println("Error scanning row:", err) 402 continue 403 } 404 405 change := map[string]interface{}{ 406 "TableName": tableName, 407 "PK": encodeToBase64(pk), 408 "ColumnName": columnName, 409 "Value": value, 410 "ColVersion": colVersion, 411 "DBVersion": dbVersion, 412 "SiteID": encodeToBase64(siteID), 413 "CL": cl, 414 "Seq": seq, 415 } 416 417 changes = append(changes, change) 418 } 419 420 response := map[string]interface{}{ 421 "type": "changes", 422 "data": changes, 423 } 424 425 responseJSON, _ := json.Marshal(response) 426 conn.WriteMessage(websocket.TextMessage, responseJSON) 427} 428 429// getLatestDBVersion gets the latest db_version from the database (legacy) 430func getLatestDBVersion(db *sql.DB) (int, error) { 431 // Query the maximum db_version 432 row := db.QueryRow("SELECT MAX(db_version) FROM crsql_changes") 433 434 var version sql.NullInt64 435 if err := row.Scan(&version); err != nil { 436 return 0, err 437 } 438 439 // If there are no rows or the value is null, return 0 440 if !version.Valid { 441 return 0, nil 442 } 443 444 return int(version.Int64), nil 445} 446 447// getLatestDBVersionCol gets the latest (db_version, col_version) pair from the database 448func getLatestDBVersionCol(db *sql.DB) (VersionColPair, error) { 449 // First check if there are any changes at all 450 var count int 451 err := db.QueryRow("SELECT COUNT(*) FROM crsql_changes").Scan(&count) 452 if err != nil { 453 return VersionColPair{DBVersion: 0, ColVersion: 0}, err 454 } 455 456 // If no changes exist, return 0,0 457 if count == 0 { 458 return VersionColPair{DBVersion: 0, ColVersion: 0}, nil 459 } 460 461 // Query the maximum db_version and its maximum col_version 462 row := db.QueryRow(` 463 SELECT db_version, col_version 464 FROM crsql_changes 465 ORDER BY db_version DESC, col_version DESC 466 LIMIT 1 467 `) 468 469 var dbVersion, colVersion sql.NullInt64 470 if err := row.Scan(&dbVersion, &colVersion); err != nil { 471 if err == sql.ErrNoRows { 472 return VersionColPair{DBVersion: 0, ColVersion: 0}, nil 473 } 474 return VersionColPair{DBVersion: 0, ColVersion: 0}, err 475 } 476 477 // If there are no rows or the values are null, return 0,0 478 if !dbVersion.Valid || !colVersion.Valid { 479 return VersionColPair{DBVersion: 0, ColVersion: 0}, nil 480 } 481 482 return VersionColPair{DBVersion: int(dbVersion.Int64), ColVersion: int(colVersion.Int64)}, nil 483} 484 485// compareVersionPairs compares two version pairs, returns -1, 0, or 1 486func compareVersionPairs(a, b VersionColPair) int { 487 if a.DBVersion != b.DBVersion { 488 if a.DBVersion < b.DBVersion { 489 return -1 490 } 491 return 1 492 } 493 if a.ColVersion != b.ColVersion { 494 if a.ColVersion < b.ColVersion { 495 return -1 496 } 497 return 1 498 } 499 return 0 500} 501 502func getChangesFromDB(db *sql.DB, siteID string, version int) []map[string]interface{} { 503 // Decode the site_id from base64 504 decodedSiteID, err := decodeBase64(siteID) 505 if err != nil { 506 log.Printf("Error decoding site_id '%s': %v", siteID, err) 507 decodedSiteID = []byte{} 508 } 509 510 log.Printf("Querying for changes with site_id != %v AND db_version > %d", decodedSiteID, version) 511 query := "SELECT * FROM crsql_changes WHERE site_id != ? AND db_version > ?" 512 513 rows, err := db.Query(query, decodedSiteID, version) 514 if err != nil { 515 log.Println("Error querying changes:", err) 516 return nil 517 } 518 defer rows.Close() 519 520 var changes []map[string]interface{} 521 522 for rows.Next() { 523 var tableName string 524 var pk []byte 525 var columnName string 526 var value interface{} 527 var colVersion, dbVersion int64 528 var siteID []byte 529 var cl, seq int64 530 531 if err := rows.Scan(&tableName, &pk, &columnName, &value, &colVersion, &dbVersion, &siteID, &cl, &seq); err != nil { 532 log.Println("Error scanning row:", err) 533 continue 534 } 535 536 change := map[string]interface{}{ 537 "TableName": tableName, 538 "PK": encodeToBase64(pk), 539 "ColumnName": columnName, 540 "Value": value, 541 "ColVersion": colVersion, 542 "DBVersion": dbVersion, 543 "SiteID": encodeToBase64(siteID), 544 "CL": cl, 545 "Seq": seq, 546 } 547 548 changes = append(changes, change) 549 } 550 551 log.Printf("Returning %d changes to client for sync", len(changes)) 552 return changes 553} 554 555func applyChangesToDB(db *sql.DB, changes []interface{}) { 556 tx, err := db.Begin() 557 if err != nil { 558 log.Println("Error starting transaction:", err) 559 return 560 } 561 562 for _, change := range changes { 563 if changeMap, ok := change.(map[string]interface{}); ok { 564 tableName := changeMap["TableName"].(string) 565 pk, _ := decodeBase64(changeMap["PK"].(string)) 566 columnName := changeMap["ColumnName"].(string) 567 value := changeMap["Value"] 568 colVersion := int64(changeMap["ColVersion"].(float64)) 569 dbVersion := int64(changeMap["DBVersion"].(float64)) 570 siteID, _ := decodeBase64(changeMap["SiteID"].(string)) 571 cl := int64(changeMap["CL"].(float64)) 572 seq := int64(changeMap["Seq"].(float64)) 573 574 _, err := tx.Exec( 575 `INSERT INTO crsql_changes VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, 576 tableName, pk, columnName, value, colVersion, dbVersion, siteID, cl, seq, 577 ) 578 if err != nil { 579 log.Println("Error inserting change:", err) 580 } 581 } 582 } 583 584 if err := tx.Commit(); err != nil { 585 log.Println("Error committing transaction:", err) 586 tx.Rollback() 587 } 588} 589 590func encodeToBase64(data []byte) string { 591 return base64.StdEncoding.EncodeToString(data) 592} 593 594func decodeBase64(encoded string) ([]byte, error) { 595 return base64.StdEncoding.DecodeString(encoded) 596} 597 598// createDirIfNotExists creates a directory if it doesn't already exist 599func createDirIfNotExists(path string) error { 600 // Check if the directory exists 601 if _, err := os.Stat(path); os.IsNotExist(err) { 602 // Directory does not exist, create it 603 return os.MkdirAll(path, 0755) 604 } else if err != nil { 605 // Some other error occurred 606 return err 607 } 608 // Directory already exists 609 return nil 610} 611 612 613// PullMessage represents the structure of a 'pull' type message from clients (legacy v1) 614type PullMessage struct { 615 Type string `json:"type"` 616 SiteID string `json:"site_id"` 617 Version int `json:"version"` 618} 619 620// RequestChangesMessage represents a request from server to client to send their changes 621type RequestChangesMessage struct { 622 Type string `json:"type"` 623 RoomID string `json:"room_id"` 624 Version VersionColPair `json:"version"` 625} 626 627// VersionColPair represents a (db_version, col_version) pair 628type VersionColPair struct { 629 DBVersion int `json:"db_version"` 630 ColVersion int `json:"col_version"` 631} 632 633// MissingVersionRange represents missing col_versions for a specific db_version 634type MissingVersionRange struct { 635 DBVersion int `json:"db_version"` 636 ColVersions []int `json:"col_versions"` 637} 638 639// VersionRange represents a missing version range (legacy - kept for compatibility) 640type VersionRange struct { 641 Start int `json:"start"` 642 End int `json:"end"` 643} 644 645// SyncRequestMessage represents the v2 sync request with version ranges 646type SyncRequestMessage struct { 647 Type string `json:"type"` 648 SiteID string `json:"site_id"` 649 ContiguousUpTo VersionColPair `json:"contiguous_up_to"` 650 MissingRanges []MissingVersionRange `json:"missing_ranges"` 651 MaxVersionSeen VersionColPair `json:"max_version_seen"` 652 PublicKey string `json:"publicKey"` 653} 654 655// SyncResponseMessage represents the server response to sync request 656type SyncResponseMessage struct { 657 Type string `json:"type"` 658 CurrentMaxVersion VersionColPair `json:"current_max_version"` 659 Changes []map[string]interface{} `json:"changes"` 660} 661 662// RoomStatusMessage represents immediate room status on connection 663type RoomStatusMessage struct { 664 Type string `json:"type"` 665 Access string `json:"access"` 666} 667 668// ErrorMessage represents structured error responses 669type ErrorMessage struct { 670 Type string `json:"type"` 671 Code string `json:"code"` 672 Message string `json:"message"` 673} 674 675// AuthRequest is the structure for authentication verification requests 676type AuthRequest struct { 677 RoomID string `json:"roomId"` 678 PublicKey string `json:"publicKey"` 679 Data string `json:"data"` 680 Signature string `json:"signature"` 681} 682 683// AuthResponse is the structure for authentication verification responses 684type AuthResponse struct { 685 Authenticated bool `json:"authenticated"` 686 Message string `json:"message"` 687} 688 689// handleAuthVerify handles authentication verification requests 690func handleAuthVerify(w http.ResponseWriter, r *http.Request) { 691 // Set CORS headers 692 w.Header().Set("Access-Control-Allow-Origin", "*") 693 w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") 694 w.Header().Set("Access-Control-Allow-Headers", "Content-Type") 695 696 // Handle preflight OPTIONS request 697 if r.Method == http.MethodOptions { 698 w.WriteHeader(http.StatusOK) 699 return 700 } 701 702 // Only allow POST requests 703 if r.Method != http.MethodPost { 704 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 705 return 706 } 707 708 // Parse the request body 709 var req AuthRequest 710 decoder := json.NewDecoder(r.Body) 711 if err := decoder.Decode(&req); err != nil { 712 log.Printf("Error decoding auth request: %v", err) 713 http.Error(w, "Invalid request format", http.StatusBadRequest) 714 return 715 } 716 717 // Validate the request 718 if req.RoomID == "" || req.PublicKey == "" || req.Signature == "" || req.Data == "" { 719 http.Error(w, "Room ID, public key, data, and signature are required", http.StatusBadRequest) 720 return 721 } 722 723 // Verify the signature 724 isValid, err := VerifySignature(req.PublicKey, req.Data, req.Signature) 725 if err != nil { 726 log.Printf("Error verifying signature: %v", err) 727 http.Error(w, "Failed to verify signature", http.StatusInternalServerError) 728 return 729 } 730 731 if !isValid { 732 // Return failure response 733 response := AuthResponse{ 734 Authenticated: false, 735 Message: "Signature verification failed", 736 } 737 738 w.Header().Set("Content-Type", "application/json") 739 w.WriteHeader(http.StatusUnauthorized) 740 json.NewEncoder(w).Encode(response) 741 return 742 } 743 744 // Check if the public key has write permission for the room 745 hasWritePerm, err := CheckKeyPermission(req.RoomID, req.PublicKey, "write") 746 if err != nil { 747 log.Printf("Error checking key permission: %v", err) 748 http.Error(w, "Failed to check permissions", http.StatusInternalServerError) 749 return 750 } 751 752 if !hasWritePerm { 753 // Return failure response 754 response := AuthResponse{ 755 Authenticated: false, 756 Message: "Public key doesn't have write permission for this room", 757 } 758 759 w.Header().Set("Content-Type", "application/json") 760 w.WriteHeader(http.StatusForbidden) 761 json.NewEncoder(w).Encode(response) 762 return 763 } 764 765 // Return success response 766 response := AuthResponse{ 767 Authenticated: true, 768 Message: "Authentication successful", 769 } 770 771 w.Header().Set("Content-Type", "application/json") 772 w.WriteHeader(http.StatusOK) 773 json.NewEncoder(w).Encode(response) 774} 775 776// determineAccess determines the access level for a user connecting to a room 777func determineAccess(roomID, publicKey string) string { 778 logKey := publicKey 779 if len(logKey) > 20 { 780 logKey = publicKey[:20] + "..." 781 } 782 log.Printf("Determining access for publicKey %s in room %s", logKey, roomID) 783 784 // Check if room exists 785 roomExists, err := CheckRoomExists(roomID) 786 if err != nil { 787 log.Printf("Error checking room existence: %v", err) 788 return "none" 789 } 790 791 if !roomExists { 792 log.Printf("Room %s does not exist", roomID) 793 if !requireAuth { 794 // Development mode - auto-create room and grant permissions 795 log.Printf("Auto-creating room %s and granting permissions", roomID) 796 if err := GetOrCreateRoom(roomID); err != nil { 797 log.Printf("Error auto-creating room: %v", err) 798 return "no_room" 799 } 800 if err := AutoGrantInvitePermissions(roomID, publicKey); err != nil { 801 log.Printf("Error auto-granting permissions: %v", err) 802 return "no_room" 803 } 804 return "write" 805 } 806 return "no_room" 807 } 808 809 // Room exists - check permissions 810 if !requireAuth { 811 // Development mode - auto-grant permissions for existing rooms 812 log.Printf("Auto-granting permissions for existing room %s", roomID) 813 if err := AutoGrantInvitePermissions(roomID, publicKey); err != nil { 814 log.Printf("Error auto-granting permissions: %v", err) 815 } 816 return "write" 817 } 818 819 // Production mode - check actual permissions 820 hasRead, err := CheckKeyPermission(roomID, publicKey, "read") 821 if err != nil { 822 log.Printf("Error checking read permission: %v", err) 823 return "none" 824 } 825 826 if !hasRead { 827 return "none" 828 } 829 830 hasWrite, err := CheckKeyPermission(roomID, publicKey, "write") 831 if err != nil { 832 log.Printf("Error checking write permission: %v", err) 833 return "read" 834 } 835 836 if hasWrite { 837 return "write" 838 } 839 840 return "read" 841} 842 843// sendRoomStatus sends immediate room status to client 844func sendRoomStatus(conn *websocket.Conn, access string) { 845 status := RoomStatusMessage{ 846 Type: "room_status", 847 Access: access, 848 } 849 850 statusJSON, _ := json.Marshal(status) 851 conn.WriteMessage(websocket.TextMessage, statusJSON) 852 log.Printf("Sent room status: %s", access) 853} 854 855// sendError sends structured error response to client 856func sendError(conn *websocket.Conn, code, message string) { 857 errorMsg := ErrorMessage{ 858 Type: "error", 859 Code: code, 860 Message: message, 861 } 862 863 errorJSON, _ := json.Marshal(errorMsg) 864 conn.WriteMessage(websocket.TextMessage, errorJSON) 865 log.Printf("Sent error: %s - %s", code, message) 866} 867 868// getChangesForSyncRequest gets changes for missing ranges + new changes 869func getChangesForSyncRequest(db *sql.DB, syncMsg SyncRequestMessage) []map[string]interface{} { 870 decodedSiteID, err := decodeBase64(syncMsg.SiteID) 871 if err != nil { 872 log.Printf("Error decoding site_id '%s': %v", syncMsg.SiteID, err) 873 decodedSiteID = []byte{} 874 } 875 876 // Build query for missing (db_version, col_version) pairs + new changes 877 var queryParts []string 878 var args []interface{} 879 880 // Add site_id filter 881 baseCond := "site_id != ?" 882 args = append(args, decodedSiteID) 883 884 // Add missing ranges - for each db_version, get specific missing col_versions 885 for _, vrange := range syncMsg.MissingRanges { 886 if len(vrange.ColVersions) > 0 { 887 // Create placeholders for the col_versions 888 placeholders := make([]string, len(vrange.ColVersions)) 889 for i := range placeholders { 890 placeholders[i] = "?" 891 args = append(args, vrange.ColVersions[i]) 892 } 893 894 // Add condition for this db_version with specific col_versions 895 queryParts = append(queryParts, 896 fmt.Sprintf("(db_version = ? AND col_version IN (%s))", 897 strings.Join(placeholders, ","))) 898 args = append(args, vrange.DBVersion) 899 } 900 } 901 902 // Add new changes beyond max version seen 903 // This includes: db_version > max_db_version OR (db_version = max_db_version AND col_version > max_col_version) 904 queryParts = append(queryParts, 905 "(db_version > ? OR (db_version = ? AND col_version > ?))") 906 args = append(args, syncMsg.MaxVersionSeen.DBVersion, 907 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion) 908 909 // Combine all conditions 910 whereClause := baseCond 911 if len(queryParts) > 0 { 912 whereClause += " AND (" + strings.Join(queryParts, " OR ") + ")" 913 } 914 915 query := "SELECT * FROM crsql_changes WHERE " + whereClause + " ORDER BY db_version ASC, col_version ASC" 916 917 log.Printf("Executing sync query with %d missing ranges and max_version %d.%d", 918 len(syncMsg.MissingRanges), syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion) 919 920 rows, err := db.Query(query, args...) 921 if err != nil { 922 log.Printf("Error querying changes for sync: %v", err) 923 return nil 924 } 925 defer rows.Close() 926 927 var changes []map[string]interface{} 928 929 for rows.Next() { 930 var tableName string 931 var pk []byte 932 var columnName string 933 var value interface{} 934 var colVersion, dbVersion int64 935 var siteID []byte 936 var cl, seq int64 937 938 if err := rows.Scan(&tableName, &pk, &columnName, &value, &colVersion, &dbVersion, &siteID, &cl, &seq); err != nil { 939 log.Printf("Error scanning row: %v", err) 940 continue 941 } 942 943 change := map[string]interface{}{ 944 "TableName": tableName, 945 "PK": encodeToBase64(pk), 946 "ColumnName": columnName, 947 "Value": value, 948 "ColVersion": colVersion, 949 "DBVersion": dbVersion, 950 "SiteID": encodeToBase64(siteID), 951 "CL": cl, 952 "Seq": seq, 953 } 954 955 changes = append(changes, change) 956 } 957 958 log.Printf("Returning %d changes for sync request", len(changes)) 959 return changes 960} 961 962func main() { 963 // Log startup configuration 964 log.Printf("Starting server with REQUIRE_AUTH=%v", requireAuth) 965 966 // Register SQLite with CR-SQLite extension 967 sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{ 968 Extensions: []string{"../db/crsqlite"}, 969 }) 970 971 // Initialize the auth database 972 log.Println("Initializing authentication database...") 973 if err := InitAuthDB(); err != nil { 974 log.Printf("Warning: Failed to initialize auth database: %v", err) 975 } else { 976 log.Println("Authentication database initialized successfully") 977 } 978 // Close auth database connection when the server exits 979 defer CloseAuthDB() 980 981 // Create directory for room databases 982 if err := createDirIfNotExists("./rooms"); err != nil { 983 log.Printf("Warning: Failed to create rooms directory: %v", err) 984 } else { 985 log.Println("Rooms directory created or verified") 986 } 987 988 http.HandleFunc("/sync", handleWebSocket) 989 http.HandleFunc("/auth/verify", handleAuthVerify) 990 http.HandleFunc("/auth/register-key", HandleKeyRegistration) 991 992 log.Println("WebSocket server started on :8080") 993 log.Fatal(http.ListenAndServe(":8080", nil)) 994} 995