this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(toolbox): add ssh tunnel and vault client plumbing

Khue Doan 3e000cf6 8832eb18

+420 -4
+39
toolbox/cmd/root.go
··· 2 2 3 3 import ( 4 4 "os" 5 + "path/filepath" 5 6 7 + "github.com/charmbracelet/log" 6 8 "github.com/spf13/cobra" 7 9 ) 10 + 11 + var ( 12 + hostsFile string 13 + host string 14 + sshUser string 15 + sshKey string 16 + sshKnownHosts string 17 + ) 18 + 19 + func init() { 20 + log.SetReportTimestamp(false) 21 + 22 + rootCmd.PersistentFlags().StringVar(&hostsFile, "hosts-file", "", "Path to hosts.json file") 23 + rootCmd.PersistentFlags().StringVar(&host, "host", "", "Host name to connect to (e.g., kube-1)") 24 + rootCmd.PersistentFlags().StringVar(&sshUser, "ssh-user", "root", "SSH user") 25 + rootCmd.PersistentFlags().StringVar(&sshKey, "ssh-key", defaultSSHKey(), "Path to SSH private key") 26 + rootCmd.PersistentFlags().StringVar(&sshKnownHosts, "ssh-known-hosts", defaultKnownHostsFile(), "Path to SSH known_hosts file") 27 + } 8 28 9 29 var rootCmd = &cobra.Command{ 10 30 Use: "toolbox", 11 31 Short: "CLI tools for managing cloudlab infrastructure", 32 + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { 33 + return nil 34 + }, 12 35 } 13 36 14 37 func Execute() { ··· 16 39 os.Exit(1) 17 40 } 18 41 } 42 + 43 + func defaultSSHKey() string { 44 + home, err := os.UserHomeDir() 45 + if err != nil { 46 + return "" 47 + } 48 + return filepath.Join(home, ".ssh", "id_ed25519") 49 + } 50 + 51 + func defaultKnownHostsFile() string { 52 + home, err := os.UserHomeDir() 53 + if err != nil { 54 + return "" 55 + } 56 + return filepath.Join(home, ".ssh", "known_hosts") 57 + }
+24
toolbox/go.mod
··· 3 3 go 1.25.5 4 4 5 5 require ( 6 + github.com/charmbracelet/log v0.4.2 7 + github.com/hashicorp/vault/api v1.22.0 6 8 github.com/spf13/cobra v1.10.2 7 9 golang.org/x/crypto v0.47.0 8 10 ) 11 + 12 + require ( 13 + github.com/cenkalti/backoff/v4 v4.3.0 // indirect 14 + github.com/go-jose/go-jose/v4 v4.1.1 // indirect 15 + github.com/go-logfmt/logfmt v0.6.0 // indirect 16 + github.com/hashicorp/errwrap v1.1.0 // indirect 17 + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 18 + github.com/hashicorp/go-multierror v1.1.1 // indirect 19 + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect 20 + github.com/hashicorp/go-rootcerts v1.0.2 // indirect 21 + github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect 22 + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect 23 + github.com/hashicorp/go-sockaddr v1.0.7 // indirect 24 + github.com/hashicorp/hcl v1.0.1-vault-7 // indirect 25 + github.com/mitchellh/go-homedir v1.1.0 // indirect 26 + github.com/mitchellh/mapstructure v1.5.0 // indirect 27 + github.com/ryanuber/go-glob v1.0.0 // indirect 28 + golang.org/x/net v0.48.0 // indirect 29 + golang.org/x/sys v0.40.0 // indirect 30 + golang.org/x/text v0.33.0 // indirect 31 + golang.org/x/time v0.12.0 // indirect 32 + )
+135
toolbox/internal/cluster/client.go
··· 1 + package cluster 2 + 3 + import ( 4 + "context" 5 + "encoding/base64" 6 + "fmt" 7 + "strings" 8 + "time" 9 + 10 + "github.com/hashicorp/vault/api" 11 + ) 12 + 13 + const ( 14 + vaultNamespace = "vault" 15 + vaultService = "svc/vault" 16 + vaultPort = 8200 17 + ) 18 + 19 + type ClientConfig struct { 20 + HostsFile string 21 + Host string 22 + SSHUser string 23 + SSHKey string 24 + SSHKnownHosts string 25 + Timeout time.Duration 26 + } 27 + 28 + type Client struct { 29 + conn *Connector 30 + vault *api.Client 31 + } 32 + 33 + func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) { 34 + if cfg.Timeout == 0 { 35 + cfg.Timeout = 30 * time.Second 36 + } 37 + 38 + connectCtx, cancel := context.WithTimeout(ctx, cfg.Timeout) 39 + defer cancel() 40 + 41 + hostAddr, err := LoadHost(cfg.HostsFile, cfg.Host) 42 + if err != nil { 43 + return nil, fmt.Errorf("load host: %w", err) 44 + } 45 + 46 + conn, err := Connect(SSHConfig{ 47 + Host: hostAddr, 48 + User: cfg.SSHUser, 49 + KeyPath: cfg.SSHKey, 50 + KnownHostsPath: cfg.SSHKnownHosts, 51 + Timeout: cfg.Timeout, 52 + }) 53 + if err != nil { 54 + return nil, fmt.Errorf("connect: %w", err) 55 + } 56 + 57 + token, err := getVaultToken(connectCtx, conn) 58 + if err != nil { 59 + conn.Close() 60 + return nil, fmt.Errorf("get vault token: %w", err) 61 + } 62 + 63 + vaultTunnel, err := conn.Forward(connectCtx, ServiceConfig{ 64 + Namespace: vaultNamespace, 65 + Name: vaultService, 66 + Port: vaultPort, 67 + }) 68 + if err != nil { 69 + conn.Close() 70 + return nil, fmt.Errorf("forward vault: %w", err) 71 + } 72 + 73 + vaultClient, err := newVaultClient(vaultTunnel.LocalAddr, token) 74 + if err != nil { 75 + conn.Close() 76 + return nil, fmt.Errorf("create vault client: %w", err) 77 + } 78 + 79 + return &Client{ 80 + conn: conn, 81 + vault: vaultClient, 82 + }, nil 83 + } 84 + 85 + func (c *Client) Vault() *api.Client { 86 + return c.vault 87 + } 88 + 89 + func (c *Client) Forward(ctx context.Context, svc ServiceConfig) (*ServiceTunnel, error) { 90 + return c.conn.Forward(ctx, svc) 91 + } 92 + 93 + func (c *Client) RunCommand(cmd string) ([]byte, error) { 94 + return c.conn.RunCommand(cmd) 95 + } 96 + 97 + func (c *Client) RunCommandContext(ctx context.Context, cmd string) ([]byte, error) { 98 + return c.conn.RunCommandContext(ctx, cmd) 99 + } 100 + 101 + func (c *Client) Close() error { 102 + return c.conn.Close() 103 + } 104 + 105 + func getVaultToken(ctx context.Context, conn *Connector) (string, error) { 106 + cmd := fmt.Sprintf( 107 + `kubectl get secret vault-unseal-keys -n %s -o template='{{ index .data "vault-root" }}'`, 108 + vaultNamespace, 109 + ) 110 + 111 + output, err := conn.RunCommandContext(ctx, cmd) 112 + if err != nil { 113 + return "", err 114 + } 115 + 116 + token, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(output))) 117 + if err != nil { 118 + return "", fmt.Errorf("decode token: %w", err) 119 + } 120 + 121 + return string(token), nil 122 + } 123 + 124 + func newVaultClient(addr, token string) (*api.Client, error) { 125 + config := api.DefaultConfig() 126 + config.Address = "http://" + addr 127 + 128 + client, err := api.NewClient(config) 129 + if err != nil { 130 + return nil, fmt.Errorf("create client: %w", err) 131 + } 132 + client.SetToken(token) 133 + 134 + return client, nil 135 + }
+222 -4
toolbox/internal/cluster/cluster.go
··· 3 3 import ( 4 4 "context" 5 5 "encoding/json" 6 + "errors" 6 7 "fmt" 8 + "io" 9 + "net" 7 10 "os" 8 11 "strings" 12 + "sync" 9 13 "time" 10 14 11 15 "golang.org/x/crypto/ssh" ··· 13 17 ) 14 18 15 19 const ( 16 - defaultSSHPort = 22 17 - defaultSSHTimeout = 10 * time.Second 20 + defaultSSHPort = 22 21 + defaultSSHTimeout = 10 * time.Second 22 + healthCheckInterval = 500 * time.Millisecond 18 23 ) 19 24 20 25 type SSHConfig struct { ··· 23 28 KeyPath string 24 29 KnownHostsPath string 25 30 Timeout time.Duration 31 + } 32 + 33 + type ServiceConfig struct { 34 + Namespace string 35 + Name string 36 + Port int 26 37 } 27 38 28 39 type Connector struct { 29 40 sshClient *ssh.Client 41 + tunnels []*tunnel 42 + mu sync.Mutex 43 + } 44 + 45 + type ServiceTunnel struct { 46 + LocalAddr string 47 + } 48 + 49 + type tunnel struct { 50 + listener net.Listener 51 + session *ssh.Session 52 + localPort int 53 + remotePort int 54 + done chan struct{} 55 + closeOnce sync.Once 30 56 } 31 57 32 58 func Connect(cfg SSHConfig) (*Connector, error) { ··· 39 65 return nil, fmt.Errorf("ssh connect: %w", err) 40 66 } 41 67 42 - return &Connector{sshClient: sshClient}, nil 68 + return &Connector{ 69 + sshClient: sshClient, 70 + tunnels: make([]*tunnel, 0), 71 + }, nil 72 + } 73 + 74 + func (c *Connector) Forward(ctx context.Context, svc ServiceConfig) (*ServiceTunnel, error) { 75 + c.mu.Lock() 76 + defer c.mu.Unlock() 77 + 78 + session, err := c.sshClient.NewSession() 79 + if err != nil { 80 + return nil, fmt.Errorf("create session: %w", err) 81 + } 82 + 83 + cmd := fmt.Sprintf( 84 + "exec kubectl port-forward %s -n %s %d:%d", 85 + svc.Name, svc.Namespace, svc.Port, svc.Port, 86 + ) 87 + 88 + if err := session.Start(cmd); err != nil { 89 + session.Close() 90 + return nil, fmt.Errorf("start port-forward: %w", err) 91 + } 92 + 93 + listener, err := net.Listen("tcp", "127.0.0.1:0") 94 + if err != nil { 95 + session.Signal(ssh.SIGTERM) 96 + session.Close() 97 + return nil, fmt.Errorf("listen: %w", err) 98 + } 99 + 100 + localPort := listener.Addr().(*net.TCPAddr).Port 101 + done := make(chan struct{}) 102 + 103 + t := &tunnel{ 104 + listener: listener, 105 + session: session, 106 + localPort: localPort, 107 + remotePort: svc.Port, 108 + done: done, 109 + } 110 + 111 + go c.runTunnel(t) 112 + 113 + c.tunnels = append(c.tunnels, t) 114 + 115 + localAddr := fmt.Sprintf("127.0.0.1:%d", localPort) 116 + if err := waitForTunnel(ctx, c.sshClient, localAddr, svc.Port); err != nil { 117 + closeErr := c.closeTunnel(t) 118 + c.tunnels = c.tunnels[:len(c.tunnels)-1] 119 + if closeErr != nil { 120 + return nil, fmt.Errorf("service not reachable: %w (cleanup: %v)", err, closeErr) 121 + } 122 + return nil, fmt.Errorf("service not reachable: %w", err) 123 + } 124 + 125 + return &ServiceTunnel{LocalAddr: localAddr}, nil 43 126 } 44 127 45 128 func (c *Connector) RunCommand(cmd string) ([]byte, error) { ··· 77 160 } 78 161 79 162 func (c *Connector) Close() error { 163 + c.mu.Lock() 164 + defer c.mu.Unlock() 165 + 166 + var errs []error 167 + 168 + for _, t := range c.tunnels { 169 + if err := c.closeTunnel(t); err != nil { 170 + errs = append(errs, err) 171 + } 172 + } 173 + c.tunnels = nil 174 + 80 175 if c.sshClient != nil { 81 176 if err := c.sshClient.Close(); err != nil { 82 - return fmt.Errorf("close ssh: %w", err) 177 + errs = append(errs, fmt.Errorf("close ssh: %w", err)) 83 178 } 179 + } 180 + 181 + if len(errs) > 0 { 182 + return errors.Join(errs...) 84 183 } 85 184 return nil 86 185 } 87 186 187 + func (c *Connector) closeTunnel(t *tunnel) error { 188 + var closeErr error 189 + 190 + t.closeOnce.Do(func() { 191 + var errs []error 192 + 193 + close(t.done) 194 + 195 + if t.listener != nil { 196 + if err := t.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { 197 + errs = append(errs, fmt.Errorf("close listener: %w", err)) 198 + } 199 + t.listener = nil 200 + } 201 + 202 + if t.session != nil { 203 + _ = t.session.Signal(ssh.SIGTERM) 204 + if err := t.session.Close(); err != nil && !errors.Is(err, io.EOF) { 205 + errs = append(errs, fmt.Errorf("close session: %w", err)) 206 + } 207 + t.session = nil 208 + } 209 + 210 + if len(errs) > 0 { 211 + closeErr = errors.Join(errs...) 212 + } 213 + }) 214 + 215 + return closeErr 216 + } 217 + 218 + func (c *Connector) runTunnel(t *tunnel) { 219 + for { 220 + select { 221 + case <-t.done: 222 + return 223 + default: 224 + } 225 + 226 + localConn, err := t.listener.Accept() 227 + if err != nil { 228 + if errors.Is(err, net.ErrClosed) { 229 + return 230 + } 231 + continue 232 + } 233 + 234 + go c.handleTunnelConn(localConn, t.remotePort) 235 + } 236 + } 237 + 238 + func (c *Connector) handleTunnelConn(localConn net.Conn, remotePort int) { 239 + defer localConn.Close() 240 + 241 + remoteAddr := fmt.Sprintf("127.0.0.1:%d", remotePort) 242 + remoteConn, err := c.sshClient.Dial("tcp", remoteAddr) 243 + if err != nil { 244 + return 245 + } 246 + defer remoteConn.Close() 247 + 248 + done := make(chan struct{}, 2) 249 + 250 + go func() { 251 + io.Copy(remoteConn, localConn) 252 + done <- struct{}{} 253 + }() 254 + 255 + go func() { 256 + io.Copy(localConn, remoteConn) 257 + done <- struct{}{} 258 + }() 259 + 260 + <-done 261 + } 262 + 263 + func waitForTunnel(ctx context.Context, sshClient *ssh.Client, localAddr string, remotePort int) error { 264 + dialer := &net.Dialer{Timeout: 2 * time.Second} 265 + remoteAddr := fmt.Sprintf("127.0.0.1:%d", remotePort) 266 + var lastErr error 267 + 268 + for { 269 + select { 270 + case <-ctx.Done(): 271 + if lastErr != nil { 272 + return fmt.Errorf("%w (last check: %v)", ctx.Err(), lastErr) 273 + } 274 + return ctx.Err() 275 + default: 276 + } 277 + 278 + localConn, err := dialer.DialContext(ctx, "tcp", localAddr) 279 + if err != nil { 280 + lastErr = fmt.Errorf("dial local %s: %w", localAddr, err) 281 + } else { 282 + localConn.Close() 283 + 284 + // Ensure the SSH-side endpoint is also reachable so we don't report 285 + // readiness while the remote port-forward is still starting. 286 + remoteConn, err := sshClient.Dial("tcp", remoteAddr) 287 + if err == nil { 288 + remoteConn.Close() 289 + return nil 290 + } 291 + lastErr = fmt.Errorf("dial remote %s: %w", remoteAddr, err) 292 + } 293 + 294 + select { 295 + case <-ctx.Done(): 296 + if lastErr != nil { 297 + return fmt.Errorf("%w (last check: %v)", ctx.Err(), lastErr) 298 + } 299 + return ctx.Err() 300 + case <-time.After(healthCheckInterval): 301 + } 302 + } 303 + } 304 + 88 305 type HostInfo struct { 89 306 IPv6Address string `json:"ipv6_address"` 90 307 } ··· 143 360 144 361 addr := fmt.Sprintf("%s:%d", cfg.Host, defaultSSHPort) 145 362 if strings.Contains(cfg.Host, ":") { 363 + // IPv6 addresses need brackets 146 364 addr = fmt.Sprintf("[%s]:%d", cfg.Host, defaultSSHPort) 147 365 } 148 366