A social RSS reader built on the AT Protocol. glean.at
glean atproto atmosphere rss feed social app
14
fork

Configure Feed

Select the types of activity you want to include in your feed.

Validate OAuth sessions and clear invalid cookies

+240
+18
internal/server/middleware.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 4 5 "crypto/rand" 5 6 "encoding/hex" 6 7 "net/http" 7 8 "net/url" 8 9 "strings" 9 10 "time" 11 + 12 + "github.com/bluesky-social/indigo/atproto/syntax" 10 13 ) 11 14 12 15 func (s *Server) sessionMiddleware(next http.Handler) http.Handler { 13 16 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 14 17 user := s.getUserFromSession(r) 15 18 if user != nil { 19 + data := s.getSessionData(r) 20 + if data != nil && data.SessionID != "" && !s.isOAuthSessionValid(r.Context(), data) { 21 + s.clearUserSession(w) 22 + next.ServeHTTP(w, r) 23 + return 24 + } 16 25 ctx := contextWithUser(r.Context(), user) 17 26 r = r.WithContext(ctx) 18 27 } 19 28 next.ServeHTTP(w, r) 20 29 }) 30 + } 31 + 32 + func (s *Server) isOAuthSessionValid(ctx context.Context, data *sessionData) bool { 33 + did, err := syntax.ParseDID(data.DID) 34 + if err != nil { 35 + return false 36 + } 37 + _, err = s.oauthStore.GetSession(ctx, did, data.SessionID) 38 + return err == nil 21 39 } 22 40 23 41 func (s *Server) requireAuth(next http.Handler) http.Handler {
+222
internal/server/middleware_test.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "net/http" 6 + "net/http/httptest" 7 + "os" 8 + "testing" 9 + 10 + oauth "github.com/bluesky-social/indigo/atproto/auth/oauth" 11 + "github.com/bluesky-social/indigo/atproto/syntax" 12 + "gotest.tools/v3/assert" 13 + 14 + "pkg.rbrt.fr/glean/internal/db" 15 + ) 16 + 17 + func setupTestServer(t *testing.T) (*Server, *db.Store) { 18 + t.Helper() 19 + f, err := os.CreateTemp("", "glean-test-*.db") 20 + assert.NilError(t, err) 21 + assert.NilError(t, f.Close()) 22 + path := f.Name() 23 + t.Cleanup(func() { 24 + for _, suffix := range []string{"", "_users", "_users-shm", "_users-wal", "_articles", "_articles-shm", "_articles-wal", "_recs", "_recs-shm", "_recs-wal"} { 25 + _ = os.Remove(path + suffix) 26 + } 27 + }) 28 + 29 + dbs, err := db.Open(path) 30 + assert.NilError(t, err) 31 + t.Cleanup(func() { _ = dbs.Close() }) 32 + 33 + s := &Server{ 34 + dbs: dbs, 35 + oauthStore: db.NewOAuthStore(dbs), 36 + sessionKey: []byte("test-session-key-32-bytes-long!"), 37 + } 38 + return s, dbs 39 + } 40 + 41 + func encodeTestSession(t *testing.T, s *Server, data sessionData) string { 42 + t.Helper() 43 + encoded, err := encodeSession(s.sessionKey, data) 44 + assert.NilError(t, err) 45 + return encoded 46 + } 47 + 48 + func seedOAuthSession(t *testing.T, s *Server, did, sessionID string) { 49 + t.Helper() 50 + parsedDID, err := syntax.ParseDID(did) 51 + assert.NilError(t, err) 52 + sessData := oauth.ClientSessionData{ 53 + AccountDID: parsedDID, 54 + SessionID: sessionID, 55 + HostURL: "https://example.com", 56 + AccessToken: "test-access-token", 57 + RefreshToken: "test-refresh-token", 58 + } 59 + err = s.oauthStore.SaveSession(context.Background(), sessData) 60 + assert.NilError(t, err) 61 + } 62 + 63 + func TestSessionMiddleware_ValidOAuthSession(t *testing.T) { 64 + s, dbs := setupTestServer(t) 65 + ctx := context.Background() 66 + 67 + _, err := dbs.Users.CreateUser(ctx, "did:test:user1") 68 + assert.NilError(t, err) 69 + 70 + seedOAuthSession(t, s, "did:test:user1", "session-1") 71 + 72 + cookieVal := encodeTestSession(t, s, sessionData{ 73 + DID: "did:test:user1", 74 + SessionID: "session-1", 75 + }) 76 + 77 + called := false 78 + handler := s.sessionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 + called = true 80 + u := currentUser(r) 81 + assert.Assert(t, u != nil) 82 + assert.Equal(t, u.DID, "did:test:user1") 83 + w.WriteHeader(http.StatusOK) 84 + })) 85 + 86 + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) 87 + req.AddCookie(&http.Cookie{Name: "glean_session", Value: cookieVal}) 88 + rec := httptest.NewRecorder() 89 + 90 + handler.ServeHTTP(rec, req) 91 + assert.Assert(t, called) 92 + assert.Equal(t, rec.Code, http.StatusOK) 93 + } 94 + 95 + func TestSessionMiddleware_InvalidOAuthSession_ClearsCookie(t *testing.T) { 96 + s, dbs := setupTestServer(t) 97 + ctx := context.Background() 98 + 99 + _, err := dbs.Users.CreateUser(ctx, "did:test:user2") 100 + assert.NilError(t, err) 101 + 102 + cookieVal := encodeTestSession(t, s, sessionData{ 103 + DID: "did:test:user2", 104 + SessionID: "session-gone", 105 + }) 106 + 107 + called := false 108 + handler := s.sessionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 109 + called = true 110 + u := currentUser(r) 111 + assert.Assert(t, u == nil) 112 + w.WriteHeader(http.StatusOK) 113 + })) 114 + 115 + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) 116 + req.AddCookie(&http.Cookie{Name: "glean_session", Value: cookieVal}) 117 + rec := httptest.NewRecorder() 118 + 119 + handler.ServeHTTP(rec, req) 120 + assert.Assert(t, called) 121 + assert.Equal(t, rec.Code, http.StatusOK) 122 + 123 + var cleared bool 124 + for _, c := range rec.Result().Cookies() { 125 + if c.Name == "glean_session" && c.MaxAge == -1 { 126 + cleared = true 127 + } 128 + } 129 + assert.Assert(t, cleared, "expected glean_session cookie to be cleared") 130 + } 131 + 132 + func TestSessionMiddleware_NonOAuthSession_PassesThrough(t *testing.T) { 133 + s, dbs := setupTestServer(t) 134 + ctx := context.Background() 135 + 136 + _, err := dbs.Users.CreateUser(ctx, "did:test:user3") 137 + assert.NilError(t, err) 138 + 139 + cookieVal := encodeTestSession(t, s, sessionData{ 140 + DID: "did:test:user3", 141 + }) 142 + 143 + called := false 144 + handler := s.sessionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 145 + called = true 146 + u := currentUser(r) 147 + assert.Assert(t, u != nil) 148 + assert.Equal(t, u.DID, "did:test:user3") 149 + w.WriteHeader(http.StatusOK) 150 + })) 151 + 152 + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) 153 + req.AddCookie(&http.Cookie{Name: "glean_session", Value: cookieVal}) 154 + rec := httptest.NewRecorder() 155 + 156 + handler.ServeHTTP(rec, req) 157 + assert.Assert(t, called) 158 + assert.Equal(t, rec.Code, http.StatusOK) 159 + } 160 + 161 + func TestSessionMiddleware_OAuthSessionDeletedAfterLogin_ClearsCookie(t *testing.T) { 162 + s, dbs := setupTestServer(t) 163 + ctx := context.Background() 164 + 165 + _, err := dbs.Users.CreateUser(ctx, "did:test:user4") 166 + assert.NilError(t, err) 167 + 168 + seedOAuthSession(t, s, "did:test:user4", "session-4") 169 + 170 + cookieVal := encodeTestSession(t, s, sessionData{ 171 + DID: "did:test:user4", 172 + SessionID: "session-4", 173 + }) 174 + 175 + handler := s.sessionMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 176 + u := currentUser(r) 177 + if u == nil { 178 + w.WriteHeader(http.StatusUnauthorized) 179 + return 180 + } 181 + w.WriteHeader(http.StatusOK) 182 + })) 183 + 184 + req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) 185 + req.AddCookie(&http.Cookie{Name: "glean_session", Value: cookieVal}) 186 + rec := httptest.NewRecorder() 187 + handler.ServeHTTP(rec, req) 188 + assert.Equal(t, rec.Code, http.StatusOK) 189 + 190 + parsedDID, err := syntax.ParseDID("did:test:user4") 191 + assert.NilError(t, err) 192 + err = s.oauthStore.DeleteSession(ctx, parsedDID, "session-4") 193 + assert.NilError(t, err) 194 + 195 + req2 := httptest.NewRequest(http.MethodGet, "/dashboard", nil) 196 + req2.AddCookie(&http.Cookie{Name: "glean_session", Value: cookieVal}) 197 + rec2 := httptest.NewRecorder() 198 + handler.ServeHTTP(rec2, req2) 199 + assert.Equal(t, rec2.Code, http.StatusUnauthorized) 200 + } 201 + 202 + func TestIsOAuthSessionValid(t *testing.T) { 203 + s, _ := setupTestServer(t) 204 + ctx := context.Background() 205 + 206 + seedOAuthSession(t, s, "did:test:valid", "sess-1") 207 + 208 + assert.Assert(t, s.isOAuthSessionValid(ctx, &sessionData{ 209 + DID: "did:test:valid", 210 + SessionID: "sess-1", 211 + })) 212 + 213 + assert.Assert(t, !s.isOAuthSessionValid(ctx, &sessionData{ 214 + DID: "did:test:valid", 215 + SessionID: "sess-nonexistent", 216 + })) 217 + 218 + assert.Assert(t, !s.isOAuthSessionValid(ctx, &sessionData{ 219 + DID: "did:test:nonexistent", 220 + SessionID: "sess-1", 221 + })) 222 + }