this repo has no description
1package main
2
3import (
4 "database/sql"
5 "encoding/base64"
6 "encoding/json"
7 "fmt"
8 "log"
9 "net/http"
10 "os"
11 "strings"
12 "sync"
13 "time"
14
15 "github.com/gorilla/websocket"
16 sqlite3 "github.com/mattn/go-sqlite3"
17)
18
19var upgrader = websocket.Upgrader{
20 // Allow all cross-origin connections
21 CheckOrigin: func(r *http.Request) bool {
22 return true
23 },
24}
25
26// Environment configuration
27var requireAuth = os.Getenv("REQUIRE_AUTH") == "true"
28
29// Room management
30type Room struct {
31 clients map[*websocket.Conn]bool
32 mu sync.Mutex
33}
34
35var rooms = make(map[string]*Room)
36var roomsMu sync.Mutex
37
38// Get or create a room
39func getRoom(roomID string) *Room {
40 roomsMu.Lock()
41 defer roomsMu.Unlock()
42
43 if room, exists := rooms[roomID]; exists {
44 return room
45 }
46
47 room := &Room{
48 clients: make(map[*websocket.Conn]bool),
49 }
50 rooms[roomID] = room
51 return room
52}
53
54// Add client to room
55func addClientToRoom(roomID string, conn *websocket.Conn) {
56 room := getRoom(roomID)
57
58 room.mu.Lock()
59 defer room.mu.Unlock()
60
61 room.clients[conn] = true
62 log.Printf("Client added to room %s. Total clients: %d", roomID, len(room.clients))
63}
64
65// Remove client from room
66func removeClientFromRoom(roomID string, conn *websocket.Conn) {
67 roomsMu.Lock()
68 room, exists := rooms[roomID]
69 roomsMu.Unlock()
70
71 if !exists {
72 return
73 }
74
75 room.mu.Lock()
76 defer room.mu.Unlock()
77
78 delete(room.clients, conn)
79 log.Printf("Client removed from room %s. Remaining clients: %d", roomID, len(room.clients))
80
81 // Clean up empty rooms
82 if len(room.clients) == 0 {
83 roomsMu.Lock()
84 delete(rooms, roomID)
85 roomsMu.Unlock()
86 log.Printf("Room %s removed (empty)", roomID)
87 }
88}
89
90// Broadcast changes to all clients in a room except sender
91func broadcastToRoom(roomID string, sender *websocket.Conn, message []byte) {
92 roomsMu.Lock()
93 room, exists := rooms[roomID]
94 roomsMu.Unlock()
95
96 if !exists {
97 log.Printf("Room %s not found for broadcasting", roomID)
98 return
99 }
100
101 room.mu.Lock()
102 defer room.mu.Unlock()
103
104 clientCount := 0
105 for client := range room.clients {
106 if client != sender {
107 clientCount++
108 err := client.WriteMessage(websocket.TextMessage, message)
109 if err != nil {
110 log.Printf("Error broadcasting to client: %v", err)
111 }
112 }
113 }
114 log.Printf("Broadcasted changes to %d clients in room %s", clientCount, roomID)
115}
116
117func handleWebSocket(w http.ResponseWriter, r *http.Request) {
118 // Set CORS headers for the WebSocket handshake
119 w.Header().Set("Access-Control-Allow-Origin", "*")
120
121 // Extract room ID and publicKey from query parameters
122 roomID := r.URL.Query().Get("room")
123 publicKey := r.URL.Query().Get("publicKey")
124
125 if roomID == "" {
126 http.Error(w, "Missing room parameter", http.StatusBadRequest)
127 return
128 }
129
130 if publicKey == "" {
131 http.Error(w, "Missing publicKey parameter", http.StatusBadRequest)
132 return
133 }
134
135 // Get or create the room in the auth database
136 err := GetOrCreateRoom(roomID)
137 if err != nil {
138 log.Printf("Error with auth database for room %s: %v", roomID, err)
139 http.Error(w, "Server error", http.StatusInternalServerError)
140 return
141 }
142
143 // Determine access level
144 access := determineAccess(roomID, publicKey)
145
146 // Upgrade WebSocket
147 conn, err := upgrader.Upgrade(w, r, nil)
148 if err != nil {
149 log.Println("Error upgrading connection:", err)
150 return
151 }
152
153 // Send immediate room status
154 sendRoomStatus(conn, access)
155
156 // Close connection if no read access
157 if access == "none" || access == "no_room" {
158 log.Printf("Closing connection for user with access: %s", access)
159 conn.Close()
160 return
161 }
162
163 // Add client to room only if authorized
164 addClientToRoom(roomID, conn)
165
166 // Create database connection for this room
167 dbPath := "./rooms/" + roomID + ".db"
168 db, err := sql.Open("sqlite3_with_extensions", dbPath)
169 if err != nil {
170 log.Println("Error opening database:", err)
171 conn.Close()
172 removeClientFromRoom(roomID, conn)
173 return
174 }
175 defer db.Close()
176
177 // Ensure the database has the necessary tables
178 setupDatabase(db)
179
180 // Set up ping/pong to keep connection alive
181 conn.SetPingHandler(func(string) error {
182 conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(time.Second))
183 return nil
184 })
185
186 // Clean up when connection closes
187 defer func() {
188 conn.Close()
189 removeClientFromRoom(roomID, conn)
190 }()
191
192 // Handle incoming messages
193 for {
194 _, message, err := conn.ReadMessage()
195 if err != nil {
196 log.Println("Read error:", err)
197 break
198 }
199
200 var msg map[string]interface{}
201 if err := json.Unmarshal(message, &msg); err != nil {
202 log.Println("Error parsing message:", err)
203 continue
204 }
205
206 msgType, ok := msg["type"].(string)
207 if !ok {
208 log.Println("Message missing 'type' field")
209 continue
210 }
211
212 switch msgType {
213 case "sync_request":
214 log.Printf("Client in room %s requested sync", roomID)
215 var syncMsg SyncRequestMessage
216 syncData, _ := json.Marshal(msg)
217 json.Unmarshal(syncData, &syncMsg)
218
219 log.Printf("Sync request with site_id: %s, contiguous_up_to: %d.%d, max_version_seen: %d.%d",
220 syncMsg.SiteID, syncMsg.ContiguousUpTo.DBVersion, syncMsg.ContiguousUpTo.ColVersion,
221 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion)
222
223 // Verify publicKey matches
224 if syncMsg.PublicKey != publicKey {
225 log.Printf("PublicKey mismatch in sync request")
226 sendError(conn, "AUTH_FAILED", "PublicKey mismatch")
227 continue
228 }
229
230 // Get current server version pair
231 serverMaxVersionPair, err := getLatestDBVersionCol(db)
232 if err != nil {
233 log.Printf("Error getting server's latest version pair: %v", err)
234 serverMaxVersionPair = VersionColPair{DBVersion: 0, ColVersion: 0}
235 }
236 log.Printf("Server max version: %d.%d, client max version seen: %d.%d",
237 serverMaxVersionPair.DBVersion, serverMaxVersionPair.ColVersion,
238 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion)
239
240 // Check if client has newer data than server
241 if compareVersionPairs(syncMsg.MaxVersionSeen, serverMaxVersionPair) > 0 {
242 log.Printf("Client has higher version (%d.%d) than server (%d.%d). Requesting changes.",
243 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion,
244 serverMaxVersionPair.DBVersion, serverMaxVersionPair.ColVersion)
245
246 // Send a request_changes message to the client
247 requestChangesMsg := RequestChangesMessage{
248 Type: "request_changes",
249 RoomID: roomID,
250 Version: serverMaxVersionPair,
251 }
252
253 requestJSON, _ := json.Marshal(requestChangesMsg)
254 conn.WriteMessage(websocket.TextMessage, requestJSON)
255 }
256
257 // Always send sync_response with any changes the server has for the client
258 changes := getChangesForSyncRequest(db, syncMsg)
259
260 response := SyncResponseMessage{
261 Type: "sync_response",
262 CurrentMaxVersion: serverMaxVersionPair,
263 Changes: changes,
264 }
265 responseJSON, _ := json.Marshal(response)
266 conn.WriteMessage(websocket.TextMessage, responseJSON)
267
268 case "pull":
269 // Legacy support for v1 protocol
270 log.Printf("Client in room %s requested pull (legacy)", roomID)
271 var pullMsg PullMessage
272 pullData, _ := json.Marshal(msg)
273 json.Unmarshal(pullData, &pullMsg)
274
275 log.Printf("Pull request with site_id: %s, version: %d", pullMsg.SiteID, pullMsg.Version)
276
277 // Get the server's latest db_version
278 serverLatestVersion, err := getLatestDBVersion(db)
279 if err != nil {
280 log.Printf("Error getting server's latest db_version: %v", err)
281 serverLatestVersion = 0
282 }
283
284 // Check if client has a higher version than the server
285 if pullMsg.Version > serverLatestVersion {
286 log.Printf("Client has higher version (%d) than server (%d). Requesting changes.",
287 pullMsg.Version, serverLatestVersion)
288
289 // Send a request_changes message to the client
290 requestChangesMsg := RequestChangesMessage{
291 Type: "request_changes",
292 RoomID: roomID,
293 Version: VersionColPair{DBVersion: serverLatestVersion, ColVersion: 0},
294 }
295
296 requestJSON, _ := json.Marshal(requestChangesMsg)
297 conn.WriteMessage(websocket.TextMessage, requestJSON)
298 } else {
299 // Only send changes if we're newer
300 changes := getChangesFromDB(db, pullMsg.SiteID, pullMsg.Version)
301 response := map[string]interface{}{
302 "type": "changes",
303 "data": changes,
304 }
305 responseJSON, _ := json.Marshal(response)
306 conn.WriteMessage(websocket.TextMessage, responseJSON)
307 }
308
309 case "changes":
310 log.Printf("Received changes from client in room %s", roomID)
311
312 // Check write permission
313 if access != "write" {
314 log.Printf("User has no write permission for room %s", roomID)
315 sendError(conn, "PERMISSION_DENIED", "No write permission")
316 continue
317 }
318
319 // For signed changes, verify the signature
320 if msgPublicKey, hasKey := msg["publicKey"].(string); hasKey {
321 log.Printf("Changes are authenticated with public key: %s...", msgPublicKey[:20])
322
323 // Verify publicKey matches connection
324 if msgPublicKey != publicKey {
325 log.Printf("PublicKey mismatch in changes message")
326 sendError(conn, "AUTH_FAILED", "PublicKey mismatch")
327 continue
328 }
329
330 // TODO: Verify signature when ECDSA is implemented
331 if signature, hasSig := msg["signature"].(string); hasSig && requireAuth {
332 dataStr := ""
333 if data, ok := msg["data"]; ok {
334 dataBytes, _ := json.Marshal(data)
335 dataStr = string(dataBytes)
336 }
337
338 isValid, err := VerifySignature(msgPublicKey, "changes:"+dataStr, signature)
339 if err != nil || !isValid {
340 log.Printf("Signature verification failed")
341 sendError(conn, "AUTH_FAILED", "Invalid signature")
342 continue
343 }
344 }
345 }
346
347 if data, ok := msg["data"].([]interface{}); ok {
348 log.Printf("Processing %d changes from client", len(data))
349 applyChangesToDB(db, data)
350
351 log.Printf("Broadcasting changes to other clients in room %s", roomID)
352 broadcastToRoom(roomID, conn, message)
353 log.Printf("Broadcast completed for room %s", roomID)
354 }
355 }
356 }
357}
358
359func setupDatabase(db *sql.DB) {
360 // Create todos table if it doesn't exist
361 _, err := db.Exec(`
362 CREATE TABLE IF NOT EXISTS todos (
363 id BLOB PRIMARY KEY NOT NULL,
364 description TEXT,
365 project text,
366 tags text,
367 due text,
368 wait text,
369 priority text,
370 urgency real,
371 completed INTEGER NOT NULL DEFAULT 0
372 );
373 SELECT crsql_as_crr('todos');
374 `)
375 if err != nil {
376 log.Println("Error setting up database:", err)
377 }
378}
379
380func sendInitialChanges(conn *websocket.Conn, db *sql.DB) {
381 // Query all changes
382 rows, err := db.Query("SELECT * FROM crsql_changes")
383 if err != nil {
384 log.Println("Error querying changes:", err)
385 return
386 }
387 defer rows.Close()
388
389 var changes []map[string]interface{}
390
391 for rows.Next() {
392 var tableName string
393 var pk []byte
394 var columnName string
395 var value interface{}
396 var colVersion, dbVersion int64
397 var siteID []byte
398 var cl, seq int64
399
400 if err := rows.Scan(&tableName, &pk, &columnName, &value, &colVersion, &dbVersion, &siteID, &cl, &seq); err != nil {
401 log.Println("Error scanning row:", err)
402 continue
403 }
404
405 change := map[string]interface{}{
406 "TableName": tableName,
407 "PK": encodeToBase64(pk),
408 "ColumnName": columnName,
409 "Value": value,
410 "ColVersion": colVersion,
411 "DBVersion": dbVersion,
412 "SiteID": encodeToBase64(siteID),
413 "CL": cl,
414 "Seq": seq,
415 }
416
417 changes = append(changes, change)
418 }
419
420 response := map[string]interface{}{
421 "type": "changes",
422 "data": changes,
423 }
424
425 responseJSON, _ := json.Marshal(response)
426 conn.WriteMessage(websocket.TextMessage, responseJSON)
427}
428
429// getLatestDBVersion gets the latest db_version from the database (legacy)
430func getLatestDBVersion(db *sql.DB) (int, error) {
431 // Query the maximum db_version
432 row := db.QueryRow("SELECT MAX(db_version) FROM crsql_changes")
433
434 var version sql.NullInt64
435 if err := row.Scan(&version); err != nil {
436 return 0, err
437 }
438
439 // If there are no rows or the value is null, return 0
440 if !version.Valid {
441 return 0, nil
442 }
443
444 return int(version.Int64), nil
445}
446
447// getLatestDBVersionCol gets the latest (db_version, col_version) pair from the database
448func getLatestDBVersionCol(db *sql.DB) (VersionColPair, error) {
449 // First check if there are any changes at all
450 var count int
451 err := db.QueryRow("SELECT COUNT(*) FROM crsql_changes").Scan(&count)
452 if err != nil {
453 return VersionColPair{DBVersion: 0, ColVersion: 0}, err
454 }
455
456 // If no changes exist, return 0,0
457 if count == 0 {
458 return VersionColPair{DBVersion: 0, ColVersion: 0}, nil
459 }
460
461 // Query the maximum db_version and its maximum col_version
462 row := db.QueryRow(`
463 SELECT db_version, col_version
464 FROM crsql_changes
465 ORDER BY db_version DESC, col_version DESC
466 LIMIT 1
467 `)
468
469 var dbVersion, colVersion sql.NullInt64
470 if err := row.Scan(&dbVersion, &colVersion); err != nil {
471 if err == sql.ErrNoRows {
472 return VersionColPair{DBVersion: 0, ColVersion: 0}, nil
473 }
474 return VersionColPair{DBVersion: 0, ColVersion: 0}, err
475 }
476
477 // If there are no rows or the values are null, return 0,0
478 if !dbVersion.Valid || !colVersion.Valid {
479 return VersionColPair{DBVersion: 0, ColVersion: 0}, nil
480 }
481
482 return VersionColPair{DBVersion: int(dbVersion.Int64), ColVersion: int(colVersion.Int64)}, nil
483}
484
485// compareVersionPairs compares two version pairs, returns -1, 0, or 1
486func compareVersionPairs(a, b VersionColPair) int {
487 if a.DBVersion != b.DBVersion {
488 if a.DBVersion < b.DBVersion {
489 return -1
490 }
491 return 1
492 }
493 if a.ColVersion != b.ColVersion {
494 if a.ColVersion < b.ColVersion {
495 return -1
496 }
497 return 1
498 }
499 return 0
500}
501
502func getChangesFromDB(db *sql.DB, siteID string, version int) []map[string]interface{} {
503 // Decode the site_id from base64
504 decodedSiteID, err := decodeBase64(siteID)
505 if err != nil {
506 log.Printf("Error decoding site_id '%s': %v", siteID, err)
507 decodedSiteID = []byte{}
508 }
509
510 log.Printf("Querying for changes with site_id != %v AND db_version > %d", decodedSiteID, version)
511 query := "SELECT * FROM crsql_changes WHERE site_id != ? AND db_version > ?"
512
513 rows, err := db.Query(query, decodedSiteID, version)
514 if err != nil {
515 log.Println("Error querying changes:", err)
516 return nil
517 }
518 defer rows.Close()
519
520 var changes []map[string]interface{}
521
522 for rows.Next() {
523 var tableName string
524 var pk []byte
525 var columnName string
526 var value interface{}
527 var colVersion, dbVersion int64
528 var siteID []byte
529 var cl, seq int64
530
531 if err := rows.Scan(&tableName, &pk, &columnName, &value, &colVersion, &dbVersion, &siteID, &cl, &seq); err != nil {
532 log.Println("Error scanning row:", err)
533 continue
534 }
535
536 change := map[string]interface{}{
537 "TableName": tableName,
538 "PK": encodeToBase64(pk),
539 "ColumnName": columnName,
540 "Value": value,
541 "ColVersion": colVersion,
542 "DBVersion": dbVersion,
543 "SiteID": encodeToBase64(siteID),
544 "CL": cl,
545 "Seq": seq,
546 }
547
548 changes = append(changes, change)
549 }
550
551 log.Printf("Returning %d changes to client for sync", len(changes))
552 return changes
553}
554
555func applyChangesToDB(db *sql.DB, changes []interface{}) {
556 tx, err := db.Begin()
557 if err != nil {
558 log.Println("Error starting transaction:", err)
559 return
560 }
561
562 for _, change := range changes {
563 if changeMap, ok := change.(map[string]interface{}); ok {
564 tableName := changeMap["TableName"].(string)
565 pk, _ := decodeBase64(changeMap["PK"].(string))
566 columnName := changeMap["ColumnName"].(string)
567 value := changeMap["Value"]
568 colVersion := int64(changeMap["ColVersion"].(float64))
569 dbVersion := int64(changeMap["DBVersion"].(float64))
570 siteID, _ := decodeBase64(changeMap["SiteID"].(string))
571 cl := int64(changeMap["CL"].(float64))
572 seq := int64(changeMap["Seq"].(float64))
573
574 _, err := tx.Exec(
575 `INSERT INTO crsql_changes VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
576 tableName, pk, columnName, value, colVersion, dbVersion, siteID, cl, seq,
577 )
578 if err != nil {
579 log.Println("Error inserting change:", err)
580 }
581 }
582 }
583
584 if err := tx.Commit(); err != nil {
585 log.Println("Error committing transaction:", err)
586 tx.Rollback()
587 }
588}
589
590func encodeToBase64(data []byte) string {
591 return base64.StdEncoding.EncodeToString(data)
592}
593
594func decodeBase64(encoded string) ([]byte, error) {
595 return base64.StdEncoding.DecodeString(encoded)
596}
597
598// createDirIfNotExists creates a directory if it doesn't already exist
599func createDirIfNotExists(path string) error {
600 // Check if the directory exists
601 if _, err := os.Stat(path); os.IsNotExist(err) {
602 // Directory does not exist, create it
603 return os.MkdirAll(path, 0755)
604 } else if err != nil {
605 // Some other error occurred
606 return err
607 }
608 // Directory already exists
609 return nil
610}
611
612
613// PullMessage represents the structure of a 'pull' type message from clients (legacy v1)
614type PullMessage struct {
615 Type string `json:"type"`
616 SiteID string `json:"site_id"`
617 Version int `json:"version"`
618}
619
620// RequestChangesMessage represents a request from server to client to send their changes
621type RequestChangesMessage struct {
622 Type string `json:"type"`
623 RoomID string `json:"room_id"`
624 Version VersionColPair `json:"version"`
625}
626
627// VersionColPair represents a (db_version, col_version) pair
628type VersionColPair struct {
629 DBVersion int `json:"db_version"`
630 ColVersion int `json:"col_version"`
631}
632
633// MissingVersionRange represents missing col_versions for a specific db_version
634type MissingVersionRange struct {
635 DBVersion int `json:"db_version"`
636 ColVersions []int `json:"col_versions"`
637}
638
639// VersionRange represents a missing version range (legacy - kept for compatibility)
640type VersionRange struct {
641 Start int `json:"start"`
642 End int `json:"end"`
643}
644
645// SyncRequestMessage represents the v2 sync request with version ranges
646type SyncRequestMessage struct {
647 Type string `json:"type"`
648 SiteID string `json:"site_id"`
649 ContiguousUpTo VersionColPair `json:"contiguous_up_to"`
650 MissingRanges []MissingVersionRange `json:"missing_ranges"`
651 MaxVersionSeen VersionColPair `json:"max_version_seen"`
652 PublicKey string `json:"publicKey"`
653}
654
655// SyncResponseMessage represents the server response to sync request
656type SyncResponseMessage struct {
657 Type string `json:"type"`
658 CurrentMaxVersion VersionColPair `json:"current_max_version"`
659 Changes []map[string]interface{} `json:"changes"`
660}
661
662// RoomStatusMessage represents immediate room status on connection
663type RoomStatusMessage struct {
664 Type string `json:"type"`
665 Access string `json:"access"`
666}
667
668// ErrorMessage represents structured error responses
669type ErrorMessage struct {
670 Type string `json:"type"`
671 Code string `json:"code"`
672 Message string `json:"message"`
673}
674
675// AuthRequest is the structure for authentication verification requests
676type AuthRequest struct {
677 RoomID string `json:"roomId"`
678 PublicKey string `json:"publicKey"`
679 Data string `json:"data"`
680 Signature string `json:"signature"`
681}
682
683// AuthResponse is the structure for authentication verification responses
684type AuthResponse struct {
685 Authenticated bool `json:"authenticated"`
686 Message string `json:"message"`
687}
688
689// handleAuthVerify handles authentication verification requests
690func handleAuthVerify(w http.ResponseWriter, r *http.Request) {
691 // Set CORS headers
692 w.Header().Set("Access-Control-Allow-Origin", "*")
693 w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
694 w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
695
696 // Handle preflight OPTIONS request
697 if r.Method == http.MethodOptions {
698 w.WriteHeader(http.StatusOK)
699 return
700 }
701
702 // Only allow POST requests
703 if r.Method != http.MethodPost {
704 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
705 return
706 }
707
708 // Parse the request body
709 var req AuthRequest
710 decoder := json.NewDecoder(r.Body)
711 if err := decoder.Decode(&req); err != nil {
712 log.Printf("Error decoding auth request: %v", err)
713 http.Error(w, "Invalid request format", http.StatusBadRequest)
714 return
715 }
716
717 // Validate the request
718 if req.RoomID == "" || req.PublicKey == "" || req.Signature == "" || req.Data == "" {
719 http.Error(w, "Room ID, public key, data, and signature are required", http.StatusBadRequest)
720 return
721 }
722
723 // Verify the signature
724 isValid, err := VerifySignature(req.PublicKey, req.Data, req.Signature)
725 if err != nil {
726 log.Printf("Error verifying signature: %v", err)
727 http.Error(w, "Failed to verify signature", http.StatusInternalServerError)
728 return
729 }
730
731 if !isValid {
732 // Return failure response
733 response := AuthResponse{
734 Authenticated: false,
735 Message: "Signature verification failed",
736 }
737
738 w.Header().Set("Content-Type", "application/json")
739 w.WriteHeader(http.StatusUnauthorized)
740 json.NewEncoder(w).Encode(response)
741 return
742 }
743
744 // Check if the public key has write permission for the room
745 hasWritePerm, err := CheckKeyPermission(req.RoomID, req.PublicKey, "write")
746 if err != nil {
747 log.Printf("Error checking key permission: %v", err)
748 http.Error(w, "Failed to check permissions", http.StatusInternalServerError)
749 return
750 }
751
752 if !hasWritePerm {
753 // Return failure response
754 response := AuthResponse{
755 Authenticated: false,
756 Message: "Public key doesn't have write permission for this room",
757 }
758
759 w.Header().Set("Content-Type", "application/json")
760 w.WriteHeader(http.StatusForbidden)
761 json.NewEncoder(w).Encode(response)
762 return
763 }
764
765 // Return success response
766 response := AuthResponse{
767 Authenticated: true,
768 Message: "Authentication successful",
769 }
770
771 w.Header().Set("Content-Type", "application/json")
772 w.WriteHeader(http.StatusOK)
773 json.NewEncoder(w).Encode(response)
774}
775
776// determineAccess determines the access level for a user connecting to a room
777func determineAccess(roomID, publicKey string) string {
778 logKey := publicKey
779 if len(logKey) > 20 {
780 logKey = publicKey[:20] + "..."
781 }
782 log.Printf("Determining access for publicKey %s in room %s", logKey, roomID)
783
784 // Check if room exists
785 roomExists, err := CheckRoomExists(roomID)
786 if err != nil {
787 log.Printf("Error checking room existence: %v", err)
788 return "none"
789 }
790
791 if !roomExists {
792 log.Printf("Room %s does not exist", roomID)
793 if !requireAuth {
794 // Development mode - auto-create room and grant permissions
795 log.Printf("Auto-creating room %s and granting permissions", roomID)
796 if err := GetOrCreateRoom(roomID); err != nil {
797 log.Printf("Error auto-creating room: %v", err)
798 return "no_room"
799 }
800 if err := AutoGrantInvitePermissions(roomID, publicKey); err != nil {
801 log.Printf("Error auto-granting permissions: %v", err)
802 return "no_room"
803 }
804 return "write"
805 }
806 return "no_room"
807 }
808
809 // Room exists - check permissions
810 if !requireAuth {
811 // Development mode - auto-grant permissions for existing rooms
812 log.Printf("Auto-granting permissions for existing room %s", roomID)
813 if err := AutoGrantInvitePermissions(roomID, publicKey); err != nil {
814 log.Printf("Error auto-granting permissions: %v", err)
815 }
816 return "write"
817 }
818
819 // Production mode - check actual permissions
820 hasRead, err := CheckKeyPermission(roomID, publicKey, "read")
821 if err != nil {
822 log.Printf("Error checking read permission: %v", err)
823 return "none"
824 }
825
826 if !hasRead {
827 return "none"
828 }
829
830 hasWrite, err := CheckKeyPermission(roomID, publicKey, "write")
831 if err != nil {
832 log.Printf("Error checking write permission: %v", err)
833 return "read"
834 }
835
836 if hasWrite {
837 return "write"
838 }
839
840 return "read"
841}
842
843// sendRoomStatus sends immediate room status to client
844func sendRoomStatus(conn *websocket.Conn, access string) {
845 status := RoomStatusMessage{
846 Type: "room_status",
847 Access: access,
848 }
849
850 statusJSON, _ := json.Marshal(status)
851 conn.WriteMessage(websocket.TextMessage, statusJSON)
852 log.Printf("Sent room status: %s", access)
853}
854
855// sendError sends structured error response to client
856func sendError(conn *websocket.Conn, code, message string) {
857 errorMsg := ErrorMessage{
858 Type: "error",
859 Code: code,
860 Message: message,
861 }
862
863 errorJSON, _ := json.Marshal(errorMsg)
864 conn.WriteMessage(websocket.TextMessage, errorJSON)
865 log.Printf("Sent error: %s - %s", code, message)
866}
867
868// getChangesForSyncRequest gets changes for missing ranges + new changes
869func getChangesForSyncRequest(db *sql.DB, syncMsg SyncRequestMessage) []map[string]interface{} {
870 decodedSiteID, err := decodeBase64(syncMsg.SiteID)
871 if err != nil {
872 log.Printf("Error decoding site_id '%s': %v", syncMsg.SiteID, err)
873 decodedSiteID = []byte{}
874 }
875
876 // Build query for missing (db_version, col_version) pairs + new changes
877 var queryParts []string
878 var args []interface{}
879
880 // Add site_id filter
881 baseCond := "site_id != ?"
882 args = append(args, decodedSiteID)
883
884 // Add missing ranges - for each db_version, get specific missing col_versions
885 for _, vrange := range syncMsg.MissingRanges {
886 if len(vrange.ColVersions) > 0 {
887 // Create placeholders for the col_versions
888 placeholders := make([]string, len(vrange.ColVersions))
889 for i := range placeholders {
890 placeholders[i] = "?"
891 args = append(args, vrange.ColVersions[i])
892 }
893
894 // Add condition for this db_version with specific col_versions
895 queryParts = append(queryParts,
896 fmt.Sprintf("(db_version = ? AND col_version IN (%s))",
897 strings.Join(placeholders, ",")))
898 args = append(args, vrange.DBVersion)
899 }
900 }
901
902 // Add new changes beyond max version seen
903 // This includes: db_version > max_db_version OR (db_version = max_db_version AND col_version > max_col_version)
904 queryParts = append(queryParts,
905 "(db_version > ? OR (db_version = ? AND col_version > ?))")
906 args = append(args, syncMsg.MaxVersionSeen.DBVersion,
907 syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion)
908
909 // Combine all conditions
910 whereClause := baseCond
911 if len(queryParts) > 0 {
912 whereClause += " AND (" + strings.Join(queryParts, " OR ") + ")"
913 }
914
915 query := "SELECT * FROM crsql_changes WHERE " + whereClause + " ORDER BY db_version ASC, col_version ASC"
916
917 log.Printf("Executing sync query with %d missing ranges and max_version %d.%d",
918 len(syncMsg.MissingRanges), syncMsg.MaxVersionSeen.DBVersion, syncMsg.MaxVersionSeen.ColVersion)
919
920 rows, err := db.Query(query, args...)
921 if err != nil {
922 log.Printf("Error querying changes for sync: %v", err)
923 return nil
924 }
925 defer rows.Close()
926
927 var changes []map[string]interface{}
928
929 for rows.Next() {
930 var tableName string
931 var pk []byte
932 var columnName string
933 var value interface{}
934 var colVersion, dbVersion int64
935 var siteID []byte
936 var cl, seq int64
937
938 if err := rows.Scan(&tableName, &pk, &columnName, &value, &colVersion, &dbVersion, &siteID, &cl, &seq); err != nil {
939 log.Printf("Error scanning row: %v", err)
940 continue
941 }
942
943 change := map[string]interface{}{
944 "TableName": tableName,
945 "PK": encodeToBase64(pk),
946 "ColumnName": columnName,
947 "Value": value,
948 "ColVersion": colVersion,
949 "DBVersion": dbVersion,
950 "SiteID": encodeToBase64(siteID),
951 "CL": cl,
952 "Seq": seq,
953 }
954
955 changes = append(changes, change)
956 }
957
958 log.Printf("Returning %d changes for sync request", len(changes))
959 return changes
960}
961
962func main() {
963 // Log startup configuration
964 log.Printf("Starting server with REQUIRE_AUTH=%v", requireAuth)
965
966 // Register SQLite with CR-SQLite extension
967 sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{
968 Extensions: []string{"../db/crsqlite"},
969 })
970
971 // Initialize the auth database
972 log.Println("Initializing authentication database...")
973 if err := InitAuthDB(); err != nil {
974 log.Printf("Warning: Failed to initialize auth database: %v", err)
975 } else {
976 log.Println("Authentication database initialized successfully")
977 }
978 // Close auth database connection when the server exits
979 defer CloseAuthDB()
980
981 // Create directory for room databases
982 if err := createDirIfNotExists("./rooms"); err != nil {
983 log.Printf("Warning: Failed to create rooms directory: %v", err)
984 } else {
985 log.Println("Rooms directory created or verified")
986 }
987
988 http.HandleFunc("/sync", handleWebSocket)
989 http.HandleFunc("/auth/verify", handleAuthVerify)
990 http.HandleFunc("/auth/register-key", HandleKeyRegistration)
991
992 log.Println("WebSocket server started on :8080")
993 log.Fatal(http.ListenAndServe(":8080", nil))
994}
995