Stateless auth proxy that converts AT Protocol native apps from public to confidential OAuth clients. Deploy once, get 180-day refresh tokens instead of 24-hour ones.
1package main
2
3import (
4 "crypto/ecdsa"
5 "crypto/elliptic"
6 "crypto/rand"
7 "crypto/x509"
8 "encoding/json"
9 "encoding/pem"
10 "testing"
11)
12
13func generateTestPEM(t *testing.T) string {
14 t.Helper()
15 key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
16 if err != nil {
17 t.Fatalf("failed to generate test key: %v", err)
18 }
19 der, err := x509.MarshalPKCS8PrivateKey(key)
20 if err != nil {
21 t.Fatalf("failed to marshal test key: %v", err)
22 }
23 block := &pem.Block{Type: "PRIVATE KEY", Bytes: der}
24 return string(pem.EncodeToMemory(block))
25}
26
27func TestParsePrivateKey(t *testing.T) {
28 t.Run("valid P-256 PEM", func(t *testing.T) {
29 pemData := generateTestPEM(t)
30 key, err := ParsePrivateKey(pemData)
31 if err != nil {
32 t.Fatalf("unexpected error: %v", err)
33 }
34 if key.Curve != elliptic.P256() {
35 t.Fatalf("expected P-256 curve, got %s", key.Curve.Params().Name)
36 }
37 })
38
39 t.Run("invalid PEM", func(t *testing.T) {
40 _, err := ParsePrivateKey("not a pem")
41 if err == nil {
42 t.Fatal("expected error for invalid PEM")
43 }
44 })
45
46 t.Run("empty PEM", func(t *testing.T) {
47 _, err := ParsePrivateKey("")
48 if err == nil {
49 t.Fatal("expected error for empty PEM")
50 }
51 })
52
53 t.Run("wrong key type", func(t *testing.T) {
54 // Generate a P-384 key (wrong curve)
55 key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
56 if err != nil {
57 t.Fatalf("failed to generate P-384 key: %v", err)
58 }
59 der, err := x509.MarshalPKCS8PrivateKey(key)
60 if err != nil {
61 t.Fatalf("failed to marshal key: %v", err)
62 }
63 block := &pem.Block{Type: "PRIVATE KEY", Bytes: der}
64 pemData := string(pem.EncodeToMemory(block))
65
66 _, err = ParsePrivateKey(pemData)
67 if err == nil {
68 t.Fatal("expected error for non-P-256 key")
69 }
70 })
71}
72
73func TestBuildJWKS(t *testing.T) {
74 pemData := generateTestPEM(t)
75 key, err := ParsePrivateKey(pemData)
76 if err != nil {
77 t.Fatalf("failed to parse key: %v", err)
78 }
79
80 kid := "test-key-1"
81 jwksBytes, err := BuildJWKS([]keyEntry{{privateKey: key, kid: kid}})
82 if err != nil {
83 t.Fatalf("unexpected error: %v", err)
84 }
85
86 var jwks struct {
87 Keys []struct {
88 Kty string `json:"kty"`
89 Crv string `json:"crv"`
90 Kid string `json:"kid"`
91 Use string `json:"use"`
92 Alg string `json:"alg"`
93 X string `json:"x"`
94 Y string `json:"y"`
95 } `json:"keys"`
96 }
97 if err := json.Unmarshal(jwksBytes, &jwks); err != nil {
98 t.Fatalf("failed to unmarshal JWKS: %v", err)
99 }
100
101 if len(jwks.Keys) != 1 {
102 t.Fatalf("expected 1 key, got %d", len(jwks.Keys))
103 }
104
105 k := jwks.Keys[0]
106 if k.Kty != "EC" {
107 t.Errorf("expected kty=EC, got %s", k.Kty)
108 }
109 if k.Crv != "P-256" {
110 t.Errorf("expected crv=P-256, got %s", k.Crv)
111 }
112 if k.Kid != kid {
113 t.Errorf("expected kid=%s, got %s", kid, k.Kid)
114 }
115 if k.Use != "sig" {
116 t.Errorf("expected use=sig, got %s", k.Use)
117 }
118 if k.Alg != "ES256" {
119 t.Errorf("expected alg=ES256, got %s", k.Alg)
120 }
121 if k.X == "" || k.Y == "" {
122 t.Error("expected non-empty x and y coordinates")
123 }
124}
125
126func TestBuildJWKS_MultipleKeys(t *testing.T) {
127 pem1 := generateTestPEM(t)
128 key1, err := ParsePrivateKey(pem1)
129 if err != nil {
130 t.Fatalf("failed to parse key1: %v", err)
131 }
132
133 pem2 := generateTestPEM(t)
134 key2, err := ParsePrivateKey(pem2)
135 if err != nil {
136 t.Fatalf("failed to parse key2: %v", err)
137 }
138
139 jwksBytes, err := BuildJWKS([]keyEntry{
140 {privateKey: key1, kid: "new-key"},
141 {privateKey: key2, kid: "old-key"},
142 })
143 if err != nil {
144 t.Fatalf("unexpected error: %v", err)
145 }
146
147 var jwks struct {
148 Keys []struct {
149 Kid string `json:"kid"`
150 Kty string `json:"kty"`
151 } `json:"keys"`
152 }
153 if err := json.Unmarshal(jwksBytes, &jwks); err != nil {
154 t.Fatalf("failed to unmarshal JWKS: %v", err)
155 }
156
157 if len(jwks.Keys) != 2 {
158 t.Fatalf("expected 2 keys, got %d", len(jwks.Keys))
159 }
160
161 kids := map[string]bool{}
162 for _, k := range jwks.Keys {
163 kids[k.Kid] = true
164 }
165 if !kids["new-key"] {
166 t.Error("expected kid=new-key in JWKS")
167 }
168 if !kids["old-key"] {
169 t.Error("expected kid=old-key in JWKS")
170 }
171}
172
173func TestBuildJWKS_NoPrivateKey(t *testing.T) {
174 pemData := generateTestPEM(t)
175 key, err := ParsePrivateKey(pemData)
176 if err != nil {
177 t.Fatalf("failed to parse key: %v", err)
178 }
179
180 jwksBytes, err := BuildJWKS([]keyEntry{{privateKey: key, kid: "test-key"}})
181 if err != nil {
182 t.Fatalf("unexpected error: %v", err)
183 }
184
185 // Verify no private key material (d) is present
186 var raw map[string]json.RawMessage
187 if err := json.Unmarshal(jwksBytes, &raw); err != nil {
188 t.Fatalf("failed to unmarshal: %v", err)
189 }
190
191 var keys []map[string]json.RawMessage
192 if err := json.Unmarshal(raw["keys"], &keys); err != nil {
193 t.Fatalf("failed to unmarshal keys: %v", err)
194 }
195
196 if _, hasD := keys[0]["d"]; hasD {
197 t.Fatal("JWKS must not contain private key material (d)")
198 }
199}
200
201func TestNewSigner(t *testing.T) {
202 pemData := generateTestPEM(t)
203 key, err := ParsePrivateKey(pemData)
204 if err != nil {
205 t.Fatalf("failed to parse key: %v", err)
206 }
207
208 kid := "signer-key-1"
209 signer, err := NewSigner(key, kid)
210 if err != nil {
211 t.Fatalf("unexpected error: %v", err)
212 }
213
214 if signer.KeyID() != kid {
215 t.Errorf("expected kid=%s, got %s", kid, signer.KeyID())
216 }
217 if signer.Algorithm().String() != "ES256" {
218 t.Errorf("expected alg=ES256, got %s", signer.Algorithm())
219 }
220}