Fork of https://github.com/xenova/microgpt.js
1"""
2The most atomic way to train and inference a GPT in pure, dependency-free Python.
3This file is the complete algorithm.
4Everything else is just efficiency.
5
6@karpathy
7"""
8
9import math # math.log, math.exp
10import random # random.seed, random.choices, random.gauss, random.shuffle
11random.seed(42) # Let there be order among chaos
12
13docs = [l.strip() for l in open('input.txt').read().strip().split('\n') if l.strip()] # list[str] of documents
14random.shuffle(docs)
15print(f"num docs: {len(docs)}")
16
17# Let there be a Tokenizer to translate strings to discrete symbols and back
18uchars = sorted(set(''.join(docs))) # unique characters in the dataset become token ids 0..n-1
19BOS = len(uchars) # token id for the special Beginning of Sequence (BOS) token
20vocab_size = len(uchars) + 1 # total number of unique tokens, +1 is for BOS
21print(f"vocab size: {vocab_size}")
22
23# Let there be Autograd, to recursively apply the chain rule through a computation graph
24class Value:
25 __slots__ = ('data', 'grad', '_children', '_local_grads') # Python optimization for memory usage
26
27 def __init__(self, data, children=(), local_grads=()):
28 self.data = data # scalar value of this node calculated during forward pass
29 self.grad = 0 # derivative of the loss w.r.t. this node, calculated in backward pass
30 self._children = children # children of this node in the computation graph
31 self._local_grads = local_grads # local derivative of this node w.r.t. its children
32
33 def __add__(self, other):
34 other = other if isinstance(other, Value) else Value(other)
35 return Value(self.data + other.data, (self, other), (1, 1))
36
37 def __mul__(self, other):
38 other = other if isinstance(other, Value) else Value(other)
39 return Value(self.data * other.data, (self, other), (other.data, self.data))
40
41 def __pow__(self, other): return Value(self.data**other, (self,), (other * self.data**(other-1),))
42 def log(self): return Value(math.log(self.data), (self,), (1/self.data,))
43 def exp(self): return Value(math.exp(self.data), (self,), (math.exp(self.data),))
44 def relu(self): return Value(max(0, self.data), (self,), (float(self.data > 0),))
45 def __neg__(self): return self * -1
46 def __radd__(self, other): return self + other
47 def __sub__(self, other): return self + (-other)
48 def __rsub__(self, other): return other + (-self)
49 def __rmul__(self, other): return self * other
50 def __truediv__(self, other): return self * other**-1
51 def __rtruediv__(self, other): return other * self**-1
52
53 def backward(self):
54 topo = []
55 visited = set()
56 def build_topo(v):
57 if v not in visited:
58 visited.add(v)
59 for child in v._children:
60 build_topo(child)
61 topo.append(v)
62 build_topo(self)
63 self.grad = 1
64 for v in reversed(topo):
65 for child, local_grad in zip(v._children, v._local_grads):
66 child.grad += local_grad * v.grad
67
68# Initialize the parameters, to store the knowledge of the model.
69n_embd = 16 # embedding dimension
70n_head = 4 # number of attention heads
71n_layer = 1 # number of layers
72block_size = 16 # maximum sequence length
73head_dim = n_embd // n_head # dimension of each head
74matrix = lambda nout, nin, std=0.08: [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)]
75state_dict = {'wte': matrix(vocab_size, n_embd), 'wpe': matrix(block_size, n_embd), 'lm_head': matrix(vocab_size, n_embd)}
76for i in range(n_layer):
77 state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd)
78 state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd)
79 state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd)
80 state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd)
81 state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd)
82 state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd)
83params = [p for mat in state_dict.values() for row in mat for p in row] # flatten params into a single list[Value]
84print(f"num params: {len(params)}")
85
86# Define the model architecture: a stateless function mapping token sequence and parameters to logits over what comes next.
87# Follow GPT-2, blessed among the GPTs, with minor differences: layernorm -> rmsnorm, no biases, GeLU -> ReLU
88def linear(x, w):
89 return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]
90
91def softmax(logits):
92 max_val = max(val.data for val in logits)
93 exps = [(val - max_val).exp() for val in logits]
94 total = sum(exps)
95 return [e / total for e in exps]
96
97def rmsnorm(x):
98 ms = sum(xi * xi for xi in x) / len(x)
99 scale = (ms + 1e-5) ** -0.5
100 return [xi * scale for xi in x]
101
102def gpt(token_id, pos_id, keys, values):
103 tok_emb = state_dict['wte'][token_id] # token embedding
104 pos_emb = state_dict['wpe'][pos_id] # position embedding
105 x = [t + p for t, p in zip(tok_emb, pos_emb)] # joint token and position embedding
106 x = rmsnorm(x)
107
108 for li in range(n_layer):
109 # 1) Multi-head attention block
110 x_residual = x
111 x = rmsnorm(x)
112 q = linear(x, state_dict[f'layer{li}.attn_wq'])
113 k = linear(x, state_dict[f'layer{li}.attn_wk'])
114 v = linear(x, state_dict[f'layer{li}.attn_wv'])
115 keys[li].append(k)
116 values[li].append(v)
117 x_attn = []
118 for h in range(n_head):
119 hs = h * head_dim
120 q_h = q[hs:hs+head_dim]
121 k_h = [ki[hs:hs+head_dim] for ki in keys[li]]
122 v_h = [vi[hs:hs+head_dim] for vi in values[li]]
123 attn_logits = [sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / head_dim**0.5 for t in range(len(k_h))]
124 attn_weights = softmax(attn_logits)
125 head_out = [sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h))) for j in range(head_dim)]
126 x_attn.extend(head_out)
127 x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
128 x = [a + b for a, b in zip(x, x_residual)]
129 # 2) MLP block
130 x_residual = x
131 x = rmsnorm(x)
132 x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
133 x = [xi.relu() for xi in x]
134 x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
135 x = [a + b for a, b in zip(x, x_residual)]
136
137 logits = linear(x, state_dict['lm_head'])
138 return logits
139
140# Let there be Adam, the blessed optimizer and its buffers
141learning_rate, beta1, beta2, eps_adam = 0.01, 0.85, 0.99, 1e-8
142m = [0.0] * len(params) # first moment buffer
143v = [0.0] * len(params) # second moment buffer
144
145# Repeat in sequence
146num_steps = 1000 # number of training steps
147for step in range(num_steps):
148
149 # Take single document, tokenize it, surround it with BOS special token on both sides
150 doc = docs[step % len(docs)]
151 tokens = [BOS] + [uchars.index(ch) for ch in doc] + [BOS]
152 n = min(block_size, len(tokens) - 1)
153
154 # Forward the token sequence through the model, building up the computation graph all the way to the loss.
155 keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
156 losses = []
157 for pos_id in range(n):
158 token_id, target_id = tokens[pos_id], tokens[pos_id + 1]
159 logits = gpt(token_id, pos_id, keys, values)
160 probs = softmax(logits)
161 loss_t = -probs[target_id].log()
162 losses.append(loss_t)
163 loss = (1 / n) * sum(losses) # final average loss over the document sequence. May yours be low.
164
165 # Backward the loss, calculating the gradients with respect to all model parameters.
166 loss.backward()
167
168 # Adam optimizer update: update the model parameters based on the corresponding gradients.
169 lr_t = learning_rate * (1 - step / num_steps) # linear learning rate decay
170 for i, p in enumerate(params):
171 m[i] = beta1 * m[i] + (1 - beta1) * p.grad
172 v[i] = beta2 * v[i] + (1 - beta2) * p.grad ** 2
173 m_hat = m[i] / (1 - beta1 ** (step + 1))
174 v_hat = v[i] / (1 - beta2 ** (step + 1))
175 p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)
176 p.grad = 0
177
178 print(f"step {step+1:4d} / {num_steps:4d} | loss {loss.data:.4f}")
179
180# Inference: may the model babble back to us
181temperature = 0.5 # in (0, 1], control the "creativity" of generated text, low to high
182print("\n--- inference (new, hallucinated names) ---")
183for sample_idx in range(20):
184 keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
185 token_id = BOS
186 sample = []
187 for pos_id in range(block_size):
188 logits = gpt(token_id, pos_id, keys, values)
189 probs = softmax([l / temperature for l in logits])
190 token_id = random.choices(range(vocab_size), weights=[p.data for p in probs])[0]
191 if token_id == BOS:
192 break
193 sample.append(uchars[token_id])
194 print(f"sample {sample_idx+1:2d}: {''.join(sample)}")