Fork of https://github.com/xenova/microgpt.js
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Optimize Zig port: remove @constCast, f64 inference fast path, readability improvements

- Eliminate all @constCast by changing [][]const *Value to [][]*Value
- Add f64 fast-path for inference (doLinearF64, doSoftmaxF64, doRmsnormF64, doGptF64)
- Break up dense Mersenne Twister expressions into readable intermediates
- Use std.mem.swap in mtShuffle
- Inline sum in doLinear to avoid temp allocation per row
- Pre-size ArrayLists (tokens, losses) where lengths are known
- Shrink Value.gen from u64 to u32

+764
+764
microgpt.zig
··· 1 + /// 2 + /// The most atomic way to train and inference a GPT in pure, dependency-free Zig. 3 + /// This file is the complete algorithm. 4 + /// Everything else is just efficiency. 5 + /// 6 + /// @karpathy (original Python), @xenova (JavaScript port), Zig port 7 + /// 8 + const std = @import("std"); 9 + const math = std.math; 10 + const Allocator = std.mem.Allocator; 11 + const Io = std.Io; 12 + const ArrayList = std.ArrayList; 13 + 14 + // ─── Mersenne Twister 19937 (matches Python's random module exactly) ─── 15 + 16 + var mt_state: [624]u32 = undefined; 17 + var mt_idx: u32 = 625; 18 + var gauss_next: ?f64 = null; 19 + 20 + fn mtSeed(n: u64) void { 21 + var key_buf: [4]u32 = undefined; 22 + var key_len: usize = 0; 23 + var v = n; 24 + if (v == 0) { 25 + key_buf[0] = 0; 26 + key_len = 1; 27 + } else { 28 + while (v > 0) : (v = v / 0x100000000) { 29 + key_buf[key_len] = @as(u32, @intCast(v & 0xFFFFFFFF)); 30 + key_len += 1; 31 + } 32 + } 33 + const key = key_buf[0..key_len]; 34 + 35 + mt_state[0] = 19650218; 36 + for (1..624) |idx| { 37 + const prev = mt_state[idx - 1]; 38 + const xor = prev ^ (prev >> 30); 39 + const product: u32 = @truncate(@as(u64, 1812433253) *% @as(u64, xor)); 40 + mt_state[idx] = product +% @as(u32, @intCast(idx)); 41 + } 42 + var i: usize = 1; 43 + var j: usize = 0; 44 + var k: usize = @max(624, key.len); 45 + while (k > 0) : (k -= 1) { 46 + if (i >= 624) { 47 + mt_state[0] = mt_state[623]; 48 + i = 1; 49 + } 50 + if (j >= key.len) j = 0; 51 + const prev1 = mt_state[i - 1]; 52 + const xor1: u32 = @truncate(@as(u64, 1664525) *% @as(u64, prev1 ^ (prev1 >> 30))); 53 + mt_state[i] = (mt_state[i] ^ xor1) +% key[j] +% @as(u32, @intCast(j)); 54 + i += 1; 55 + j += 1; 56 + } 57 + k = 623; 58 + while (k > 0) : (k -= 1) { 59 + if (i >= 624) { 60 + mt_state[0] = mt_state[623]; 61 + i = 1; 62 + } 63 + const prev2 = mt_state[i - 1]; 64 + const xor2: u32 = @truncate(@as(u64, 1566083941) *% @as(u64, prev2 ^ (prev2 >> 30))); 65 + mt_state[i] = (mt_state[i] ^ xor2) -% @as(u32, @intCast(i)); 66 + i += 1; 67 + } 68 + mt_state[0] = 0x80000000; 69 + mt_idx = 624; 70 + gauss_next = null; 71 + } 72 + 73 + fn mtInt32() u32 { 74 + if (mt_idx >= 624) { 75 + for (0..624) |k| { 76 + const y = (mt_state[k] & 0x80000000) | (mt_state[(k + 1) % 624] & 0x7FFFFFFF); 77 + mt_state[k] = mt_state[(k + 397) % 624] ^ (y >> 1) ^ (if (y & 1 != 0) @as(u32, 0x9908B0DF) else @as(u32, 0)); 78 + } 79 + mt_idx = 0; 80 + } 81 + var y = mt_state[mt_idx]; 82 + mt_idx += 1; 83 + y ^= y >> 11; 84 + y ^= (y << 7) & 0x9D2C5680; 85 + y ^= (y << 15) & 0xEFC60000; 86 + y ^= y >> 18; 87 + return y; 88 + } 89 + 90 + fn mtRandom() f64 { 91 + const a: f64 = @floatFromInt(mtInt32() >> 5); 92 + const b: f64 = @floatFromInt(mtInt32() >> 6); 93 + return (a * 67108864.0 + b) / 9007199254740992.0; 94 + } 95 + 96 + fn mtGauss(mu: f64, sigma: f64) f64 { 97 + var z: f64 = undefined; 98 + if (gauss_next) |gn| { 99 + z = gn; 100 + gauss_next = null; 101 + } else { 102 + const x2pi = mtRandom() * 2.0 * math.pi; 103 + const g2rad = @sqrt(-2.0 * @log(1.0 - mtRandom())); 104 + z = @cos(x2pi) * g2rad; 105 + gauss_next = @sin(x2pi) * g2rad; 106 + } 107 + return mu + z * sigma; 108 + } 109 + 110 + fn mtShuffle(comptime T: type, arr: []T) void { 111 + var i: usize = arr.len - 1; 112 + while (i > 0) : (i -= 1) { 113 + const k: u32 = 32 - @clz(@as(u32, @intCast(i + 1))); 114 + const shift: u5 = @intCast(32 -| k); // saturating subtract for the case k=32 115 + var r: usize = @intCast(mtInt32() >> shift); 116 + while (r > i) { 117 + r = @intCast(mtInt32() >> shift); 118 + } 119 + std.mem.swap(T, &arr[i], &arr[r]); 120 + } 121 + } 122 + 123 + fn mtChoices(weights: []const f64, alloc: Allocator) usize { 124 + const n = weights.len; 125 + const cum = alloc.alloc(f64, n) catch unreachable; 126 + cum[0] = weights[0]; 127 + for (1..n) |i| { 128 + cum[i] = cum[i - 1] + weights[i]; 129 + } 130 + const x = mtRandom() * cum[n - 1]; 131 + var lo: usize = 0; 132 + var hi: usize = n - 1; 133 + while (lo < hi) { 134 + const mid = (lo + hi) >> 1; 135 + if (x < cum[mid]) { 136 + hi = mid; 137 + } else { 138 + lo = mid + 1; 139 + } 140 + } 141 + return lo; 142 + } 143 + 144 + // ─── Value (Autograd) ─── 145 + 146 + const Value = struct { 147 + data: f64, 148 + grad: f64, 149 + c0: ?*Value, 150 + c1: ?*Value, 151 + lg0: f64, 152 + lg1: f64, 153 + nch: u2, 154 + gen: u32, 155 + 156 + fn create(alloc: Allocator, data: f64) *Value { 157 + const v = alloc.create(Value) catch unreachable; 158 + v.* = .{ .data = data, .grad = 0, .c0 = null, .c1 = null, .lg0 = 0, .lg1 = 0, .nch = 0, .gen = 0 }; 159 + return v; 160 + } 161 + 162 + fn bin(alloc: Allocator, data: f64, a: *Value, b: *Value, lg0: f64, lg1: f64) *Value { 163 + const v = alloc.create(Value) catch unreachable; 164 + v.* = .{ .data = data, .grad = 0, .c0 = a, .c1 = b, .lg0 = lg0, .lg1 = lg1, .nch = 2, .gen = 0 }; 165 + return v; 166 + } 167 + 168 + fn una(alloc: Allocator, data: f64, a: *Value, lg0: f64) *Value { 169 + const v = alloc.create(Value) catch unreachable; 170 + v.* = .{ .data = data, .grad = 0, .c0 = a, .c1 = null, .lg0 = lg0, .lg1 = 0, .nch = 1, .gen = 0 }; 171 + return v; 172 + } 173 + 174 + fn addV(self: *Value, other: *Value, alloc: Allocator) *Value { 175 + return bin(alloc, self.data + other.data, self, other, 1, 1); 176 + } 177 + 178 + fn addS(self: *Value, s: f64, alloc: Allocator) *Value { 179 + return una(alloc, self.data + s, self, 1); 180 + } 181 + 182 + fn mulV(self: *Value, other: *Value, alloc: Allocator) *Value { 183 + return bin(alloc, self.data * other.data, self, other, other.data, self.data); 184 + } 185 + 186 + fn mulS(self: *Value, s: f64, alloc: Allocator) *Value { 187 + return una(alloc, self.data * s, self, s); 188 + } 189 + 190 + fn powS(self: *Value, p: f64, alloc: Allocator) *Value { 191 + return una(alloc, math.pow(f64, self.data, p), self, p * math.pow(f64, self.data, p - 1)); 192 + } 193 + 194 + fn logV(self: *Value, alloc: Allocator) *Value { 195 + return una(alloc, @log(self.data), self, 1.0 / self.data); 196 + } 197 + 198 + fn expV(self: *Value, alloc: Allocator) *Value { 199 + const e = @exp(self.data); 200 + return una(alloc, e, self, e); 201 + } 202 + 203 + fn relu(self: *Value, alloc: Allocator) *Value { 204 + return una(alloc, @max(0.0, self.data), self, if (self.data > 0) @as(f64, 1.0) else @as(f64, 0.0)); 205 + } 206 + 207 + fn neg(self: *Value, alloc: Allocator) *Value { 208 + return una(alloc, -self.data, self, -1.0); 209 + } 210 + 211 + fn subS(self: *Value, s: f64, alloc: Allocator) *Value { 212 + return self.addS(-s, alloc); 213 + } 214 + 215 + fn divV(self: *Value, other: *Value, alloc: Allocator) *Value { 216 + return self.mulV(other.powS(-1, alloc), alloc); 217 + } 218 + 219 + fn divS(self: *Value, s: f64, alloc: Allocator) *Value { 220 + return self.mulS(1.0 / s, alloc); 221 + } 222 + }; 223 + 224 + var backward_gen: u32 = 0; 225 + 226 + fn doBackward(loss: *Value, alloc: Allocator) void { 227 + backward_gen += 1; 228 + const gen = backward_gen; 229 + 230 + var topo: ArrayList(*Value) = .empty; 231 + const Frame = struct { v: *Value, state: u8 }; 232 + var stack: ArrayList(Frame) = .empty; 233 + stack.append(alloc, .{ .v = loss, .state = 0 }) catch unreachable; 234 + 235 + while (stack.items.len > 0) { 236 + const frame = &stack.items[stack.items.len - 1]; 237 + if (frame.v.gen == gen and frame.state == 0) { 238 + _ = stack.pop(); 239 + continue; 240 + } 241 + if (frame.state == 0) { 242 + frame.v.gen = gen; 243 + if (frame.v.nch >= 1) { 244 + frame.state = 1; 245 + stack.append(alloc, .{ .v = frame.v.c0.?, .state = 0 }) catch unreachable; 246 + continue; 247 + } 248 + } 249 + if (frame.state <= 1) { 250 + if (frame.v.nch == 2) { 251 + frame.state = 2; 252 + stack.append(alloc, .{ .v = frame.v.c1.?, .state = 0 }) catch unreachable; 253 + continue; 254 + } 255 + } 256 + topo.append(alloc, frame.v) catch unreachable; 257 + _ = stack.pop(); 258 + } 259 + 260 + loss.grad = 1; 261 + var i: usize = topo.items.len; 262 + while (i > 0) { 263 + i -= 1; 264 + const v = topo.items[i]; 265 + const g = v.grad; 266 + if (v.nch >= 1) v.c0.?.grad += v.lg0 * g; 267 + if (v.nch == 2) v.c1.?.grad += v.lg1 * g; 268 + } 269 + } 270 + 271 + // ─── Helper functions ─── 272 + 273 + fn sumValues(vals: []*Value, alloc: Allocator) *Value { 274 + var result = vals[0]; 275 + for (vals[1..]) |v| { 276 + result = result.addV(v, alloc); 277 + } 278 + return result; 279 + } 280 + 281 + fn doLinear(x: []*Value, w: []const []*Value, alloc: Allocator) []*Value { 282 + const nout = w.len; 283 + const out = alloc.alloc(*Value, nout) catch unreachable; 284 + for (0..nout) |oi| { 285 + const row = w[oi]; 286 + var acc = row[0].mulV(x[0], alloc); 287 + for (1..x.len) |j| { 288 + acc = acc.addV(row[j].mulV(x[j], alloc), alloc); 289 + } 290 + out[oi] = acc; 291 + } 292 + return out; 293 + } 294 + 295 + fn doSoftmax(logits: []*Value, alloc: Allocator) []*Value { 296 + var max_val: f64 = -math.inf(f64); 297 + for (logits) |v| { 298 + if (v.data > max_val) max_val = v.data; 299 + } 300 + const n = logits.len; 301 + const exps = alloc.alloc(*Value, n) catch unreachable; 302 + for (0..n) |i| { 303 + exps[i] = logits[i].subS(max_val, alloc).expV(alloc); 304 + } 305 + const total = sumValues(exps, alloc); 306 + const result = alloc.alloc(*Value, n) catch unreachable; 307 + for (0..n) |i| { 308 + result[i] = exps[i].divV(total, alloc); 309 + } 310 + return result; 311 + } 312 + 313 + fn doRmsnorm(x: []*Value, alloc: Allocator) []*Value { 314 + const n = x.len; 315 + const sq = alloc.alloc(*Value, n) catch unreachable; 316 + for (0..n) |i| { 317 + sq[i] = x[i].mulV(x[i], alloc); 318 + } 319 + const ms = sumValues(sq, alloc).mulS(1.0 / @as(f64, @floatFromInt(n)), alloc); 320 + const s = ms.addS(1e-5, alloc).powS(-0.5, alloc); 321 + const result = alloc.alloc(*Value, n) catch unreachable; 322 + for (0..n) |i| { 323 + result[i] = x[i].mulV(s, alloc); 324 + } 325 + return result; 326 + } 327 + 328 + // ─── Model hyperparameters ─── 329 + const n_embd = 16; 330 + const n_head = 4; 331 + const n_layer = 1; 332 + const block_size = 16; 333 + const head_dim = n_embd / n_head; 334 + const attn_scale: f64 = 1.0 / @sqrt(@as(f64, head_dim)); 335 + 336 + const StateDict = struct { 337 + wte: [][]*Value, 338 + wpe: [][]*Value, 339 + lm_head: [][]*Value, 340 + attn_wq: [n_layer][][]*Value, 341 + attn_wk: [n_layer][][]*Value, 342 + attn_wv: [n_layer][][]*Value, 343 + attn_wo: [n_layer][][]*Value, 344 + mlp_fc1: [n_layer][][]*Value, 345 + mlp_fc2: [n_layer][][]*Value, 346 + }; 347 + 348 + fn makeMatrix(alloc: Allocator, nout: usize, nin: usize, stdev: f64) [][]*Value { 349 + const rows = alloc.alloc([]*Value, nout) catch unreachable; 350 + for (0..nout) |i| { 351 + const row = alloc.alloc(*Value, nin) catch unreachable; 352 + for (0..nin) |j| { 353 + row[j] = Value.create(alloc, mtGauss(0, stdev)); 354 + } 355 + rows[i] = row; 356 + } 357 + return rows; 358 + } 359 + 360 + fn doGpt( 361 + token_id: usize, 362 + pos_id: usize, 363 + keys: *[n_layer]ArrayList([]*Value), 364 + kv_values: *[n_layer]ArrayList([]*Value), 365 + sd: *const StateDict, 366 + alloc: Allocator, 367 + ) []*Value { 368 + const tok_emb = sd.wte[token_id]; 369 + const pos_emb = sd.wpe[pos_id]; 370 + var x = alloc.alloc(*Value, n_embd) catch unreachable; 371 + for (0..n_embd) |i| { 372 + x[i] = tok_emb[i].addV(pos_emb[i], alloc); 373 + } 374 + x = doRmsnorm(x, alloc); 375 + 376 + for (0..n_layer) |li| { 377 + // 1) Multi-head attention block 378 + const x_residual = x; 379 + var xn = doRmsnorm(x, alloc); 380 + 381 + const q = doLinear(xn, sd.attn_wq[li], alloc); 382 + const k_vec = doLinear(xn, sd.attn_wk[li], alloc); 383 + const v_vec = doLinear(xn, sd.attn_wv[li], alloc); 384 + 385 + keys[li].append(alloc, k_vec) catch unreachable; 386 + kv_values[li].append(alloc, v_vec) catch unreachable; 387 + 388 + const x_attn = alloc.alloc(*Value, n_embd) catch unreachable; 389 + for (0..n_head) |h| { 390 + const hs = h * head_dim; 391 + const q_h = q[hs .. hs + head_dim]; 392 + 393 + const num_t = keys[li].items.len; 394 + const attn_logits = alloc.alloc(*Value, num_t) catch unreachable; 395 + for (0..num_t) |t| { 396 + const kt = keys[li].items[t]; 397 + const dot = alloc.alloc(*Value, head_dim) catch unreachable; 398 + for (0..head_dim) |j_| { 399 + dot[j_] = q_h[j_].mulV(kt[hs + j_], alloc); 400 + } 401 + attn_logits[t] = sumValues(dot, alloc).mulS(attn_scale, alloc); 402 + } 403 + const attn_weights = doSoftmax(attn_logits, alloc); 404 + 405 + for (0..head_dim) |j_| { 406 + const weighted = alloc.alloc(*Value, num_t) catch unreachable; 407 + for (0..num_t) |t| { 408 + weighted[t] = attn_weights[t].mulV(kv_values[li].items[t][hs + j_], alloc); 409 + } 410 + x_attn[hs + j_] = sumValues(weighted, alloc); 411 + } 412 + } 413 + 414 + xn = doLinear(x_attn, sd.attn_wo[li], alloc); 415 + // residual connection 416 + const x_res1 = alloc.alloc(*Value, n_embd) catch unreachable; 417 + for (0..n_embd) |i| { 418 + x_res1[i] = xn[i].addV(x_residual[i], alloc); 419 + } 420 + 421 + // 2) MLP block 422 + const x_residual2 = x_res1; 423 + const xn2 = doRmsnorm(x_res1, alloc); 424 + var fc1_out = doLinear(xn2, sd.mlp_fc1[li], alloc); 425 + for (0..fc1_out.len) |i| { 426 + fc1_out[i] = fc1_out[i].relu(alloc); 427 + } 428 + const fc2_out = doLinear(fc1_out, sd.mlp_fc2[li], alloc); 429 + x = alloc.alloc(*Value, n_embd) catch unreachable; 430 + for (0..n_embd) |i| { 431 + x[i] = fc2_out[i].addV(x_residual2[i], alloc); 432 + } 433 + } 434 + 435 + return doLinear(x, sd.lm_head, alloc); 436 + } 437 + 438 + // ─── f64 fast-path for inference (no autograd) ─── 439 + 440 + fn doLinearF64(x: []const f64, w: []const []*Value, alloc: Allocator) []f64 { 441 + const nout = w.len; 442 + const out = alloc.alloc(f64, nout) catch unreachable; 443 + for (0..nout) |oi| { 444 + const row = w[oi]; 445 + var acc: f64 = 0; 446 + for (0..x.len) |j| { 447 + acc += row[j].data * x[j]; 448 + } 449 + out[oi] = acc; 450 + } 451 + return out; 452 + } 453 + 454 + fn doSoftmaxF64(logits: []f64) []f64 { 455 + var max_val: f64 = -math.inf(f64); 456 + for (logits) |v| { 457 + if (v > max_val) max_val = v; 458 + } 459 + var total: f64 = 0; 460 + for (logits) |*v| { 461 + v.* = @exp(v.* - max_val); 462 + total += v.*; 463 + } 464 + for (logits) |*v| { 465 + v.* /= total; 466 + } 467 + return logits; 468 + } 469 + 470 + fn doRmsnormF64(x: []const f64, alloc: Allocator) []f64 { 471 + const n = x.len; 472 + var ms: f64 = 0; 473 + for (x) |xi| { 474 + ms += xi * xi; 475 + } 476 + ms /= @floatFromInt(n); 477 + const s = 1.0 / @sqrt(ms + 1e-5); 478 + const result = alloc.alloc(f64, n) catch unreachable; 479 + for (0..n) |i| { 480 + result[i] = x[i] * s; 481 + } 482 + return result; 483 + } 484 + 485 + fn doGptF64( 486 + token_id: usize, 487 + pos_id: usize, 488 + keys: *[n_layer]ArrayList([]f64), 489 + kv_values: *[n_layer]ArrayList([]f64), 490 + sd: *const StateDict, 491 + alloc: Allocator, 492 + ) []f64 { 493 + const tok_emb = sd.wte[token_id]; 494 + const pos_emb = sd.wpe[pos_id]; 495 + var x = alloc.alloc(f64, n_embd) catch unreachable; 496 + for (0..n_embd) |i| { 497 + x[i] = tok_emb[i].data + pos_emb[i].data; 498 + } 499 + x = doRmsnormF64(x, alloc); 500 + 501 + for (0..n_layer) |li| { 502 + // 1) Multi-head attention block 503 + const x_residual = x; 504 + var xn = doRmsnormF64(x, alloc); 505 + 506 + const q = doLinearF64(xn, sd.attn_wq[li], alloc); 507 + const k_vec = doLinearF64(xn, sd.attn_wk[li], alloc); 508 + const v_vec = doLinearF64(xn, sd.attn_wv[li], alloc); 509 + 510 + keys[li].append(alloc, k_vec) catch unreachable; 511 + kv_values[li].append(alloc, v_vec) catch unreachable; 512 + 513 + const x_attn = alloc.alloc(f64, n_embd) catch unreachable; 514 + for (0..n_head) |h| { 515 + const hs = h * head_dim; 516 + const q_h = q[hs .. hs + head_dim]; 517 + 518 + const num_t = keys[li].items.len; 519 + const attn_logits = alloc.alloc(f64, num_t) catch unreachable; 520 + for (0..num_t) |t| { 521 + const kt = keys[li].items[t]; 522 + var dot: f64 = 0; 523 + for (0..head_dim) |j_| { 524 + dot += q_h[j_] * kt[hs + j_]; 525 + } 526 + attn_logits[t] = dot * attn_scale; 527 + } 528 + const attn_weights = doSoftmaxF64(attn_logits); 529 + 530 + for (0..head_dim) |j_| { 531 + var sum: f64 = 0; 532 + for (0..num_t) |t| { 533 + sum += attn_weights[t] * kv_values[li].items[t][hs + j_]; 534 + } 535 + x_attn[hs + j_] = sum; 536 + } 537 + } 538 + 539 + xn = doLinearF64(x_attn, sd.attn_wo[li], alloc); 540 + // residual connection 541 + const x_res1 = alloc.alloc(f64, n_embd) catch unreachable; 542 + for (0..n_embd) |i| { 543 + x_res1[i] = xn[i] + x_residual[i]; 544 + } 545 + 546 + // 2) MLP block 547 + const x_residual2 = x_res1; 548 + const xn2 = doRmsnormF64(x_res1, alloc); 549 + const fc1_out = doLinearF64(xn2, sd.mlp_fc1[li], alloc); 550 + for (fc1_out) |*v| { 551 + v.* = @max(0.0, v.*); 552 + } 553 + const fc2_out = doLinearF64(fc1_out, sd.mlp_fc2[li], alloc); 554 + x = alloc.alloc(f64, n_embd) catch unreachable; 555 + for (0..n_embd) |i| { 556 + x[i] = fc2_out[i] + x_residual2[i]; 557 + } 558 + } 559 + 560 + return doLinearF64(x, sd.lm_head, alloc); 561 + } 562 + 563 + // ─── Main ─── 564 + pub fn main(init: std.process.Init) !void { 565 + const io = init.io; 566 + const alloc = init.gpa; 567 + 568 + var stdout_buffer: [4096]u8 = undefined; 569 + var stdout_writer = Io.File.stdout().writer(io, &stdout_buffer); 570 + const stdout = &stdout_writer.interface; 571 + 572 + // Read input.txt 573 + const file_data = try Io.Dir.cwd().readFileAlloc(io, "input.txt", alloc, .limited(10 * 1024 * 1024)); 574 + defer alloc.free(file_data); 575 + 576 + // Split into docs 577 + var doc_list: ArrayList([]const u8) = .empty; 578 + defer doc_list.deinit(alloc); 579 + const trimmed_data = std.mem.trim(u8, file_data, &[_]u8{ '\n', '\r', ' ', '\t' }); 580 + var line_iter = std.mem.splitScalar(u8, trimmed_data, '\n'); 581 + while (line_iter.next()) |line| { 582 + const trimmed = std.mem.trim(u8, line, &[_]u8{ ' ', '\t', '\r' }); 583 + if (trimmed.len > 0) { 584 + try doc_list.append(alloc, trimmed); 585 + } 586 + } 587 + var docs = doc_list.items; 588 + 589 + mtSeed(42); 590 + mtShuffle([]const u8, docs); 591 + try stdout.print("num docs: {d}\n", .{docs.len}); 592 + 593 + // Tokenizer: unique characters 594 + var char_set = std.AutoHashMap(u8, void).init(alloc); 595 + defer char_set.deinit(); 596 + for (docs) |doc| { 597 + for (doc) |ch| { 598 + try char_set.put(ch, {}); 599 + } 600 + } 601 + var uchars_list: ArrayList(u8) = .empty; 602 + defer uchars_list.deinit(alloc); 603 + var it = char_set.keyIterator(); 604 + while (it.next()) |key| { 605 + try uchars_list.append(alloc, key.*); 606 + } 607 + std.mem.sort(u8, uchars_list.items, {}, std.sort.asc(u8)); 608 + const uchars = uchars_list.items; 609 + 610 + var char_to_id = std.AutoHashMap(u8, usize).init(alloc); 611 + defer char_to_id.deinit(); 612 + for (uchars, 0..) |ch, i| { 613 + try char_to_id.put(ch, i); 614 + } 615 + 616 + const BOS = uchars.len; 617 + const vocab_size = uchars.len + 1; 618 + try stdout.print("vocab size: {d}\n", .{vocab_size}); 619 + 620 + // Initialize parameters 621 + const param_alloc = std.heap.page_allocator; 622 + 623 + var sd: StateDict = undefined; 624 + sd.wte = makeMatrix(param_alloc, vocab_size, n_embd, 0.08); 625 + sd.wpe = makeMatrix(param_alloc, block_size, n_embd, 0.08); 626 + sd.lm_head = makeMatrix(param_alloc, vocab_size, n_embd, 0.08); 627 + for (0..n_layer) |i| { 628 + sd.attn_wq[i] = makeMatrix(param_alloc, n_embd, n_embd, 0.08); 629 + sd.attn_wk[i] = makeMatrix(param_alloc, n_embd, n_embd, 0.08); 630 + sd.attn_wv[i] = makeMatrix(param_alloc, n_embd, n_embd, 0.08); 631 + sd.attn_wo[i] = makeMatrix(param_alloc, n_embd, n_embd, 0.08); 632 + sd.mlp_fc1[i] = makeMatrix(param_alloc, 4 * n_embd, n_embd, 0.08); 633 + sd.mlp_fc2[i] = makeMatrix(param_alloc, n_embd, 4 * n_embd, 0.08); 634 + } 635 + 636 + // Flatten params 637 + var params_list: ArrayList(*Value) = .empty; 638 + const matrices = [_][][]*Value{ sd.wte, sd.wpe, sd.lm_head }; 639 + for (matrices) |mat_| { 640 + for (mat_) |row| { 641 + for (row) |p| { 642 + try params_list.append(param_alloc, p); 643 + } 644 + } 645 + } 646 + for (0..n_layer) |li| { 647 + const layer_mats = [_][][]*Value{ 648 + sd.attn_wq[li], sd.attn_wk[li], sd.attn_wv[li], sd.attn_wo[li], 649 + sd.mlp_fc1[li], sd.mlp_fc2[li], 650 + }; 651 + for (layer_mats) |mat_| { 652 + for (mat_) |row| { 653 + for (row) |p| { 654 + try params_list.append(param_alloc, p); 655 + } 656 + } 657 + } 658 + } 659 + const params = params_list.items; 660 + try stdout.print("num params: {d}\n", .{params.len}); 661 + try stdout_writer.flush(); 662 + 663 + // Adam optimizer buffers 664 + const m_buf = try param_alloc.alloc(f64, params.len); 665 + const v_buf = try param_alloc.alloc(f64, params.len); 666 + @memset(m_buf, 0); 667 + @memset(v_buf, 0); 668 + 669 + const learning_rate: f64 = 0.01; 670 + const beta1: f64 = 0.85; 671 + const beta2: f64 = 0.99; 672 + const eps_adam: f64 = 1e-8; 673 + const num_steps: usize = 1000; 674 + 675 + // Training loop 676 + for (0..num_steps) |step| { 677 + const doc = docs[step % docs.len]; 678 + var tokens: ArrayList(usize) = .empty; 679 + defer tokens.deinit(alloc); 680 + try tokens.ensureTotalCapacity(alloc, doc.len + 2); 681 + tokens.appendAssumeCapacity(BOS); 682 + for (doc) |ch| { 683 + tokens.appendAssumeCapacity(char_to_id.get(ch).?); 684 + } 685 + tokens.appendAssumeCapacity(BOS); 686 + const n = @min(block_size, tokens.items.len - 1); 687 + 688 + var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 689 + defer arena.deinit(); 690 + const step_alloc = arena.allocator(); 691 + 692 + var keys: [n_layer]ArrayList([]*Value) = undefined; 693 + var kv_values: [n_layer]ArrayList([]*Value) = undefined; 694 + for (0..n_layer) |li| { 695 + keys[li] = .empty; 696 + kv_values[li] = .empty; 697 + } 698 + 699 + var losses: ArrayList(*Value) = .empty; 700 + losses.ensureTotalCapacity(step_alloc, n) catch unreachable; 701 + for (0..n) |pos_id| { 702 + const token_id = tokens.items[pos_id]; 703 + const target_id = tokens.items[pos_id + 1]; 704 + const logits = doGpt(token_id, pos_id, &keys, &kv_values, &sd, step_alloc); 705 + const probs = doSoftmax(logits, step_alloc); 706 + const loss_t = probs[target_id].logV(step_alloc).neg(step_alloc); 707 + losses.append(step_alloc, loss_t) catch unreachable; 708 + } 709 + const loss = sumValues(losses.items, step_alloc).mulS(1.0 / @as(f64, @floatFromInt(n)), step_alloc); 710 + 711 + doBackward(loss, step_alloc); 712 + 713 + const lr_t = learning_rate * (1.0 - @as(f64, @floatFromInt(step)) / @as(f64, @floatFromInt(num_steps))); 714 + const step_f: f64 = @floatFromInt(step + 1); 715 + const bc1 = 1.0 - math.pow(f64, beta1, step_f); 716 + const bc2 = 1.0 - math.pow(f64, beta2, step_f); 717 + for (params, 0..) |p, i| { 718 + m_buf[i] = beta1 * m_buf[i] + (1.0 - beta1) * p.grad; 719 + v_buf[i] = beta2 * v_buf[i] + (1.0 - beta2) * p.grad * p.grad; 720 + const m_hat = m_buf[i] / bc1; 721 + const v_hat = v_buf[i] / bc2; 722 + p.data -= lr_t * m_hat / (@sqrt(v_hat) + eps_adam); 723 + p.grad = 0; 724 + } 725 + 726 + try stdout.print("step {d:>4} / {d:>4} | loss {d:.4}\n", .{ step + 1, num_steps, loss.data }); 727 + try stdout_writer.flush(); 728 + } 729 + 730 + // Inference 731 + const temperature: f64 = 0.5; 732 + try stdout.print("\n--- inference (new, hallucinated names) ---\n", .{}); 733 + try stdout_writer.flush(); 734 + 735 + for (0..20) |sample_idx| { 736 + var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); 737 + defer arena.deinit(); 738 + const inf_alloc = arena.allocator(); 739 + 740 + var f_keys: [n_layer]ArrayList([]f64) = undefined; 741 + var f_values: [n_layer]ArrayList([]f64) = undefined; 742 + for (0..n_layer) |li| { 743 + f_keys[li] = .empty; 744 + f_values[li] = .empty; 745 + } 746 + 747 + var token_id: usize = BOS; 748 + var sample: ArrayList(u8) = .empty; 749 + 750 + for (0..block_size) |pos_id| { 751 + const logits = doGptF64(token_id, pos_id, &f_keys, &f_values, &sd, inf_alloc); 752 + const scaled = inf_alloc.alloc(f64, logits.len) catch unreachable; 753 + for (0..logits.len) |i| { 754 + scaled[i] = logits[i] / temperature; 755 + } 756 + const probs = doSoftmaxF64(scaled); 757 + token_id = mtChoices(probs, inf_alloc); 758 + if (token_id == BOS) break; 759 + sample.append(inf_alloc, uchars[token_id]) catch unreachable; 760 + } 761 + try stdout.print("sample {d:>2}: {s}\n", .{ sample_idx + 1, sample.items }); 762 + } 763 + try stdout_writer.flush(); 764 + }