Fork of https://github.com/xenova/microgpt.js
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 194 lines 8.8 kB view raw
1""" 2The most atomic way to train and inference a GPT in pure, dependency-free Python. 3This file is the complete algorithm. 4Everything else is just efficiency. 5 6@karpathy 7""" 8 9import math # math.log, math.exp 10import random # random.seed, random.choices, random.gauss, random.shuffle 11random.seed(42) # Let there be order among chaos 12 13docs = [l.strip() for l in open('input.txt').read().strip().split('\n') if l.strip()] # list[str] of documents 14random.shuffle(docs) 15print(f"num docs: {len(docs)}") 16 17# Let there be a Tokenizer to translate strings to discrete symbols and back 18uchars = sorted(set(''.join(docs))) # unique characters in the dataset become token ids 0..n-1 19BOS = len(uchars) # token id for the special Beginning of Sequence (BOS) token 20vocab_size = len(uchars) + 1 # total number of unique tokens, +1 is for BOS 21print(f"vocab size: {vocab_size}") 22 23# Let there be Autograd, to recursively apply the chain rule through a computation graph 24class Value: 25 __slots__ = ('data', 'grad', '_children', '_local_grads') # Python optimization for memory usage 26 27 def __init__(self, data, children=(), local_grads=()): 28 self.data = data # scalar value of this node calculated during forward pass 29 self.grad = 0 # derivative of the loss w.r.t. this node, calculated in backward pass 30 self._children = children # children of this node in the computation graph 31 self._local_grads = local_grads # local derivative of this node w.r.t. its children 32 33 def __add__(self, other): 34 other = other if isinstance(other, Value) else Value(other) 35 return Value(self.data + other.data, (self, other), (1, 1)) 36 37 def __mul__(self, other): 38 other = other if isinstance(other, Value) else Value(other) 39 return Value(self.data * other.data, (self, other), (other.data, self.data)) 40 41 def __pow__(self, other): return Value(self.data**other, (self,), (other * self.data**(other-1),)) 42 def log(self): return Value(math.log(self.data), (self,), (1/self.data,)) 43 def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),)) 44 def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),)) 45 def __neg__(self): return self * -1 46 def __radd__(self, other): return self + other 47 def __sub__(self, other): return self + (-other) 48 def __rsub__(self, other): return other + (-self) 49 def __rmul__(self, other): return self * other 50 def __truediv__(self, other): return self * other**-1 51 def __rtruediv__(self, other): return other * self**-1 52 53 def backward(self): 54 topo = [] 55 visited = set() 56 def build_topo(v): 57 if v not in visited: 58 visited.add(v) 59 for child in v._children: 60 build_topo(child) 61 topo.append(v) 62 build_topo(self) 63 self.grad = 1 64 for v in reversed(topo): 65 for child, local_grad in zip(v._children, v._local_grads): 66 child.grad += local_grad * v.grad 67 68# Initialize the parameters, to store the knowledge of the model. 69n_embd = 16 # embedding dimension 70n_head = 4 # number of attention heads 71n_layer = 1 # number of layers 72block_size = 16 # maximum sequence length 73head_dim = n_embd // n_head # dimension of each head 74matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)] 75state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)} 76for i in range(n_layer): 77 state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd) 78 state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd) 79 state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd) 80 state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd) 81 state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd) 82 state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd) 83params = [p for mat in state_dict.values() for row in mat for p in row] # flatten params into a single list[Value] 84print(f"num params: {len(params)}") 85 86# Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next. 87# Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU 88def linear(x, w): 89 return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w] 90 91def softmax(logits): 92 max_val = max(val.data for val in logits) 93 exps = [(val - max_val).exp() for val in logits] 94 total = sum(exps) 95 return [e / total for e in exps] 96 97def rmsnorm(x): 98 ms = sum(xi * xi for xi in x) / len(x) 99 scale = (ms + 1e-5) ** -0.5 100 return [xi * scale for xi in x] 101 102def gpt(token_id, pos_id, keys, values): 103 tok_emb = state_dict['wte'][token_id] # token embedding 104 pos_emb = state_dict['wpe'][pos_id] # position embedding 105 x = [t + p for t, p in zip(tok_emb, pos_emb)] # joint token and position embedding 106 x = rmsnorm(x) 107 108 for li in range(n_layer): 109 # 1) Multi-head attention block 110 x_residual = x 111 x = rmsnorm(x) 112 q = linear(x, state_dict[f'layer{li}.attn_wq']) 113 k = linear(x, state_dict[f'layer{li}.attn_wk']) 114 v = linear(x, state_dict[f'layer{li}.attn_wv']) 115 keys[li].append(k) 116 values[li].append(v) 117 x_attn = [] 118 for h in range(n_head): 119 hs = h * head_dim 120 q_h = q[hs:hs+head_dim] 121 k_h = [ki[hs:hs+head_dim] for ki in keys[li]] 122 v_h = [vi[hs:hs+head_dim] for vi in values[li]] 123 attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))] 124 attn_weights = softmax(attn_logits) 125 head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)] 126 x_attn.extend(head_out) 127 x = linear(x_attn, state_dict[f'layer{li}.attn_wo']) 128 x = [a + b for a, b in zip(x, x_residual)] 129 # 2) MLP block 130 x_residual = x 131 x = rmsnorm(x) 132 x = linear(x, state_dict[f'layer{li}.mlp_fc1']) 133 x = [xi.relu() for xi in x] 134 x = linear(x, state_dict[f'layer{li}.mlp_fc2']) 135 x = [a + b for a, b in zip(x, x_residual)] 136 137 logits = linear(x, state_dict['lm_head']) 138 return logits 139 140# Let there be Adam, the blessed optimizer and its buffers 141learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8 142m = [0.0] * len(params) # first moment buffer 143v = [0.0] * len(params) # second moment buffer 144 145# Repeat in sequence 146num_steps = 1000 # number of training steps 147for step in range(num_steps): 148 149 # Take single document, tokenize it, surround it with BOS special token on both sides 150 doc = docs[step % len(docs)] 151 tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS] 152 n = min(block_size, len(tokens) - 1) 153 154 # Forward the token sequence through the model, building up the computation graph all the way to the loss. 155 keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)] 156 losses = [] 157 for pos_id in range(n): 158 token_id, target_id = tokens[pos_id], tokens[pos_id + 1] 159 logits = gpt(token_id, pos_id, keys, values) 160 probs = softmax(logits) 161 loss_t = -probs[target_id].log() 162 losses.append(loss_t) 163 loss = (1 / n) * sum(losses) # final average loss over the document sequence. May yours be low. 164 165 # Backward the loss, calculating the gradients with respect to all model parameters. 166 loss.backward() 167 168 # Adam optimizer update: update the model parameters based on the corresponding gradients. 169 lr_t = learning_rate * (1 - step / num_steps) # linear learning rate decay 170 for i, p in enumerate(params): 171 m[i] = beta1 * m[i] + (1 - beta1) * p.grad 172 v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2 173 m_hat = m[i] / (1 - beta1 ** (step + 1)) 174 v_hat = v[i] / (1 - beta2 ** (step + 1)) 175 p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam) 176 p.grad = 0 177 178 print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}") 179 180# Inference: may the model babble back to us 181temperature = 0.5 # in (0, 1], control the "creativity" of generated text, low to high 182print("\n--- inference (new, hallucinated names) ---") 183for sample_idx in range(20): 184 keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)] 185 token_id = BOS 186 sample = [] 187 for pos_id in range(block_size): 188 logits = gpt(token_id, pos_id, keys, values) 189 probs = softmax([l / temperature for l in logits]) 190 token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0] 191 if token_id == BOS: 192 break 193 sample.append(uchars[token_id]) 194 print(f"sample {sample_idx+1:2d}: {''.join(sample)}")