package main import ( "context" "fmt" "net" "net/http" "net/url" "strconv" "strings" "sync" "time" ) var reservedIPv4Nets = []net.IPNet{ ipv4Net(0, 0, 0, 0, 8), ipv4Net(10, 0, 0, 0, 8), ipv4Net(100, 64, 0, 0, 10), ipv4Net(127, 0, 0, 0, 8), ipv4Net(169, 254, 0, 0, 16), ipv4Net(172, 16, 0, 0, 12), ipv4Net(192, 0, 0, 0, 24), ipv4Net(192, 0, 2, 0, 24), ipv4Net(192, 88, 99, 0, 24), ipv4Net(192, 168, 0, 0, 16), ipv4Net(198, 18, 0, 0, 15), ipv4Net(198, 51, 100, 0, 24), ipv4Net(203, 0, 113, 0, 24), ipv4Net(224, 0, 0, 0, 4), ipv4Net(240, 0, 0, 0, 4), } var globalUnicastIPv6Net = net.IPNet{ IP: net.IP{0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Mask: net.CIDRMask(3, 128), } var allowedHosts sync.Map func allowHost(host string) { allowedHosts.Store(strings.TrimSpace(strings.ToLower(host)), true) } func clearAllowedHosts() { allowedHosts = sync.Map{} } func ValidateTokenEndpoint(endpoint string) error { _, err := validateEndpointURL(endpoint) return err } func ValidatePAREndpoint(endpoint string) error { _, err := validateEndpointURL(endpoint) return err } func ValidateIssuer(issuer string) (*url.URL, error) { return validatePublicURL(issuer, true) } func validateEndpointURL(rawURL string) (*url.URL, error) { return validatePublicURL(rawURL, false) } func validatePublicURL(rawURL string, requireOrigin bool) (*url.URL, error) { u, err := url.Parse(rawURL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } if u.Scheme != "https" { return nil, fmt.Errorf("URL must use HTTPS") } if u.User != nil { return nil, fmt.Errorf("URL must not include userinfo") } if u.Hostname() == "" { return nil, fmt.Errorf("URL must have a hostname") } if !isAllowedHost(u.Hostname()) && isBlockedHostname(u.Hostname()) { return nil, fmt.Errorf("URL must not target localhost") } if u.RawQuery != "" { return nil, fmt.Errorf("URL must not include a query string") } if u.Fragment != "" { return nil, fmt.Errorf("URL must not include a fragment") } if requireOrigin { if u.Path != "" && u.Path != "/" { return nil, fmt.Errorf("issuer must be an origin URL without a path") } if u.Port() == "443" { return nil, fmt.Errorf("issuer must not include the default HTTPS port") } } if ip := net.ParseIP(u.Hostname()); ip != nil && !IsPublicIPAddress(ip) && !isAllowedHost(u.Hostname()) { return nil, fmt.Errorf("URL must not target a private or reserved IP address") } return u, nil } func newPublicHTTPClient(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, Transport: newPublicOnlyTransport(), CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) >= 5 { return fmt.Errorf("too many redirects") } _, err := validateEndpointURL(req.URL.String()) return err }, } } func newPublicOnlyTransport() *http.Transport { transport := http.DefaultTransport.(*http.Transport).Clone() dialer := &net.Dialer{ Timeout: 10 * time.Second, KeepAlive: 30 * time.Second, } transport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { return dialPublicContext(ctx, dialer, network, address) } transport.ForceAttemptHTTP2 = true transport.MaxIdleConns = 100 transport.IdleConnTimeout = 90 * time.Second transport.TLSHandshakeTimeout = 10 * time.Second transport.ResponseHeaderTimeout = 10 * time.Second transport.ExpectContinueTimeout = time.Second return transport } func dialPublicContext(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, fmt.Errorf("invalid address %q: %w", address, err) } if !isAllowedHost(host) && isBlockedHostname(host) { return nil, fmt.Errorf("blocked host %q", host) } portNum, err := strconv.Atoi(port) if err != nil || portNum < 1 || portNum > 65535 { return nil, fmt.Errorf("invalid port %q", port) } if ip := net.ParseIP(host); ip != nil { if !IsPublicIPAddress(ip) && !isAllowedHost(host) { return nil, fmt.Errorf("blocked IP address %s", ip) } return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) } resolved, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { return nil, fmt.Errorf("failed to resolve host %q: %w", host, err) } if len(resolved) == 0 { return nil, fmt.Errorf("host %q did not resolve to any IPs", host) } addresses := make([]string, 0, len(resolved)) for _, addr := range resolved { if !IsPublicIPAddress(addr.IP) && !isAllowedHost(host) { return nil, fmt.Errorf("host %q resolves to non-public IP %s", host, addr.IP) } addresses = append(addresses, net.JoinHostPort(addr.IP.String(), port)) } var lastErr error for _, dialAddress := range addresses { conn, err := dialer.DialContext(ctx, network, dialAddress) if err == nil { return conn, nil } lastErr = err } return nil, fmt.Errorf("failed to connect to %q: %w", host, lastErr) } func ipv4Net(a, b, c, d byte, subnetPrefixLen int) net.IPNet { return net.IPNet{ IP: net.IPv4(a, b, c, d), Mask: net.CIDRMask(96+subnetPrefixLen, 128), } } func IsPublicIPAddress(ip net.IP) bool { if ip4 := ip.To4(); ip4 != nil { for _, reserved := range reservedIPv4Nets { if reserved.Contains(ip4) { return false } } return true } return globalUnicastIPv6Net.Contains(ip) } func isPrivateHost(host string) bool { if isBlockedHostname(host) { return true } ip := net.ParseIP(host) if ip == nil { return false } return !IsPublicIPAddress(ip) } func isPrivateIP(ip net.IP) bool { return !IsPublicIPAddress(ip) } func isBlockedHostname(host string) bool { host = strings.TrimSpace(strings.ToLower(host)) return host == "localhost" || strings.HasSuffix(host, ".localhost") } func isAllowedHost(host string) bool { host = strings.TrimSpace(strings.ToLower(host)) if host == "" { return false } _, ok := allowedHosts.Load(host) return ok }