Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
1package main
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "net/http"
8 "net/url"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13)
14
15var reservedIPv4Nets = []net.IPNet{
16 ipv4Net(0, 0, 0, 0, 8),
17 ipv4Net(10, 0, 0, 0, 8),
18 ipv4Net(100, 64, 0, 0, 10),
19 ipv4Net(127, 0, 0, 0, 8),
20 ipv4Net(169, 254, 0, 0, 16),
21 ipv4Net(172, 16, 0, 0, 12),
22 ipv4Net(192, 0, 0, 0, 24),
23 ipv4Net(192, 0, 2, 0, 24),
24 ipv4Net(192, 88, 99, 0, 24),
25 ipv4Net(192, 168, 0, 0, 16),
26 ipv4Net(198, 18, 0, 0, 15),
27 ipv4Net(198, 51, 100, 0, 24),
28 ipv4Net(203, 0, 113, 0, 24),
29 ipv4Net(224, 0, 0, 0, 4),
30 ipv4Net(240, 0, 0, 0, 4),
31}
32
33var globalUnicastIPv6Net = net.IPNet{
34 IP: net.IP{0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
35 Mask: net.CIDRMask(3, 128),
36}
37
38var allowedHosts sync.Map
39
40func allowHost(host string) {
41 allowedHosts.Store(strings.TrimSpace(strings.ToLower(host)), true)
42}
43
44func clearAllowedHosts() {
45 allowedHosts = sync.Map{}
46}
47
48func ValidateTokenEndpoint(endpoint string) error {
49 _, err := validateEndpointURL(endpoint)
50 return err
51}
52
53func ValidatePAREndpoint(endpoint string) error {
54 _, err := validateEndpointURL(endpoint)
55 return err
56}
57
58func ValidateIssuer(issuer string) (*url.URL, error) {
59 return validatePublicURL(issuer, true)
60}
61
62func validateEndpointURL(rawURL string) (*url.URL, error) {
63 return validatePublicURL(rawURL, false)
64}
65
66func validatePublicURL(rawURL string, requireOrigin bool) (*url.URL, error) {
67 u, err := url.Parse(rawURL)
68 if err != nil {
69 return nil, fmt.Errorf("invalid URL: %w", err)
70 }
71
72 if u.Scheme != "https" {
73 return nil, fmt.Errorf("URL must use HTTPS")
74 }
75 if u.User != nil {
76 return nil, fmt.Errorf("URL must not include userinfo")
77 }
78 if u.Hostname() == "" {
79 return nil, fmt.Errorf("URL must have a hostname")
80 }
81 if !isAllowedHost(u.Hostname()) && isBlockedHostname(u.Hostname()) {
82 return nil, fmt.Errorf("URL must not target localhost")
83 }
84 if u.RawQuery != "" {
85 return nil, fmt.Errorf("URL must not include a query string")
86 }
87 if u.Fragment != "" {
88 return nil, fmt.Errorf("URL must not include a fragment")
89 }
90
91 if requireOrigin {
92 if u.Path != "" && u.Path != "/" {
93 return nil, fmt.Errorf("issuer must be an origin URL without a path")
94 }
95 if u.Port() == "443" {
96 return nil, fmt.Errorf("issuer must not include the default HTTPS port")
97 }
98 }
99
100 if ip := net.ParseIP(u.Hostname()); ip != nil && !IsPublicIPAddress(ip) && !isAllowedHost(u.Hostname()) {
101 return nil, fmt.Errorf("URL must not target a private or reserved IP address")
102 }
103
104 return u, nil
105}
106
107func newPublicHTTPClient(timeout time.Duration) *http.Client {
108 return &http.Client{
109 Timeout: timeout,
110 Transport: newPublicOnlyTransport(),
111 CheckRedirect: func(req *http.Request, via []*http.Request) error {
112 if len(via) >= 5 {
113 return fmt.Errorf("too many redirects")
114 }
115 _, err := validateEndpointURL(req.URL.String())
116 return err
117 },
118 }
119}
120
121func newPublicOnlyTransport() *http.Transport {
122 transport := http.DefaultTransport.(*http.Transport).Clone()
123 dialer := &net.Dialer{
124 Timeout: 10 * time.Second,
125 KeepAlive: 30 * time.Second,
126 }
127
128 transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
129 return dialPublicContext(ctx, dialer, network, address)
130 }
131 transport.ForceAttemptHTTP2 = true
132 transport.MaxIdleConns = 100
133 transport.IdleConnTimeout = 90 * time.Second
134 transport.TLSHandshakeTimeout = 10 * time.Second
135 transport.ResponseHeaderTimeout = 10 * time.Second
136 transport.ExpectContinueTimeout = time.Second
137 return transport
138}
139
140func dialPublicContext(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) {
141 host, port, err := net.SplitHostPort(address)
142 if err != nil {
143 return nil, fmt.Errorf("invalid address %q: %w", address, err)
144 }
145
146 if !isAllowedHost(host) && isBlockedHostname(host) {
147 return nil, fmt.Errorf("blocked host %q", host)
148 }
149
150 portNum, err := strconv.Atoi(port)
151 if err != nil || portNum < 1 || portNum > 65535 {
152 return nil, fmt.Errorf("invalid port %q", port)
153 }
154
155 if ip := net.ParseIP(host); ip != nil {
156 if !IsPublicIPAddress(ip) && !isAllowedHost(host) {
157 return nil, fmt.Errorf("blocked IP address %s", ip)
158 }
159 return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
160 }
161
162 resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host)
163 if err != nil {
164 return nil, fmt.Errorf("failed to resolve host %q: %w", host, err)
165 }
166 if len(resolved) == 0 {
167 return nil, fmt.Errorf("host %q did not resolve to any IPs", host)
168 }
169
170 addresses := make([]string, 0, len(resolved))
171 for _, addr := range resolved {
172 if !IsPublicIPAddress(addr.IP) && !isAllowedHost(host) {
173 return nil, fmt.Errorf("host %q resolves to non-public IP %s", host, addr.IP)
174 }
175 addresses = append(addresses, net.JoinHostPort(addr.IP.String(), port))
176 }
177
178 var lastErr error
179 for _, dialAddress := range addresses {
180 conn, err := dialer.DialContext(ctx, network, dialAddress)
181 if err == nil {
182 return conn, nil
183 }
184 lastErr = err
185 }
186
187 return nil, fmt.Errorf("failed to connect to %q: %w", host, lastErr)
188}
189
190func ipv4Net(a, b, c, d byte, subnetPrefixLen int) net.IPNet {
191 return net.IPNet{
192 IP: net.IPv4(a, b, c, d),
193 Mask: net.CIDRMask(96+subnetPrefixLen, 128),
194 }
195}
196
197func IsPublicIPAddress(ip net.IP) bool {
198 if ip4 := ip.To4(); ip4 != nil {
199 for _, reserved := range reservedIPv4Nets {
200 if reserved.Contains(ip4) {
201 return false
202 }
203 }
204 return true
205 }
206
207 return globalUnicastIPv6Net.Contains(ip)
208}
209
210func isPrivateHost(host string) bool {
211 if isBlockedHostname(host) {
212 return true
213 }
214
215 ip := net.ParseIP(host)
216 if ip == nil {
217 return false
218 }
219
220 return !IsPublicIPAddress(ip)
221}
222
223func isPrivateIP(ip net.IP) bool {
224 return !IsPublicIPAddress(ip)
225}
226
227func isBlockedHostname(host string) bool {
228 host = strings.TrimSpace(strings.ToLower(host))
229 return host == "localhost" || strings.HasSuffix(host, ".localhost")
230}
231
232func isAllowedHost(host string) bool {
233 host = strings.TrimSpace(strings.ToLower(host))
234 if host == "" {
235 return false
236 }
237
238 _, ok := allowedHosts.Load(host)
239 return ok
240}