Fork of https://github.com/xenova/microgpt.js
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 221 lines 10 kB view raw
1/** 2 * The most atomic way to train and inference a GPT in pure, dependency-free JavaScript. 3 * This file is the complete algorithm. 4 * Everything else is just efficiency. 5 * 6 * @karpathy (original Python), @xenova (JavaScript port) 7 */ 8 9import fs from 'node:fs'; // for reading the input text file 10import random from './random.js'; // random.seed, random.choices, random.gauss, random.shuffle 11random.seed(42); // Let there be order among chaos 12 13const docs = fs.readFileSync('input.txt', 'utf-8').trim().split('\n').map(l => l.trim()).filter(l => l.length > 0); // list of documents 14random.shuffle(docs); 15console.log(`num docs: ${docs.length}`); 16 17// Let there be a Tokenizer to translate strings to discrete symbols and back 18const uchars = [...new Set(docs.join(''))].sort(); // unique characters in the dataset become token ids 0..n-1 19const char_to_id = new Map(uchars.map((ch, i) => [ch, i])); // fast character lookup 20const BOS = uchars.length; // token id for the special Beginning of Sequence (BOS) token 21const vocab_size = uchars.length + 1; // total number of unique tokens, +1 is for BOS 22console.log(`vocab size: ${vocab_size}`); 23 24// Let there be Autograd, to recursively apply the chain rule through a computation graph 25let _gen = 0; // global generation counter for autograd, to help with topological sorting of the graph during backward pass 26class Value { 27 constructor(data, children = [], local_grads = []) { 28 this.data = data; // scalar value of this node calculated during forward pass 29 this.grad = 0; // derivative of the loss w.r.t. this node, calculated in backward pass 30 this._c0 = children[0]; // children of this node in the computation graph 31 this._c1 = children[1]; 32 this._lg0 = local_grads[0]; // local derivative of this node w.r.t. its children 33 this._lg1 = local_grads[1]; 34 this._nch = children.length; // number of children (0, 1, or 2) 35 this._gen = 0; 36 } 37 38 add(other) { 39 if (other instanceof Value) return new Value(this.data + other.data, [this, other], [1, 1]); 40 return new Value(this.data + other, [this], [1]); 41 } 42 43 mul(other) { 44 if (other instanceof Value) return new Value(this.data * other.data, [this, other], [other.data, this.data]); 45 return new Value(this.data * other, [this], [other]); 46 } 47 48 pow(other) { return new Value(this.data ** other, [this], [other * this.data ** (other - 1)]); } 49 log() { return new Value(Math.log(this.data), [this], [1 / this.data]); } 50 exp() { const e = Math.exp(this.data); return new Value(e, [this], [e]); } 51 relu() { return new Value(Math.max(0, this.data), [this], [+(this.data > 0)]); } 52 neg() { return new Value(-this.data, [this], [-1]); } 53 sub(other) { return this.add(other instanceof Value ? other.neg() : -other); } 54 div(other) { return this.mul(other instanceof Value ? other.pow(-1) : 1 / other); } 55 56 backward() { 57 const gen = ++_gen; 58 const topo = []; 59 function build_topo(v) { 60 if (v._gen === gen) return; 61 v._gen = gen; 62 if (v._nch >= 1) build_topo(v._c0); 63 if (v._nch === 2) build_topo(v._c1); 64 topo.push(v); 65 } 66 build_topo(this); 67 this.grad = 1; 68 for (let i = topo.length - 1; i >= 0; --i) { 69 const v = topo[i], g = v.grad; 70 if (v._nch >= 1) v._c0.grad += v._lg0 * g; 71 if (v._nch === 2) v._c1.grad += v._lg1 * g; 72 } 73 } 74} 75 76// Initialize the parameters, to store the knowledge of the model. 77const n_embd = 16; // embedding dimension 78const n_head = 4; // number of attention heads 79const n_layer = 1; // number of layers 80const block_size = 16; // maximum sequence length 81const head_dim = Math.floor(n_embd / n_head); // dimension of each head 82const scale = 1 / head_dim ** 0.5; // precomputed attention scale factor 83const matrix = (nout, nin, std = 0.08) => Array.from({ length: nout }, () => Array.from({ length: nin }, () => new Value(random.gauss(0, std)))); 84const state_dict = { wte: matrix(vocab_size, n_embd), wpe: matrix(block_size, n_embd), lm_head: matrix(vocab_size, n_embd) }; 85for (let i = 0; i < n_layer; ++i) { 86 state_dict[`layer${i}.attn_wq`] = matrix(n_embd, n_embd); 87 state_dict[`layer${i}.attn_wk`] = matrix(n_embd, n_embd); 88 state_dict[`layer${i}.attn_wv`] = matrix(n_embd, n_embd); 89 state_dict[`layer${i}.attn_wo`] = matrix(n_embd, n_embd); 90 state_dict[`layer${i}.mlp_fc1`] = matrix(4 * n_embd, n_embd); 91 state_dict[`layer${i}.mlp_fc2`] = matrix(n_embd, 4 * n_embd); 92} 93const params = Object.values(state_dict).flat(Infinity); // flatten params into a single list of Values 94console.log(`num params: ${params.length}`); 95 96// Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next. 97// Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU 98const sum = (arr) => arr.reduce((a, b) => a.add(b)); 99const zip = (a, b) => a.map((ai, i) => [ai, b[i]]); 100 101function linear(x, w) { 102 return w.map(wo => sum(wo.map((wi, i) => wi.mul(x[i])))); 103} 104 105function softmax(logits) { 106 const max_val = Math.max(...logits.map(v => v.data)); 107 const exps = logits.map(v => v.sub(max_val).exp()); 108 const total = sum(exps); 109 return exps.map(e => e.div(total)); 110} 111 112function rmsnorm(x) { 113 const ms = sum(x.map(xi => xi.mul(xi))).mul(1 / x.length); 114 const s = ms.add(1e-5).pow(-0.5); 115 return x.map(xi => xi.mul(s)); 116} 117 118function gpt(token_id, pos_id, keys, values) { 119 const tok_emb = state_dict['wte'][token_id]; // token embedding 120 const pos_emb = state_dict['wpe'][pos_id]; // position embedding 121 let x = zip(tok_emb, pos_emb).map(([t, p]) => t.add(p)); // joint token and position embedding 122 x = rmsnorm(x); 123 124 for (let li = 0; li < n_layer; ++li) { 125 // 1) Multi-head attention block 126 let x_residual = x; 127 x = rmsnorm(x); 128 const q = linear(x, state_dict[`layer${li}.attn_wq`]); 129 const k = linear(x, state_dict[`layer${li}.attn_wk`]); 130 const v = linear(x, state_dict[`layer${li}.attn_wv`]); 131 keys[li].push(k); 132 values[li].push(v); 133 const x_attn = []; 134 for (let h = 0; h < n_head; ++h) { 135 const hs = h * head_dim; 136 const q_h = q.slice(hs, hs + head_dim); 137 const k_h = keys[li].map(ki => ki.slice(hs, hs + head_dim)); 138 const v_h = values[li].map(vi => vi.slice(hs, hs + head_dim)); 139 const attn_logits = k_h.map(kt => sum(zip(q_h, kt).map(([qi, ki]) => qi.mul(ki))).mul(scale)); 140 const attn_weights = softmax(attn_logits); 141 for (let j = 0; j < head_dim; ++j) 142 x_attn.push(sum(attn_weights.map((aw, t) => aw.mul(v_h[t][j])))); 143 } 144 x = linear(x_attn, state_dict[`layer${li}.attn_wo`]); 145 x = x.map((a, i) => a.add(x_residual[i])); 146 // 2) MLP block 147 x_residual = x; 148 x = rmsnorm(x); 149 x = linear(x, state_dict[`layer${li}.mlp_fc1`]); 150 x = x.map(xi => xi.relu()); 151 x = linear(x, state_dict[`layer${li}.mlp_fc2`]); 152 x = x.map((a, i) => a.add(x_residual[i])); 153 } 154 155 return linear(x, state_dict['lm_head']); 156} 157 158// Let there be Adam, the blessed optimizer and its buffers 159const learning_rate = 0.01, beta1 = 0.85, beta2 = 0.99, eps_adam = 1e-8; 160const m_buf = new Float64Array(params.length); // first moment buffer 161const v_buf = new Float64Array(params.length); // second moment buffer 162 163// Repeat in sequence 164const num_steps = 1000; // number of training steps 165for (let step = 0; step < num_steps; ++step) { 166 167 // Take single document, tokenize it, surround it with BOS special token on both sides 168 const doc = docs[step % docs.length]; 169 const tokens = [BOS, ...Array.from(doc, ch => char_to_id.get(ch)), BOS]; 170 const n = Math.min(block_size, tokens.length - 1); 171 172 // Forward the token sequence through the model, building up the computation graph all the way to the loss. 173 const keys = Array.from({ length: n_layer }, () => []); 174 const values = Array.from({ length: n_layer }, () => []); 175 const losses = []; 176 for (let pos_id = 0; pos_id < n; ++pos_id) { 177 const token_id = tokens[pos_id], target_id = tokens[pos_id + 1]; 178 const logits = gpt(token_id, pos_id, keys, values); 179 const probs = softmax(logits); 180 const loss_t = probs[target_id].log().neg(); 181 losses.push(loss_t); 182 } 183 const loss = sum(losses).mul(1 / n); // final average loss over the document sequence. May yours be low. 184 185 // Backward the loss, calculating the gradients with respect to all model parameters. 186 loss.backward(); 187 188 // Adam optimizer update: update the model parameters based on the corresponding gradients. 189 const lr_t = learning_rate * (1 - step / num_steps); // linear learning rate decay 190 const bc1 = 1 - beta1 ** (step + 1), bc2 = 1 - beta2 ** (step + 1); 191 for (let i = 0; i < params.length; ++i) { 192 const p = params[i]; 193 m_buf[i] = beta1 * m_buf[i] + (1 - beta1) * p.grad; 194 v_buf[i] = beta2 * v_buf[i] + (1 - beta2) * p.grad ** 2; 195 const m_hat = m_buf[i] / bc1; 196 const v_hat = v_buf[i] / bc2; 197 p.data -= lr_t * m_hat / (Math.sqrt(v_hat) + eps_adam); 198 p.grad = 0; 199 } 200 201 process.stdout.write(`step ${String(step + 1).padStart(4)} / ${String(num_steps).padStart(4)} | loss ${loss.data.toFixed(4)}\r`); 202} 203 204// Inference: may the model babble back to us 205const temperature = 0.5; // in (0, 1], control the "creativity" of generated text, low to high 206const token_ids = Array.from({ length: vocab_size }, (_, i) => i); 207console.log('\n--- inference (new, hallucinated names) ---'); 208for (let sample_idx = 0; sample_idx < 20; ++sample_idx) { 209 const keys = Array.from({ length: n_layer }, () => []); 210 const values = Array.from({ length: n_layer }, () => []); 211 let token_id = BOS; 212 const sample = []; 213 for (let pos_id = 0; pos_id < block_size; ++pos_id) { 214 const logits = gpt(token_id, pos_id, keys, values); 215 const probs = softmax(logits.map(l => l.div(temperature))); 216 token_id = random.choices(token_ids, probs.map(p => p.data)); 217 if (token_id === BOS) break; 218 sample.push(uchars[token_id]); 219 } 220 console.log(`sample ${String(sample_idx + 1).padStart(2)}: ${sample.join('')}`); 221}