The codebase that powers boop.cat
boop.cat
1// Copyright 2025 boop.cat
2// Licensed under the Apache License, Version 2.0
3// See LICENSE file for details.
4
5package middleware
6
7import (
8 "context"
9 "database/sql"
10 "net/http"
11
12 "boop-cat/db"
13 "github.com/gorilla/sessions"
14)
15
16var store *sessions.CookieStore
17
18func InitSessionStore(secret string, secure bool) {
19 store = sessions.NewCookieStore([]byte(secret))
20 store.Options = &sessions.Options{
21 Path: "/",
22 MaxAge: 86400 * 30,
23 HttpOnly: true,
24 Secure: secure,
25 SameSite: http.SameSiteLaxMode,
26 }
27}
28
29func GetSession(r *http.Request) (*sessions.Session, error) {
30 return store.Get(r, "fsd-session")
31}
32
33func WithUser(database *sql.DB) func(http.Handler) http.Handler {
34 return func(next http.Handler) http.Handler {
35 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
36 session, err := GetSession(r)
37 if err != nil {
38 next.ServeHTTP(w, r)
39 return
40 }
41
42 userID, ok := session.Values["userId"].(string)
43 if !ok || userID == "" {
44 next.ServeHTTP(w, r)
45 return
46 }
47
48 user, err := db.GetUserByID(database, userID)
49 if err == nil && user != nil && !user.Banned {
50
51 u := &db.User{
52 ID: user.ID,
53 Email: user.Email,
54 Username: user.Username,
55 EmailVerified: user.EmailVerified,
56 Banned: user.Banned,
57 }
58 ctx := context.WithValue(r.Context(), UserContextKey, u)
59 next.ServeHTTP(w, r.WithContext(ctx))
60 } else {
61 next.ServeHTTP(w, r)
62 }
63 })
64 }
65}
66
67func LoginUser(w http.ResponseWriter, r *http.Request, userID string) error {
68 session, _ := GetSession(r)
69 session.Values["userId"] = userID
70 return session.Save(r, w)
71}
72
73func LogoutUser(w http.ResponseWriter, r *http.Request) error {
74 session, _ := GetSession(r)
75 session.Values["userId"] = ""
76 session.Options.MaxAge = -1
77 return session.Save(r, w)
78}
79
80func RequireLogin(next http.Handler) http.Handler {
81 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
82 user := GetUser(r.Context())
83 if user == nil {
84 http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
85 return
86 }
87 next.ServeHTTP(w, r)
88 })
89}