this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(api): add /r/{id} redirect shortlink endpoint

Add a public redirect handler that redirects users to the actual URL
associated with a link ID. The handler validates URL schemes to prevent
open redirect attacks and supports optional click tracking via HMAC
signatures.

Features:
- GET /r/{id} returns 302 redirect to stored URL
- Validates http/https schemes only (blocks javascript:, data:, etc.)
- Click tracking with valid signature via sig query param
- 400 for invalid IDs, 404 for non-existent links

+385
+64
internal/handler/api_v1_redirect.go
··· 1 + package handler 2 + 3 + import ( 4 + "context" 5 + "log" 6 + "net/http" 7 + "strconv" 8 + "strings" 9 + ) 10 + 11 + // APIv1RedirectHandler handles GET /r/{id} - the public shortlink redirect. 12 + // This redirects users to the actual URL associated with a link ID. 13 + // If a valid click signature is provided via the sig query parameter, 14 + // the click count is incremented asynchronously. 15 + func (h *Handler) APIv1RedirectHandler(w http.ResponseWriter, r *http.Request) { 16 + // Only allow GET and HEAD methods 17 + if r.Method != http.MethodGet && r.Method != http.MethodHead { 18 + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) 19 + return 20 + } 21 + 22 + ctx := r.Context() 23 + 24 + // Parse ID from path: /r/{id} 25 + path := r.URL.Path 26 + idStr := strings.TrimPrefix(path, "/r/") 27 + 28 + // Check if we got a valid ID string 29 + if idStr == "" || idStr == path { 30 + http.Error(w, "Invalid ID", http.StatusBadRequest) 31 + return 32 + } 33 + 34 + id, err := strconv.Atoi(idStr) 35 + if err != nil || id < 0 { 36 + http.Error(w, "Invalid ID", http.StatusBadRequest) 37 + return 38 + } 39 + 40 + // Only increment clicks if signature is valid 41 + // This prevents bots from inflating click counts by hitting URLs directly 42 + sig := r.URL.Query().Get("sig") 43 + if ValidateClickSignature(id, sig, h.Config.ClickSigningKey) { 44 + go h.Store.IncrementClicks(context.Background(), id) // Async 45 + } 46 + 47 + // Get the redirect URL 48 + redirectURL, err := h.Store.GetIRCLinkURL(ctx, id) 49 + if err != nil { 50 + http.NotFound(w, r) 51 + return 52 + } 53 + 54 + // Validate URL scheme to prevent open redirect attacks (javascript:, data:, etc.) 55 + if !strings.HasPrefix(redirectURL, "http://") && !strings.HasPrefix(redirectURL, "https://") { 56 + log.Printf("Blocked redirect to invalid scheme: url=%q path=%q remote_addr=%q user_agent=%q referer=%q", 57 + redirectURL, r.URL.Path, r.RemoteAddr, r.UserAgent(), r.Referer()) 58 + http.Error(w, "Invalid redirect URL", http.StatusBadRequest) 59 + return 60 + } 61 + 62 + log.Printf("id: [%d] Location: %s", id, redirectURL) 63 + http.Redirect(w, r, redirectURL, http.StatusFound) 64 + }
+321
internal/handler/api_v1_redirect_test.go
··· 1 + package handler 2 + 3 + import ( 4 + "context" 5 + "errors" 6 + "net/http" 7 + "net/http/httptest" 8 + "sync/atomic" 9 + "testing" 10 + "time" 11 + 12 + "tumble/internal/config" 13 + ) 14 + 15 + // mockRedirectStore is a mock implementation for testing redirect handler. 16 + type mockRedirectStore struct { 17 + mockAPIStore 18 + linkURL string 19 + linkURLErr error 20 + linkURLFn func(id int) (string, error) 21 + incrementCalled atomic.Int32 22 + incrementFn func(id int) error 23 + } 24 + 25 + func (m *mockRedirectStore) GetIRCLinkURL(ctx context.Context, id int) (string, error) { 26 + if m.linkURLFn != nil { 27 + return m.linkURLFn(id) 28 + } 29 + if m.linkURLErr != nil { 30 + return "", m.linkURLErr 31 + } 32 + return m.linkURL, nil 33 + } 34 + 35 + func (m *mockRedirectStore) IncrementClicks(ctx context.Context, id int) error { 36 + m.incrementCalled.Add(1) 37 + if m.incrementFn != nil { 38 + return m.incrementFn(id) 39 + } 40 + return nil 41 + } 42 + 43 + func TestAPIv1_RedirectHandler(t *testing.T) { 44 + tests := []struct { 45 + name string 46 + path string 47 + linkURL string 48 + linkURLErr error 49 + linkURLFn func(id int) (string, error) 50 + expectedStatus int 51 + expectedLocation string 52 + checkBody func(t *testing.T, body string) 53 + }{ 54 + { 55 + name: "valid ID redirects with 302", 56 + path: "/r/123", 57 + linkURL: "https://example.com/article", 58 + expectedStatus: http.StatusFound, 59 + expectedLocation: "https://example.com/article", 60 + }, 61 + { 62 + name: "valid ID with http scheme redirects", 63 + path: "/r/456", 64 + linkURL: "http://example.com/page", 65 + expectedStatus: http.StatusFound, 66 + expectedLocation: "http://example.com/page", 67 + }, 68 + { 69 + name: "invalid ID returns 400", 70 + path: "/r/abc", 71 + expectedStatus: http.StatusBadRequest, 72 + checkBody: func(t *testing.T, body string) { 73 + if body != "Invalid ID\n" { 74 + t.Errorf("expected 'Invalid ID\\n', got %q", body) 75 + } 76 + }, 77 + }, 78 + { 79 + name: "empty ID returns 400", 80 + path: "/r/", 81 + expectedStatus: http.StatusBadRequest, 82 + checkBody: func(t *testing.T, body string) { 83 + if body != "Invalid ID\n" { 84 + t.Errorf("expected 'Invalid ID\\n', got %q", body) 85 + } 86 + }, 87 + }, 88 + { 89 + name: "negative ID returns 400", 90 + path: "/r/-5", 91 + expectedStatus: http.StatusBadRequest, 92 + checkBody: func(t *testing.T, body string) { 93 + if body != "Invalid ID\n" { 94 + t.Errorf("expected 'Invalid ID\\n', got %q", body) 95 + } 96 + }, 97 + }, 98 + { 99 + name: "non-existent link returns 404", 100 + path: "/r/999", 101 + linkURLErr: errors.New("link not found"), 102 + expectedStatus: http.StatusNotFound, 103 + }, 104 + { 105 + name: "store returns not found error returns 404", 106 + path: "/r/888", 107 + linkURLFn: func(id int) (string, error) { 108 + return "", errors.New("record not found") 109 + }, 110 + expectedStatus: http.StatusNotFound, 111 + }, 112 + { 113 + name: "javascript scheme is blocked", 114 + path: "/r/123", 115 + linkURL: "javascript:alert(1)", 116 + expectedStatus: http.StatusBadRequest, 117 + checkBody: func(t *testing.T, body string) { 118 + if body != "Invalid redirect URL\n" { 119 + t.Errorf("expected 'Invalid redirect URL\\n', got %q", body) 120 + } 121 + }, 122 + }, 123 + { 124 + name: "data scheme is blocked", 125 + path: "/r/123", 126 + linkURL: "data:text/html,<script>alert(1)</script>", 127 + expectedStatus: http.StatusBadRequest, 128 + checkBody: func(t *testing.T, body string) { 129 + if body != "Invalid redirect URL\n" { 130 + t.Errorf("expected 'Invalid redirect URL\\n', got %q", body) 131 + } 132 + }, 133 + }, 134 + { 135 + name: "file scheme is blocked", 136 + path: "/r/123", 137 + linkURL: "file:///etc/passwd", 138 + expectedStatus: http.StatusBadRequest, 139 + checkBody: func(t *testing.T, body string) { 140 + if body != "Invalid redirect URL\n" { 141 + t.Errorf("expected 'Invalid redirect URL\\n', got %q", body) 142 + } 143 + }, 144 + }, 145 + } 146 + 147 + for _, tt := range tests { 148 + t.Run(tt.name, func(t *testing.T) { 149 + store := &mockRedirectStore{ 150 + linkURL: tt.linkURL, 151 + linkURLErr: tt.linkURLErr, 152 + linkURLFn: tt.linkURLFn, 153 + } 154 + handler := &Handler{ 155 + Store: store, 156 + Config: &config.Config{}, 157 + } 158 + 159 + req := httptest.NewRequest(http.MethodGet, tt.path, nil) 160 + w := httptest.NewRecorder() 161 + 162 + handler.APIv1RedirectHandler(w, req) 163 + 164 + if w.Code != tt.expectedStatus { 165 + t.Errorf("expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String()) 166 + } 167 + 168 + if tt.expectedLocation != "" { 169 + location := w.Header().Get("Location") 170 + if location != tt.expectedLocation { 171 + t.Errorf("expected Location %q, got %q", tt.expectedLocation, location) 172 + } 173 + } 174 + 175 + if tt.checkBody != nil { 176 + tt.checkBody(t, w.Body.String()) 177 + } 178 + }) 179 + } 180 + } 181 + 182 + func TestAPIv1_RedirectHandler_ClickTracking(t *testing.T) { 183 + tests := []struct { 184 + name string 185 + path string 186 + sigQueryParam string 187 + clickSigningKey string 188 + linkURL string 189 + expectIncrement bool 190 + expectedStatus int 191 + }{ 192 + { 193 + name: "valid signature increments clicks", 194 + path: "/r/123", 195 + clickSigningKey: "testsecret", 196 + linkURL: "https://example.com", 197 + expectIncrement: true, 198 + expectedStatus: http.StatusFound, 199 + }, 200 + { 201 + name: "invalid signature does not increment", 202 + path: "/r/123", 203 + sigQueryParam: "invalidsig", 204 + clickSigningKey: "testsecret", 205 + linkURL: "https://example.com", 206 + expectIncrement: false, 207 + expectedStatus: http.StatusFound, 208 + }, 209 + { 210 + name: "missing signature does not increment", 211 + path: "/r/123", 212 + sigQueryParam: "", 213 + clickSigningKey: "testsecret", 214 + linkURL: "https://example.com", 215 + expectIncrement: false, 216 + expectedStatus: http.StatusFound, 217 + }, 218 + { 219 + name: "no signing key configured does not increment", 220 + path: "/r/123", 221 + clickSigningKey: "", 222 + linkURL: "https://example.com", 223 + expectIncrement: false, 224 + expectedStatus: http.StatusFound, 225 + }, 226 + } 227 + 228 + for _, tt := range tests { 229 + t.Run(tt.name, func(t *testing.T) { 230 + store := &mockRedirectStore{ 231 + linkURL: tt.linkURL, 232 + } 233 + handler := &Handler{ 234 + Store: store, 235 + Config: &config.Config{ 236 + ClickSigningKey: tt.clickSigningKey, 237 + }, 238 + } 239 + 240 + // Generate valid signature if needed 241 + path := tt.path 242 + if tt.expectIncrement && tt.clickSigningKey != "" { 243 + // Generate a valid signature for this test 244 + sig := GenerateClickSignature(123, tt.clickSigningKey) 245 + path = tt.path + "?sig=" + sig 246 + } else if tt.sigQueryParam != "" { 247 + path = tt.path + "?sig=" + tt.sigQueryParam 248 + } 249 + 250 + req := httptest.NewRequest(http.MethodGet, path, nil) 251 + w := httptest.NewRecorder() 252 + 253 + handler.APIv1RedirectHandler(w, req) 254 + 255 + if w.Code != tt.expectedStatus { 256 + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) 257 + } 258 + 259 + // Give async goroutine time to execute 260 + time.Sleep(10 * time.Millisecond) 261 + 262 + incrementCount := store.incrementCalled.Load() 263 + if tt.expectIncrement && incrementCount == 0 { 264 + t.Error("expected IncrementClicks to be called, but it wasn't") 265 + } 266 + if !tt.expectIncrement && incrementCount > 0 { 267 + t.Error("expected IncrementClicks NOT to be called, but it was") 268 + } 269 + }) 270 + } 271 + } 272 + 273 + func TestAPIv1_RedirectHandler_MethodNotAllowed(t *testing.T) { 274 + store := &mockRedirectStore{ 275 + linkURL: "https://example.com", 276 + } 277 + handler := &Handler{ 278 + Store: store, 279 + Config: &config.Config{}, 280 + } 281 + 282 + methods := []string{http.MethodPost, http.MethodPut, http.MethodDelete, http.MethodPatch} 283 + 284 + for _, method := range methods { 285 + t.Run(method, func(t *testing.T) { 286 + req := httptest.NewRequest(method, "/r/123", nil) 287 + w := httptest.NewRecorder() 288 + 289 + handler.APIv1RedirectHandler(w, req) 290 + 291 + if w.Code != http.StatusMethodNotAllowed { 292 + t.Errorf("expected status %d for %s, got %d", http.StatusMethodNotAllowed, method, w.Code) 293 + } 294 + }) 295 + } 296 + } 297 + 298 + func TestAPIv1_RedirectHandler_HeadMethod(t *testing.T) { 299 + store := &mockRedirectStore{ 300 + linkURL: "https://example.com", 301 + } 302 + handler := &Handler{ 303 + Store: store, 304 + Config: &config.Config{}, 305 + } 306 + 307 + req := httptest.NewRequest(http.MethodHead, "/r/123", nil) 308 + w := httptest.NewRecorder() 309 + 310 + handler.APIv1RedirectHandler(w, req) 311 + 312 + // HEAD should work like GET for redirects 313 + if w.Code != http.StatusFound { 314 + t.Errorf("expected status %d for HEAD, got %d", http.StatusFound, w.Code) 315 + } 316 + 317 + location := w.Header().Get("Location") 318 + if location != "https://example.com" { 319 + t.Errorf("expected Location https://example.com, got %q", location) 320 + } 321 + }