Fork of https://github.com/xenova/microgpt.js
1/**
2 * The most atomic way to train and inference a GPT in pure, dependency-free JavaScript.
3 * This file is the complete algorithm.
4 * Everything else is just efficiency.
5 *
6 * @karpathy (original Python), @xenova (JavaScript port)
7 */
8
9import fs from 'node:fs'; // for reading the input text file
10import random from './random.js'; // random.seed, random.choices, random.gauss, random.shuffle
11random.seed(42); // Let there be order among chaos
12
13const docs = fs.readFileSync('input.txt', 'utf-8').trim().split('\n').map(l => l.trim()).filter(l => l.length > 0); // list of documents
14random.shuffle(docs);
15console.log(`num docs: ${docs.length}`);
16
17// Let there be a Tokenizer to translate strings to discrete symbols and back
18const uchars = [...new Set(docs.join(''))].sort(); // unique characters in the dataset become token ids 0..n-1
19const char_to_id = new Map(uchars.map((ch, i) => [ch, i])); // fast character lookup
20const BOS = uchars.length; // token id for the special Beginning of Sequence (BOS) token
21const vocab_size = uchars.length + 1; // total number of unique tokens, +1 is for BOS
22console.log(`vocab size: ${vocab_size}`);
23
24// Let there be Autograd, to recursively apply the chain rule through a computation graph
25let _gen = 0; // global generation counter for autograd, to help with topological sorting of the graph during backward pass
26class Value {
27 constructor(data, children = [], local_grads = []) {
28 this.data = data; // scalar value of this node calculated during forward pass
29 this.grad = 0; // derivative of the loss w.r.t. this node, calculated in backward pass
30 this._c0 = children[0]; // children of this node in the computation graph
31 this._c1 = children[1];
32 this._lg0 = local_grads[0]; // local derivative of this node w.r.t. its children
33 this._lg1 = local_grads[1];
34 this._nch = children.length; // number of children (0, 1, or 2)
35 this._gen = 0;
36 }
37
38 add(other) {
39 if (other instanceof Value) return new Value(this.data + other.data, [this, other], [1, 1]);
40 return new Value(this.data + other, [this], [1]);
41 }
42
43 mul(other) {
44 if (other instanceof Value) return new Value(this.data * other.data, [this, other], [other.data, this.data]);
45 return new Value(this.data * other, [this], [other]);
46 }
47
48 pow(other) { return new Value(this.data ** other, [this], [other * this.data ** (other - 1)]); }
49 log() { return new Value(Math.log(this.data), [this], [1 / this.data]); }
50 exp() { const e = Math.exp(this.data); return new Value(e, [this], [e]); }
51 relu() { return new Value(Math.max(0, this.data), [this], [+(this.data > 0)]); }
52 neg() { return new Value(-this.data, [this], [-1]); }
53 sub(other) { return this.add(other instanceof Value ? other.neg() : -other); }
54 div(other) { return this.mul(other instanceof Value ? other.pow(-1) : 1 / other); }
55
56 backward() {
57 const gen = ++_gen;
58 const topo = [];
59 function build_topo(v) {
60 if (v._gen === gen) return;
61 v._gen = gen;
62 if (v._nch >= 1) build_topo(v._c0);
63 if (v._nch === 2) build_topo(v._c1);
64 topo.push(v);
65 }
66 build_topo(this);
67 this.grad = 1;
68 for (let i = topo.length - 1; i >= 0; --i) {
69 const v = topo[i], g = v.grad;
70 if (v._nch >= 1) v._c0.grad += v._lg0 * g;
71 if (v._nch === 2) v._c1.grad += v._lg1 * g;
72 }
73 }
74}
75
76// Initialize the parameters, to store the knowledge of the model.
77const n_embd = 16; // embedding dimension
78const n_head = 4; // number of attention heads
79const n_layer = 1; // number of layers
80const block_size = 16; // maximum sequence length
81const head_dim = Math.floor(n_embd / n_head); // dimension of each head
82const scale = 1 / head_dim ** 0.5; // precomputed attention scale factor
83const matrix = (nout, nin, std = 0.08) => Array.from({ length: nout }, () => Array.from({ length: nin }, () => new Value(random.gauss(0, std))));
84const state_dict = { wte: matrix(vocab_size, n_embd), wpe: matrix(block_size, n_embd), lm_head: matrix(vocab_size, n_embd) };
85for (let i = 0; i < n_layer; ++i) {
86 state_dict[`layer${i}.attn_wq`] = matrix(n_embd, n_embd);
87 state_dict[`layer${i}.attn_wk`] = matrix(n_embd, n_embd);
88 state_dict[`layer${i}.attn_wv`] = matrix(n_embd, n_embd);
89 state_dict[`layer${i}.attn_wo`] = matrix(n_embd, n_embd);
90 state_dict[`layer${i}.mlp_fc1`] = matrix(4 * n_embd, n_embd);
91 state_dict[`layer${i}.mlp_fc2`] = matrix(n_embd, 4 * n_embd);
92}
93const params = Object.values(state_dict).flat(Infinity); // flatten params into a single list of Values
94console.log(`num params: ${params.length}`);
95
96// Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next.
97// Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU
98const sum = (arr) => arr.reduce((a, b) => a.add(b));
99const zip = (a, b) => a.map((ai, i) => [ai, b[i]]);
100
101function linear(x, w) {
102 return w.map(wo => sum(wo.map((wi, i) => wi.mul(x[i]))));
103}
104
105function softmax(logits) {
106 const max_val = Math.max(...logits.map(v => v.data));
107 const exps = logits.map(v => v.sub(max_val).exp());
108 const total = sum(exps);
109 return exps.map(e => e.div(total));
110}
111
112function rmsnorm(x) {
113 const ms = sum(x.map(xi => xi.mul(xi))).mul(1 / x.length);
114 const s = ms.add(1e-5).pow(-0.5);
115 return x.map(xi => xi.mul(s));
116}
117
118function gpt(token_id, pos_id, keys, values) {
119 const tok_emb = state_dict['wte'][token_id]; // token embedding
120 const pos_emb = state_dict['wpe'][pos_id]; // position embedding
121 let x = zip(tok_emb, pos_emb).map(([t, p]) => t.add(p)); // joint token and position embedding
122 x = rmsnorm(x);
123
124 for (let li = 0; li < n_layer; ++li) {
125 // 1) Multi-head attention block
126 let x_residual = x;
127 x = rmsnorm(x);
128 const q = linear(x, state_dict[`layer${li}.attn_wq`]);
129 const k = linear(x, state_dict[`layer${li}.attn_wk`]);
130 const v = linear(x, state_dict[`layer${li}.attn_wv`]);
131 keys[li].push(k);
132 values[li].push(v);
133 const x_attn = [];
134 for (let h = 0; h < n_head; ++h) {
135 const hs = h * head_dim;
136 const q_h = q.slice(hs, hs + head_dim);
137 const k_h = keys[li].map(ki => ki.slice(hs, hs + head_dim));
138 const v_h = values[li].map(vi => vi.slice(hs, hs + head_dim));
139 const attn_logits = k_h.map(kt => sum(zip(q_h, kt).map(([qi, ki]) => qi.mul(ki))).mul(scale));
140 const attn_weights = softmax(attn_logits);
141 for (let j = 0; j < head_dim; ++j)
142 x_attn.push(sum(attn_weights.map((aw, t) => aw.mul(v_h[t][j]))));
143 }
144 x = linear(x_attn, state_dict[`layer${li}.attn_wo`]);
145 x = x.map((a, i) => a.add(x_residual[i]));
146 // 2) MLP block
147 x_residual = x;
148 x = rmsnorm(x);
149 x = linear(x, state_dict[`layer${li}.mlp_fc1`]);
150 x = x.map(xi => xi.relu());
151 x = linear(x, state_dict[`layer${li}.mlp_fc2`]);
152 x = x.map((a, i) => a.add(x_residual[i]));
153 }
154
155 return linear(x, state_dict['lm_head']);
156}
157
158// Let there be Adam, the blessed optimizer and its buffers
159const learning_rate = 0.01, beta1 = 0.85, beta2 = 0.99, eps_adam = 1e-8;
160const m_buf = new Float64Array(params.length); // first moment buffer
161const v_buf = new Float64Array(params.length); // second moment buffer
162
163// Repeat in sequence
164const num_steps = 1000; // number of training steps
165for (let step = 0; step < num_steps; ++step) {
166
167 // Take single document, tokenize it, surround it with BOS special token on both sides
168 const doc = docs[step % docs.length];
169 const tokens = [BOS, ...Array.from(doc, ch => char_to_id.get(ch)), BOS];
170 const n = Math.min(block_size, tokens.length - 1);
171
172 // Forward the token sequence through the model, building up the computation graph all the way to the loss.
173 const keys = Array.from({ length: n_layer }, () => []);
174 const values = Array.from({ length: n_layer }, () => []);
175 const losses = [];
176 for (let pos_id = 0; pos_id < n; ++pos_id) {
177 const token_id = tokens[pos_id], target_id = tokens[pos_id + 1];
178 const logits = gpt(token_id, pos_id, keys, values);
179 const probs = softmax(logits);
180 const loss_t = probs[target_id].log().neg();
181 losses.push(loss_t);
182 }
183 const loss = sum(losses).mul(1 / n); // final average loss over the document sequence. May yours be low.
184
185 // Backward the loss, calculating the gradients with respect to all model parameters.
186 loss.backward();
187
188 // Adam optimizer update: update the model parameters based on the corresponding gradients.
189 const lr_t = learning_rate * (1 - step / num_steps); // linear learning rate decay
190 const bc1 = 1 - beta1 ** (step + 1), bc2 = 1 - beta2 ** (step + 1);
191 for (let i = 0; i < params.length; ++i) {
192 const p = params[i];
193 m_buf[i] = beta1 * m_buf[i] + (1 - beta1) * p.grad;
194 v_buf[i] = beta2 * v_buf[i] + (1 - beta2) * p.grad ** 2;
195 const m_hat = m_buf[i] / bc1;
196 const v_hat = v_buf[i] / bc2;
197 p.data -= lr_t * m_hat / (Math.sqrt(v_hat) + eps_adam);
198 p.grad = 0;
199 }
200
201 process.stdout.write(`step ${String(step + 1).padStart(4)} / ${String(num_steps).padStart(4)} | loss ${loss.data.toFixed(4)}\r`);
202}
203
204// Inference: may the model babble back to us
205const temperature = 0.5; // in (0, 1], control the "creativity" of generated text, low to high
206const token_ids = Array.from({ length: vocab_size }, (_, i) => i);
207console.log('\n--- inference (new, hallucinated names) ---');
208for (let sample_idx = 0; sample_idx < 20; ++sample_idx) {
209 const keys = Array.from({ length: n_layer }, () => []);
210 const values = Array.from({ length: n_layer }, () => []);
211 let token_id = BOS;
212 const sample = [];
213 for (let pos_id = 0; pos_id < block_size; ++pos_id) {
214 const logits = gpt(token_id, pos_id, keys, values);
215 const probs = softmax(logits.map(l => l.div(temperature)));
216 token_id = random.choices(token_ids, probs.map(p => p.data));
217 if (token_id === BOS) break;
218 sample.push(uchars[token_id]);
219 }
220 console.log(`sample ${String(sample_idx + 1).padStart(2)}: ${sample.join('')}`);
221}