Coffee journaling on ATProto (alpha) alpha.arabica.social
coffee
17
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix: switch to go native csrf

pdewey 66d691ef b6979bc1

+28 -715
+3 -3
internal/bff/render.go
··· 138 138 } 139 139 140 140 // RenderTemplate renders a template with layout 141 - func RenderTemplate(w http.ResponseWriter, tmpl string, data *PageData) error { 141 + func RenderTemplate(w http.ResponseWriter, r *http.Request, tmpl string, data *PageData) error { 142 142 t, err := parsePageTemplate(tmpl) 143 143 if err != nil { 144 144 return err ··· 147 147 } 148 148 149 149 // RenderTemplateWithProfile renders a template with layout and user profile 150 - func RenderTemplateWithProfile(w http.ResponseWriter, tmpl string, data *PageData, userProfile *UserProfile) error { 150 + func RenderTemplateWithProfile(w http.ResponseWriter, r *http.Request, tmpl string, data *PageData, userProfile *UserProfile) error { 151 151 data.UserProfile = userProfile 152 - return RenderTemplate(w, tmpl, data) 152 + return RenderTemplate(w, r, tmpl, data) 153 153 } 154 154 155 155 // RenderHome renders the home page
+2 -2
internal/handlers/handlers.go
··· 1133 1133 UserProfile: userProfile, 1134 1134 } 1135 1135 1136 - if err := bff.RenderTemplate(w, "about.tmpl", data); err != nil { 1136 + if err := bff.RenderTemplate(w, r, "about.tmpl", data); err != nil { 1137 1137 http.Error(w, "Failed to render page", http.StatusInternalServerError) 1138 1138 log.Error().Err(err).Msg("Failed to render about page") 1139 1139 } ··· 1157 1157 UserProfile: userProfile, 1158 1158 } 1159 1159 1160 - if err := bff.RenderTemplate(w, "terms.tmpl", data); err != nil { 1160 + if err := bff.RenderTemplate(w, r, "terms.tmpl", data); err != nil { 1161 1161 http.Error(w, "Failed to render page", http.StatusInternalServerError) 1162 1162 log.Error().Err(err).Msg("Failed to render terms page") 1163 1163 }
-157
internal/middleware/csrf.go
··· 1 - package middleware 2 - 3 - import ( 4 - "crypto/rand" 5 - "crypto/subtle" 6 - "encoding/base64" 7 - "net/http" 8 - "strings" 9 - 10 - "github.com/rs/zerolog/log" 11 - ) 12 - 13 - const ( 14 - // CSRFTokenCookieName is the name of the cookie that stores the CSRF token 15 - CSRFTokenCookieName = "csrf_token" 16 - // CSRFTokenHeaderName is the HTTP header name for submitting the CSRF token 17 - CSRFTokenHeaderName = "X-CSRF-Token" 18 - // CSRFTokenFormField is the form field name for submitting the CSRF token 19 - CSRFTokenFormField = "csrf_token" 20 - // CSRFTokenLength is the number of random bytes used to generate the token 21 - CSRFTokenLength = 32 22 - ) 23 - 24 - // CSRFConfig holds CSRF middleware configuration 25 - type CSRFConfig struct { 26 - // SecureCookie sets the Secure flag on the CSRF cookie 27 - SecureCookie bool 28 - 29 - // ExemptPaths are paths that skip CSRF validation (e.g., OAuth callback) 30 - ExemptPaths []string 31 - 32 - // ExemptMethods are HTTP methods that skip CSRF validation 33 - // Default: GET, HEAD, OPTIONS, TRACE 34 - ExemptMethods []string 35 - } 36 - 37 - // DefaultCSRFConfig returns default configuration 38 - func DefaultCSRFConfig() *CSRFConfig { 39 - return &CSRFConfig{ 40 - SecureCookie: false, 41 - ExemptPaths: []string{"/oauth/callback"}, 42 - ExemptMethods: []string{"GET", "HEAD", "OPTIONS", "TRACE"}, 43 - } 44 - } 45 - 46 - // generateCSRFToken creates a cryptographically secure random token 47 - func generateCSRFToken() (string, error) { 48 - bytes := make([]byte, CSRFTokenLength) 49 - if _, err := rand.Read(bytes); err != nil { 50 - return "", err 51 - } 52 - return base64.URLEncoding.EncodeToString(bytes), nil 53 - } 54 - 55 - // CSRFMiddleware provides CSRF protection using double-submit cookie pattern 56 - func CSRFMiddleware(config *CSRFConfig) func(http.Handler) http.Handler { 57 - if config == nil { 58 - config = DefaultCSRFConfig() 59 - } 60 - 61 - // Build exempt method set for fast lookup 62 - exemptMethods := make(map[string]bool) 63 - for _, m := range config.ExemptMethods { 64 - exemptMethods[strings.ToUpper(m)] = true 65 - } 66 - 67 - return func(next http.Handler) http.Handler { 68 - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 69 - // Get or generate CSRF token 70 - var token string 71 - cookie, err := r.Cookie(CSRFTokenCookieName) 72 - if err == nil && cookie.Value != "" { 73 - token = cookie.Value 74 - } else { 75 - // Generate new token 76 - token, err = generateCSRFToken() 77 - if err != nil { 78 - log.Error().Err(err).Msg("Failed to generate CSRF token") 79 - http.Error(w, "Internal server error", http.StatusInternalServerError) 80 - return 81 - } 82 - 83 - // Set cookie 84 - http.SetCookie(w, &http.Cookie{ 85 - Name: CSRFTokenCookieName, 86 - Value: token, 87 - Path: "/", 88 - HttpOnly: false, // JS needs to read this 89 - Secure: config.SecureCookie, 90 - SameSite: http.SameSiteStrictMode, 91 - MaxAge: 86400, // 24 hours 92 - }) 93 - } 94 - 95 - // Store token in response header for JS to access 96 - // This is an alternative to reading from cookie 97 - w.Header().Set(CSRFTokenHeaderName, token) 98 - 99 - // Check if method requires validation 100 - if exemptMethods[r.Method] { 101 - next.ServeHTTP(w, r) 102 - return 103 - } 104 - 105 - // Check if path is exempt 106 - for _, path := range config.ExemptPaths { 107 - if r.URL.Path == path || strings.HasPrefix(r.URL.Path, path) { 108 - next.ServeHTTP(w, r) 109 - return 110 - } 111 - } 112 - 113 - // Validate CSRF token 114 - // Try header first (JavaScript requests) 115 - submittedToken := r.Header.Get(CSRFTokenHeaderName) 116 - 117 - // Fall back to form field (traditional forms) 118 - if submittedToken == "" { 119 - submittedToken = r.FormValue(CSRFTokenFormField) 120 - } 121 - 122 - // Validate token 123 - if submittedToken == "" { 124 - log.Warn(). 125 - Str("client_ip", getClientIP(r)). 126 - Str("path", r.URL.Path). 127 - Str("method", r.Method). 128 - Msg("CSRF token missing") 129 - http.Error(w, "CSRF token missing", http.StatusForbidden) 130 - return 131 - } 132 - 133 - // Constant-time comparison to prevent timing attacks 134 - if subtle.ConstantTimeCompare([]byte(token), []byte(submittedToken)) != 1 { 135 - log.Warn(). 136 - Str("client_ip", getClientIP(r)). 137 - Str("path", r.URL.Path). 138 - Str("method", r.Method). 139 - Msg("CSRF token invalid") 140 - http.Error(w, "CSRF token invalid", http.StatusForbidden) 141 - return 142 - } 143 - 144 - next.ServeHTTP(w, r) 145 - }) 146 - } 147 - } 148 - 149 - // GetCSRFToken extracts the CSRF token from request cookies 150 - // Used by template rendering to include token in forms 151 - func GetCSRFToken(r *http.Request) string { 152 - cookie, err := r.Cookie(CSRFTokenCookieName) 153 - if err != nil { 154 - return "" 155 - } 156 - return cookie.Value 157 - }
-455
internal/middleware/csrf_test.go
··· 1 - package middleware 2 - 3 - import ( 4 - "net/http" 5 - "net/http/httptest" 6 - "strings" 7 - "testing" 8 - ) 9 - 10 - func TestCSRFTokenGeneration(t *testing.T) { 11 - // Test that tokens are generated correctly 12 - token, err := generateCSRFToken() 13 - if err != nil { 14 - t.Fatalf("Failed to generate token: %v", err) 15 - } 16 - if len(token) == 0 { 17 - t.Error("Generated token is empty") 18 - } 19 - 20 - // Test uniqueness 21 - token2, _ := generateCSRFToken() 22 - if token == token2 { 23 - t.Error("Tokens should be unique") 24 - } 25 - } 26 - 27 - func TestCSRFMiddleware_SetsTokenCookie(t *testing.T) { 28 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 29 - w.WriteHeader(http.StatusOK) 30 - })) 31 - 32 - req := httptest.NewRequest("GET", "/", nil) 33 - rec := httptest.NewRecorder() 34 - 35 - handler.ServeHTTP(rec, req) 36 - 37 - // Check cookie is set 38 - cookies := rec.Result().Cookies() 39 - var csrfCookie *http.Cookie 40 - for _, c := range cookies { 41 - if c.Name == CSRFTokenCookieName { 42 - csrfCookie = c 43 - break 44 - } 45 - } 46 - 47 - if csrfCookie == nil { 48 - t.Error("CSRF cookie not set") 49 - } 50 - if csrfCookie.Value == "" { 51 - t.Error("CSRF cookie value is empty") 52 - } 53 - // Verify cookie settings 54 - if csrfCookie.HttpOnly { 55 - t.Error("CSRF cookie should not be HttpOnly (JS needs to read it)") 56 - } 57 - if csrfCookie.SameSite != http.SameSiteStrictMode { 58 - t.Error("CSRF cookie should have SameSite=Strict") 59 - } 60 - } 61 - 62 - func TestCSRFMiddleware_SetsResponseHeader(t *testing.T) { 63 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 64 - w.WriteHeader(http.StatusOK) 65 - })) 66 - 67 - req := httptest.NewRequest("GET", "/", nil) 68 - rec := httptest.NewRecorder() 69 - 70 - handler.ServeHTTP(rec, req) 71 - 72 - // Check response header is set 73 - headerToken := rec.Header().Get(CSRFTokenHeaderName) 74 - if headerToken == "" { 75 - t.Error("CSRF token not set in response header") 76 - } 77 - } 78 - 79 - func TestCSRFMiddleware_GETRequestsExempt(t *testing.T) { 80 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 81 - w.WriteHeader(http.StatusOK) 82 - })) 83 - 84 - // GET without token should succeed 85 - req := httptest.NewRequest("GET", "/some-page", nil) 86 - rec := httptest.NewRecorder() 87 - 88 - handler.ServeHTTP(rec, req) 89 - 90 - if rec.Code != http.StatusOK { 91 - t.Errorf("GET request should succeed, got status %d", rec.Code) 92 - } 93 - } 94 - 95 - func TestCSRFMiddleware_HEADRequestsExempt(t *testing.T) { 96 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 97 - w.WriteHeader(http.StatusOK) 98 - })) 99 - 100 - req := httptest.NewRequest("HEAD", "/some-page", nil) 101 - rec := httptest.NewRecorder() 102 - 103 - handler.ServeHTTP(rec, req) 104 - 105 - if rec.Code != http.StatusOK { 106 - t.Errorf("HEAD request should succeed, got status %d", rec.Code) 107 - } 108 - } 109 - 110 - func TestCSRFMiddleware_OPTIONSRequestsExempt(t *testing.T) { 111 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 112 - w.WriteHeader(http.StatusOK) 113 - })) 114 - 115 - req := httptest.NewRequest("OPTIONS", "/some-page", nil) 116 - rec := httptest.NewRecorder() 117 - 118 - handler.ServeHTTP(rec, req) 119 - 120 - if rec.Code != http.StatusOK { 121 - t.Errorf("OPTIONS request should succeed, got status %d", rec.Code) 122 - } 123 - } 124 - 125 - func TestCSRFMiddleware_POSTWithoutToken_Fails(t *testing.T) { 126 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 127 - w.WriteHeader(http.StatusOK) 128 - })) 129 - 130 - // POST without token should fail 131 - req := httptest.NewRequest("POST", "/api/beans", strings.NewReader("{}")) 132 - req.Header.Set("Content-Type", "application/json") 133 - rec := httptest.NewRecorder() 134 - 135 - handler.ServeHTTP(rec, req) 136 - 137 - if rec.Code != http.StatusForbidden { 138 - t.Errorf("POST without CSRF token should return 403, got %d", rec.Code) 139 - } 140 - } 141 - 142 - func TestCSRFMiddleware_POSTWithValidToken_Succeeds(t *testing.T) { 143 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 144 - w.WriteHeader(http.StatusOK) 145 - })) 146 - 147 - // First, get a token via GET 148 - getReq := httptest.NewRequest("GET", "/", nil) 149 - getRec := httptest.NewRecorder() 150 - handler.ServeHTTP(getRec, getReq) 151 - 152 - var token string 153 - for _, c := range getRec.Result().Cookies() { 154 - if c.Name == CSRFTokenCookieName { 155 - token = c.Value 156 - break 157 - } 158 - } 159 - 160 - if token == "" { 161 - t.Fatal("No CSRF token cookie was set") 162 - } 163 - 164 - // Now POST with valid token 165 - postReq := httptest.NewRequest("POST", "/api/beans", strings.NewReader("{}")) 166 - postReq.Header.Set("Content-Type", "application/json") 167 - postReq.Header.Set(CSRFTokenHeaderName, token) 168 - postReq.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: token}) 169 - postRec := httptest.NewRecorder() 170 - 171 - handler.ServeHTTP(postRec, postReq) 172 - 173 - if postRec.Code != http.StatusOK { 174 - t.Errorf("POST with valid CSRF token should succeed, got %d", postRec.Code) 175 - } 176 - } 177 - 178 - func TestCSRFMiddleware_POSTWithInvalidToken_Fails(t *testing.T) { 179 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 180 - w.WriteHeader(http.StatusOK) 181 - })) 182 - 183 - // First, get a token 184 - getReq := httptest.NewRequest("GET", "/", nil) 185 - getRec := httptest.NewRecorder() 186 - handler.ServeHTTP(getRec, getReq) 187 - 188 - var token string 189 - for _, c := range getRec.Result().Cookies() { 190 - if c.Name == CSRFTokenCookieName { 191 - token = c.Value 192 - break 193 - } 194 - } 195 - 196 - // POST with wrong token 197 - postReq := httptest.NewRequest("POST", "/api/beans", strings.NewReader("{}")) 198 - postReq.Header.Set("Content-Type", "application/json") 199 - postReq.Header.Set(CSRFTokenHeaderName, "wrong-token") 200 - postReq.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: token}) 201 - postRec := httptest.NewRecorder() 202 - 203 - handler.ServeHTTP(postRec, postReq) 204 - 205 - if postRec.Code != http.StatusForbidden { 206 - t.Errorf("POST with invalid CSRF token should return 403, got %d", postRec.Code) 207 - } 208 - } 209 - 210 - func TestCSRFMiddleware_ExemptPath(t *testing.T) { 211 - config := &CSRFConfig{ 212 - ExemptPaths: []string{"/oauth/callback"}, 213 - ExemptMethods: []string{"GET", "HEAD", "OPTIONS", "TRACE"}, 214 - } 215 - handler := CSRFMiddleware(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 216 - w.WriteHeader(http.StatusOK) 217 - })) 218 - 219 - // POST to exempt path without token should succeed 220 - req := httptest.NewRequest("POST", "/oauth/callback", nil) 221 - rec := httptest.NewRecorder() 222 - 223 - handler.ServeHTTP(rec, req) 224 - 225 - if rec.Code != http.StatusOK { 226 - t.Errorf("Exempt path should succeed without token, got %d", rec.Code) 227 - } 228 - } 229 - 230 - func TestCSRFMiddleware_ExemptPathPrefix(t *testing.T) { 231 - config := &CSRFConfig{ 232 - ExemptPaths: []string{"/oauth/"}, 233 - ExemptMethods: []string{"GET", "HEAD", "OPTIONS", "TRACE"}, 234 - } 235 - handler := CSRFMiddleware(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 236 - w.WriteHeader(http.StatusOK) 237 - })) 238 - 239 - // POST to path with exempt prefix without token should succeed 240 - req := httptest.NewRequest("POST", "/oauth/callback?code=123", nil) 241 - rec := httptest.NewRecorder() 242 - 243 - handler.ServeHTTP(rec, req) 244 - 245 - if rec.Code != http.StatusOK { 246 - t.Errorf("Exempt path prefix should succeed without token, got %d", rec.Code) 247 - } 248 - } 249 - 250 - func TestCSRFMiddleware_FormField(t *testing.T) { 251 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 252 - w.WriteHeader(http.StatusOK) 253 - })) 254 - 255 - // Get token first 256 - getReq := httptest.NewRequest("GET", "/", nil) 257 - getRec := httptest.NewRecorder() 258 - handler.ServeHTTP(getRec, getReq) 259 - 260 - var token string 261 - for _, c := range getRec.Result().Cookies() { 262 - if c.Name == CSRFTokenCookieName { 263 - token = c.Value 264 - break 265 - } 266 - } 267 - 268 - // POST with form field instead of header 269 - formData := "csrf_token=" + token + "&name=test" 270 - postReq := httptest.NewRequest("POST", "/api/beans", strings.NewReader(formData)) 271 - postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") 272 - postReq.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: token}) 273 - postRec := httptest.NewRecorder() 274 - 275 - handler.ServeHTTP(postRec, postReq) 276 - 277 - if postRec.Code != http.StatusOK { 278 - t.Errorf("POST with form field CSRF token should succeed, got %d", postRec.Code) 279 - } 280 - } 281 - 282 - func TestCSRFMiddleware_DELETE(t *testing.T) { 283 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 284 - w.WriteHeader(http.StatusOK) 285 - })) 286 - 287 - // DELETE without token should fail 288 - req := httptest.NewRequest("DELETE", "/api/beans/123", nil) 289 - rec := httptest.NewRecorder() 290 - 291 - handler.ServeHTTP(rec, req) 292 - 293 - if rec.Code != http.StatusForbidden { 294 - t.Errorf("DELETE without CSRF token should return 403, got %d", rec.Code) 295 - } 296 - } 297 - 298 - func TestCSRFMiddleware_PUT(t *testing.T) { 299 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 300 - w.WriteHeader(http.StatusOK) 301 - })) 302 - 303 - // PUT without token should fail 304 - req := httptest.NewRequest("PUT", "/api/beans/123", strings.NewReader("{}")) 305 - rec := httptest.NewRecorder() 306 - 307 - handler.ServeHTTP(rec, req) 308 - 309 - if rec.Code != http.StatusForbidden { 310 - t.Errorf("PUT without CSRF token should return 403, got %d", rec.Code) 311 - } 312 - } 313 - 314 - func TestCSRFMiddleware_DELETE_WithValidToken(t *testing.T) { 315 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 316 - w.WriteHeader(http.StatusOK) 317 - })) 318 - 319 - // Get token first 320 - getReq := httptest.NewRequest("GET", "/", nil) 321 - getRec := httptest.NewRecorder() 322 - handler.ServeHTTP(getRec, getReq) 323 - 324 - var token string 325 - for _, c := range getRec.Result().Cookies() { 326 - if c.Name == CSRFTokenCookieName { 327 - token = c.Value 328 - break 329 - } 330 - } 331 - 332 - // DELETE with valid token 333 - req := httptest.NewRequest("DELETE", "/api/beans/123", nil) 334 - req.Header.Set(CSRFTokenHeaderName, token) 335 - req.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: token}) 336 - rec := httptest.NewRecorder() 337 - 338 - handler.ServeHTTP(rec, req) 339 - 340 - if rec.Code != http.StatusOK { 341 - t.Errorf("DELETE with valid CSRF token should succeed, got %d", rec.Code) 342 - } 343 - } 344 - 345 - func TestCSRFMiddleware_PUT_WithValidToken(t *testing.T) { 346 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 347 - w.WriteHeader(http.StatusOK) 348 - })) 349 - 350 - // Get token first 351 - getReq := httptest.NewRequest("GET", "/", nil) 352 - getRec := httptest.NewRecorder() 353 - handler.ServeHTTP(getRec, getReq) 354 - 355 - var token string 356 - for _, c := range getRec.Result().Cookies() { 357 - if c.Name == CSRFTokenCookieName { 358 - token = c.Value 359 - break 360 - } 361 - } 362 - 363 - // PUT with valid token 364 - req := httptest.NewRequest("PUT", "/api/beans/123", strings.NewReader("{}")) 365 - req.Header.Set("Content-Type", "application/json") 366 - req.Header.Set(CSRFTokenHeaderName, token) 367 - req.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: token}) 368 - rec := httptest.NewRecorder() 369 - 370 - handler.ServeHTTP(rec, req) 371 - 372 - if rec.Code != http.StatusOK { 373 - t.Errorf("PUT with valid CSRF token should succeed, got %d", rec.Code) 374 - } 375 - } 376 - 377 - func TestCSRFMiddleware_SecureCookie(t *testing.T) { 378 - config := &CSRFConfig{ 379 - SecureCookie: true, 380 - ExemptMethods: []string{"GET", "HEAD", "OPTIONS", "TRACE"}, 381 - } 382 - handler := CSRFMiddleware(config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 383 - w.WriteHeader(http.StatusOK) 384 - })) 385 - 386 - req := httptest.NewRequest("GET", "/", nil) 387 - rec := httptest.NewRecorder() 388 - 389 - handler.ServeHTTP(rec, req) 390 - 391 - // Check cookie has Secure flag 392 - cookies := rec.Result().Cookies() 393 - var csrfCookie *http.Cookie 394 - for _, c := range cookies { 395 - if c.Name == CSRFTokenCookieName { 396 - csrfCookie = c 397 - break 398 - } 399 - } 400 - 401 - if csrfCookie == nil { 402 - t.Fatal("CSRF cookie not set") 403 - } 404 - if !csrfCookie.Secure { 405 - t.Error("CSRF cookie should have Secure flag when SecureCookie=true") 406 - } 407 - } 408 - 409 - func TestGetCSRFToken(t *testing.T) { 410 - // Create request with CSRF cookie 411 - req := httptest.NewRequest("GET", "/", nil) 412 - req.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: "test-token-123"}) 413 - 414 - token := GetCSRFToken(req) 415 - if token != "test-token-123" { 416 - t.Errorf("GetCSRFToken returned wrong value: %s", token) 417 - } 418 - } 419 - 420 - func TestGetCSRFToken_NoCookie(t *testing.T) { 421 - req := httptest.NewRequest("GET", "/", nil) 422 - 423 - token := GetCSRFToken(req) 424 - if token != "" { 425 - t.Errorf("GetCSRFToken should return empty string when no cookie: %s", token) 426 - } 427 - } 428 - 429 - func TestCSRFMiddleware_ReusesExistingToken(t *testing.T) { 430 - handler := CSRFMiddleware(nil)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 431 - w.WriteHeader(http.StatusOK) 432 - })) 433 - 434 - // Make request with existing token 435 - existingToken := "existing-token-abc123" 436 - req := httptest.NewRequest("GET", "/", nil) 437 - req.AddCookie(&http.Cookie{Name: CSRFTokenCookieName, Value: existingToken}) 438 - rec := httptest.NewRecorder() 439 - 440 - handler.ServeHTTP(rec, req) 441 - 442 - // Should not set a new cookie (only set when no token exists) 443 - cookies := rec.Result().Cookies() 444 - for _, c := range cookies { 445 - if c.Name == CSRFTokenCookieName { 446 - t.Error("Should not set new CSRF cookie when valid one already exists") 447 - } 448 - } 449 - 450 - // Response header should contain the existing token 451 - headerToken := rec.Header().Get(CSRFTokenHeaderName) 452 - if headerToken != existingToken { 453 - t.Errorf("Response header should contain existing token, got: %s", headerToken) 454 - } 455 - }
+23 -28
internal/routing/routing.go
··· 22 22 h := cfg.Handlers 23 23 mux := http.NewServeMux() 24 24 25 - // OAuth routes 25 + // Create CrossOriginProtection for CSRF protection 26 + cop := http.NewCrossOriginProtection() 27 + 28 + // OAuth routes (no CSRF protection needed for GET and callback) 26 29 mux.HandleFunc("GET /login", h.HandleLogin) 27 - mux.HandleFunc("POST /auth/login", h.HandleLoginSubmit) 30 + mux.Handle("POST /auth/login", cop.Handler(http.HandlerFunc(h.HandleLoginSubmit))) 28 31 mux.HandleFunc("GET /oauth/callback", h.HandleOAuthCallback) 29 - mux.HandleFunc("POST /logout", h.HandleLogout) 32 + mux.Handle("POST /logout", cop.Handler(http.HandlerFunc(h.HandleLogout))) 30 33 mux.HandleFunc("GET /client-metadata.json", h.HandleClientMetadata) 31 34 mux.HandleFunc("GET /.well-known/oauth-client-metadata", h.HandleWellKnownOAuth) 32 35 ··· 54 57 mux.HandleFunc("GET /brews", h.HandleBrewList) 55 58 mux.HandleFunc("GET /brews/new", h.HandleBrewNew) 56 59 mux.HandleFunc("GET /brews/{id}", h.HandleBrewEdit) 57 - mux.HandleFunc("POST /brews", h.HandleBrewCreate) 58 - mux.HandleFunc("PUT /brews/{id}", h.HandleBrewUpdate) 59 - mux.HandleFunc("DELETE /brews/{id}", h.HandleBrewDelete) 60 + mux.Handle("POST /brews", cop.Handler(http.HandlerFunc(h.HandleBrewCreate))) 61 + mux.Handle("PUT /brews/{id}", cop.Handler(http.HandlerFunc(h.HandleBrewUpdate))) 62 + mux.Handle("DELETE /brews/{id}", cop.Handler(http.HandlerFunc(h.HandleBrewDelete))) 60 63 mux.HandleFunc("GET /brews/export", h.HandleBrewExport) 61 64 62 65 // API routes for CRUD operations 63 - mux.HandleFunc("POST /api/beans", h.HandleBeanCreate) 64 - mux.HandleFunc("PUT /api/beans/{id}", h.HandleBeanUpdate) 65 - mux.HandleFunc("DELETE /api/beans/{id}", h.HandleBeanDelete) 66 + mux.Handle("POST /api/beans", cop.Handler(http.HandlerFunc(h.HandleBeanCreate))) 67 + mux.Handle("PUT /api/beans/{id}", cop.Handler(http.HandlerFunc(h.HandleBeanUpdate))) 68 + mux.Handle("DELETE /api/beans/{id}", cop.Handler(http.HandlerFunc(h.HandleBeanDelete))) 66 69 67 - mux.HandleFunc("POST /api/roasters", h.HandleRoasterCreate) 68 - mux.HandleFunc("PUT /api/roasters/{id}", h.HandleRoasterUpdate) 69 - mux.HandleFunc("DELETE /api/roasters/{id}", h.HandleRoasterDelete) 70 + mux.Handle("POST /api/roasters", cop.Handler(http.HandlerFunc(h.HandleRoasterCreate))) 71 + mux.Handle("PUT /api/roasters/{id}", cop.Handler(http.HandlerFunc(h.HandleRoasterUpdate))) 72 + mux.Handle("DELETE /api/roasters/{id}", cop.Handler(http.HandlerFunc(h.HandleRoasterDelete))) 70 73 71 - mux.HandleFunc("POST /api/grinders", h.HandleGrinderCreate) 72 - mux.HandleFunc("PUT /api/grinders/{id}", h.HandleGrinderUpdate) 73 - mux.HandleFunc("DELETE /api/grinders/{id}", h.HandleGrinderDelete) 74 + mux.Handle("POST /api/grinders", cop.Handler(http.HandlerFunc(h.HandleGrinderCreate))) 75 + mux.Handle("PUT /api/grinders/{id}", cop.Handler(http.HandlerFunc(h.HandleGrinderUpdate))) 76 + mux.Handle("DELETE /api/grinders/{id}", cop.Handler(http.HandlerFunc(h.HandleGrinderDelete))) 74 77 75 - mux.HandleFunc("POST /api/brewers", h.HandleBrewerCreate) 76 - mux.HandleFunc("PUT /api/brewers/{id}", h.HandleBrewerUpdate) 77 - mux.HandleFunc("DELETE /api/brewers/{id}", h.HandleBrewerDelete) 78 + mux.Handle("POST /api/brewers", cop.Handler(http.HandlerFunc(h.HandleBrewerCreate))) 79 + mux.Handle("PUT /api/brewers/{id}", cop.Handler(http.HandlerFunc(h.HandleBrewerUpdate))) 80 + mux.Handle("DELETE /api/brewers/{id}", cop.Handler(http.HandlerFunc(h.HandleBrewerDelete))) 78 81 79 82 // Profile routes (public user profiles) 80 83 mux.HandleFunc("GET /profile/{actor}", h.HandleProfile) ··· 92 95 // 1. Limit request body size (innermost - runs first on request) 93 96 handler = middleware.LimitBodyMiddleware(handler) 94 97 95 - // 2. Apply CSRF protection (validates tokens on state-changing requests) 96 - csrfConfig := &middleware.CSRFConfig{ 97 - SecureCookie: false, // Set true when using HTTPS 98 - ExemptPaths: []string{"/oauth/callback"}, 99 - ExemptMethods: []string{"GET", "HEAD", "OPTIONS", "TRACE"}, 100 - } 101 - handler = middleware.CSRFMiddleware(csrfConfig)(handler) 102 - 103 - // 3. Apply OAuth middleware to add auth context 98 + // 2. Apply OAuth middleware to add auth context 104 99 handler = cfg.OAuthManager.AuthMiddleware(handler) 105 100 106 - // 4. Apply rate limiting 101 + // 3. Apply rate limiting 107 102 rateLimitConfig := middleware.NewDefaultRateLimitConfig() 108 103 handler = middleware.RateLimitMiddleware(rateLimitConfig)(handler) 109 104
-2
templates/home.tmpl
··· 25 25 </div> 26 26 <div class="text-center mt-6"> 27 27 <form action="/logout" method="POST" class="inline-block"> 28 - <input type="hidden" name="csrf_token" class="csrf-token-field"> 29 28 <button type="submit" 30 29 class="bg-brown-400 text-brown-900 py-3 px-8 rounded-lg hover:bg-brown-500 transition-all text-lg font-medium shadow-md hover:shadow-lg"> 31 30 Logout ··· 37 36 <div> 38 37 <p class="text-brown-800 mb-6 text-center text-lg">Please log in with your AT Protocol handle to start tracking your brews.</p> 39 38 <form method="POST" action="/auth/login" class="max-w-md mx-auto"> 40 - <input type="hidden" name="csrf_token" class="csrf-token-field"> 41 39 <div class="relative"> 42 40 <label for="handle" class="block text-sm font-medium text-brown-900 mb-2">Your Handle</label> 43 41 <input
-2
templates/layout.tmpl
··· 18 18 <script src="https://cdn.jsdelivr.net/npm/alpinejs@3.15.3/dist/cdn.min.js" defer crossorigin="anonymous"></script> 19 19 <!-- HTMX for dynamic content loading --> 20 20 <script src="https://cdn.jsdelivr.net/npm/htmx.org@2.0.8/dist/htmx.min.js" crossorigin="anonymous"></script> 21 - <script src="/static/js/csrf.js"></script> 22 21 {{if .IsAuthenticated}} 23 22 <script src="/static/js/data-cache.js"></script> 24 23 {{end}} ··· 88 87 </a> 89 88 <div class="border-t border-brown-100 mt-1 pt-1"> 90 89 <form action="/logout" method="POST"> 91 - <input type="hidden" name="csrf_token" class="csrf-token-field"> 92 90 <button type="submit" class="w-full text-left px-4 py-2 text-sm text-brown-700 hover:bg-brown-50 transition-colors"> 93 91 Logout 94 92 </button>
-66
web/static/js/csrf.js
··· 1 - /** 2 - * CSRF Token Helper 3 - * 4 - * Provides functions to get the CSRF token from the cookie and 5 - * automatically configures HTMX to include the token on all requests. 6 - * 7 - * Usage: 8 - * // Get token manually for fetch calls 9 - * const token = getCSRFToken(); 10 - * 11 - * // Manual fetch with CSRF header 12 - * fetch('/api/beans', { 13 - * method: 'POST', 14 - * headers: { 15 - * 'Content-Type': 'application/json', 16 - * 'X-CSRF-Token': getCSRFToken() 17 - * }, 18 - * body: JSON.stringify(data) 19 - * }); 20 - */ 21 - 22 - /** 23 - * Get CSRF token from cookie 24 - * @returns {string} The CSRF token or empty string if not found 25 - */ 26 - function getCSRFToken() { 27 - const name = 'csrf_token='; 28 - const decodedCookie = decodeURIComponent(document.cookie); 29 - const cookies = decodedCookie.split(';'); 30 - 31 - for (let cookie of cookies) { 32 - cookie = cookie.trim(); 33 - if (cookie.indexOf(name) === 0) { 34 - return cookie.substring(name.length); 35 - } 36 - } 37 - return ''; 38 - } 39 - 40 - /** 41 - * Configure HTMX to automatically include CSRF token on all requests 42 - * This handles all HTMX requests (hx-get, hx-post, hx-put, hx-delete, etc.) 43 - */ 44 - document.addEventListener('DOMContentLoaded', function() { 45 - // Add CSRF token header to all HTMX requests 46 - document.body.addEventListener('htmx:configRequest', function(event) { 47 - const token = getCSRFToken(); 48 - if (token) { 49 - event.detail.headers['X-CSRF-Token'] = token; 50 - } 51 - }); 52 - 53 - // Populate hidden CSRF token fields in forms 54 - const token = getCSRFToken(); 55 - if (token) { 56 - document.querySelectorAll('.csrf-token-field').forEach(function(field) { 57 - field.value = token; 58 - }); 59 - } 60 - }); 61 - 62 - // Export for use in other modules (if using module system) 63 - // For non-module scripts, getCSRFToken is available as a global 64 - if (typeof window !== 'undefined') { 65 - window.getCSRFToken = getCSRFToken; 66 - }