Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 240 lines 6.0 kB view raw
1package main 2 3import ( 4 "context" 5 "fmt" 6 "net" 7 "net/http" 8 "net/url" 9 "strconv" 10 "strings" 11 "sync" 12 "time" 13) 14 15var reservedIPv4Nets = []net.IPNet{ 16 ipv4Net(0, 0, 0, 0, 8), 17 ipv4Net(10, 0, 0, 0, 8), 18 ipv4Net(100, 64, 0, 0, 10), 19 ipv4Net(127, 0, 0, 0, 8), 20 ipv4Net(169, 254, 0, 0, 16), 21 ipv4Net(172, 16, 0, 0, 12), 22 ipv4Net(192, 0, 0, 0, 24), 23 ipv4Net(192, 0, 2, 0, 24), 24 ipv4Net(192, 88, 99, 0, 24), 25 ipv4Net(192, 168, 0, 0, 16), 26 ipv4Net(198, 18, 0, 0, 15), 27 ipv4Net(198, 51, 100, 0, 24), 28 ipv4Net(203, 0, 113, 0, 24), 29 ipv4Net(224, 0, 0, 0, 4), 30 ipv4Net(240, 0, 0, 0, 4), 31} 32 33var globalUnicastIPv6Net = net.IPNet{ 34 IP: net.IP{0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 35 Mask: net.CIDRMask(3, 128), 36} 37 38var allowedHosts sync.Map 39 40func allowHost(host string) { 41 allowedHosts.Store(strings.TrimSpace(strings.ToLower(host)), true) 42} 43 44func clearAllowedHosts() { 45 allowedHosts = sync.Map{} 46} 47 48func ValidateTokenEndpoint(endpoint string) error { 49 _, err := validateEndpointURL(endpoint) 50 return err 51} 52 53func ValidatePAREndpoint(endpoint string) error { 54 _, err := validateEndpointURL(endpoint) 55 return err 56} 57 58func ValidateIssuer(issuer string) (*url.URL, error) { 59 return validatePublicURL(issuer, true) 60} 61 62func validateEndpointURL(rawURL string) (*url.URL, error) { 63 return validatePublicURL(rawURL, false) 64} 65 66func validatePublicURL(rawURL string, requireOrigin bool) (*url.URL, error) { 67 u, err := url.Parse(rawURL) 68 if err != nil { 69 return nil, fmt.Errorf("invalid URL: %w", err) 70 } 71 72 if u.Scheme != "https" { 73 return nil, fmt.Errorf("URL must use HTTPS") 74 } 75 if u.User != nil { 76 return nil, fmt.Errorf("URL must not include userinfo") 77 } 78 if u.Hostname() == "" { 79 return nil, fmt.Errorf("URL must have a hostname") 80 } 81 if !isAllowedHost(u.Hostname()) && isBlockedHostname(u.Hostname()) { 82 return nil, fmt.Errorf("URL must not target localhost") 83 } 84 if u.RawQuery != "" { 85 return nil, fmt.Errorf("URL must not include a query string") 86 } 87 if u.Fragment != "" { 88 return nil, fmt.Errorf("URL must not include a fragment") 89 } 90 91 if requireOrigin { 92 if u.Path != "" && u.Path != "/" { 93 return nil, fmt.Errorf("issuer must be an origin URL without a path") 94 } 95 if u.Port() == "443" { 96 return nil, fmt.Errorf("issuer must not include the default HTTPS port") 97 } 98 } 99 100 if ip := net.ParseIP(u.Hostname()); ip != nil && !IsPublicIPAddress(ip) && !isAllowedHost(u.Hostname()) { 101 return nil, fmt.Errorf("URL must not target a private or reserved IP address") 102 } 103 104 return u, nil 105} 106 107func newPublicHTTPClient(timeout time.Duration) *http.Client { 108 return &http.Client{ 109 Timeout: timeout, 110 Transport: newPublicOnlyTransport(), 111 CheckRedirect: func(req *http.Request, via []*http.Request) error { 112 if len(via) >= 5 { 113 return fmt.Errorf("too many redirects") 114 } 115 _, err := validateEndpointURL(req.URL.String()) 116 return err 117 }, 118 } 119} 120 121func newPublicOnlyTransport() *http.Transport { 122 transport := http.DefaultTransport.(*http.Transport).Clone() 123 dialer := &net.Dialer{ 124 Timeout: 10 * time.Second, 125 KeepAlive: 30 * time.Second, 126 } 127 128 transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { 129 return dialPublicContext(ctx, dialer, network, address) 130 } 131 transport.ForceAttemptHTTP2 = true 132 transport.MaxIdleConns = 100 133 transport.IdleConnTimeout = 90 * time.Second 134 transport.TLSHandshakeTimeout = 10 * time.Second 135 transport.ResponseHeaderTimeout = 10 * time.Second 136 transport.ExpectContinueTimeout = time.Second 137 return transport 138} 139 140func dialPublicContext(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) { 141 host, port, err := net.SplitHostPort(address) 142 if err != nil { 143 return nil, fmt.Errorf("invalid address %q: %w", address, err) 144 } 145 146 if !isAllowedHost(host) && isBlockedHostname(host) { 147 return nil, fmt.Errorf("blocked host %q", host) 148 } 149 150 portNum, err := strconv.Atoi(port) 151 if err != nil || portNum < 1 || portNum > 65535 { 152 return nil, fmt.Errorf("invalid port %q", port) 153 } 154 155 if ip := net.ParseIP(host); ip != nil { 156 if !IsPublicIPAddress(ip) && !isAllowedHost(host) { 157 return nil, fmt.Errorf("blocked IP address %s", ip) 158 } 159 return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) 160 } 161 162 resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host) 163 if err != nil { 164 return nil, fmt.Errorf("failed to resolve host %q: %w", host, err) 165 } 166 if len(resolved) == 0 { 167 return nil, fmt.Errorf("host %q did not resolve to any IPs", host) 168 } 169 170 addresses := make([]string, 0, len(resolved)) 171 for _, addr := range resolved { 172 if !IsPublicIPAddress(addr.IP) && !isAllowedHost(host) { 173 return nil, fmt.Errorf("host %q resolves to non-public IP %s", host, addr.IP) 174 } 175 addresses = append(addresses, net.JoinHostPort(addr.IP.String(), port)) 176 } 177 178 var lastErr error 179 for _, dialAddress := range addresses { 180 conn, err := dialer.DialContext(ctx, network, dialAddress) 181 if err == nil { 182 return conn, nil 183 } 184 lastErr = err 185 } 186 187 return nil, fmt.Errorf("failed to connect to %q: %w", host, lastErr) 188} 189 190func ipv4Net(a, b, c, d byte, subnetPrefixLen int) net.IPNet { 191 return net.IPNet{ 192 IP: net.IPv4(a, b, c, d), 193 Mask: net.CIDRMask(96+subnetPrefixLen, 128), 194 } 195} 196 197func IsPublicIPAddress(ip net.IP) bool { 198 if ip4 := ip.To4(); ip4 != nil { 199 for _, reserved := range reservedIPv4Nets { 200 if reserved.Contains(ip4) { 201 return false 202 } 203 } 204 return true 205 } 206 207 return globalUnicastIPv6Net.Contains(ip) 208} 209 210func isPrivateHost(host string) bool { 211 if isBlockedHostname(host) { 212 return true 213 } 214 215 ip := net.ParseIP(host) 216 if ip == nil { 217 return false 218 } 219 220 return !IsPublicIPAddress(ip) 221} 222 223func isPrivateIP(ip net.IP) bool { 224 return !IsPublicIPAddress(ip) 225} 226 227func isBlockedHostname(host string) bool { 228 host = strings.TrimSpace(strings.ToLower(host)) 229 return host == "localhost" || strings.HasSuffix(host, ".localhost") 230} 231 232func isAllowedHost(host string) bool { 233 host = strings.TrimSpace(strings.ToLower(host)) 234 if host == "" { 235 return false 236 } 237 238 _, ok := allowedHosts.Load(host) 239 return ok 240}