### Paper: PyTorch Implementation and Architecture of GPT-2

#### Abstract
In this paper, we present a PyTorch-based implementation of GPT-2, originally developed by OpenAI using TensorFlow. This implementation retains the architectural intricacies and training dynamics of GPT-2 while leveraging the capabilities of PyTorch for efficient computation. We discuss the underlying architecture, including the Causal Self-Attention mechanism, Multi-Layer Perceptron (MLP) blocks, and overall transformer structure. Furthermore, we describe the training process, data loading mechanism, and optimization strategy used to replicate GPT-2's performance.
#### Introduction
The Generative Pre-trained Transformer 2 (GPT-2) model has demonstrated remarkable performance in generating coherent and contextually relevant text. This work replicates the original TensorFlow implementation in PyTorch, providing insights into its architecture and training process. Our implementation includes a detailed code walkthrough and explanations of each component's function within the model.
## GPT-2 Architecture: Simplified
GPT-2 (Generative Pre-trained Transformer 2) is based on the Transformer model, which has been a significant breakthrough in natural language processing (NLP). The architecture consists of the following key components:
1. **Token Embeddings**
2. **Positional Embeddings**
3. **Transformer Blocks**
- Layer Normalization
- Causal Self-Attention
- Feedforward Neural Network (MLP)
4. **Output Layer**
#### 1. Token Embeddings
Token embeddings are used to convert the input tokens (words or subwords) into dense vectors of fixed size. Each token in the vocabulary is represented by a unique vector.
- **Vocabulary Size (vocab_size):** Number of unique tokens in the vocabulary.
- **Embedding Dimension (n_embd):** Size of the dense vector representing each token.
```python
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
```
#### 2. Positional Embeddings
Since the Transformer architecture does not inherently capture the order of tokens in a sequence, positional embeddings are added to the token embeddings to provide positional information.
- **Block Size (block_size):** Maximum sequence length that the model can handle.
- **Positional Embedding Dimension:** Same as the token embedding dimension.
```python
self.wpe = nn.Embedding(config.block_size, config.n_embd)
```
#### 3. Transformer Blocks
GPT-2 consists of multiple Transformer blocks (layers), each containing the following subcomponents:
##### a. Layer Normalization
Layer normalization is applied to stabilize and accelerate the training process. It normalizes the inputs across the features.
```python
self.ln_1 = nn.LayerNorm(config.n_embd)
self.ln_2 = nn.LayerNorm(config.n_embd)
```
##### b. Causal Self-Attention
The causal self-attention mechanism allows each token to attend to all previous tokens in the sequence. It consists of three main steps:
1. **Linear Projections:** Calculate queries (Q), keys (K), and values (V) from the input.
2. **Attention Mechanism:** Compute scaled dot-product attention using Q, K, and V.
3. **Output Projection:** Project the concatenated attention outputs back to the original embedding dimension.
- **Number of Heads (n_head):** Number of attention heads. Each head computes attention independently and the results are concatenated.
- **Head Size:** Dimension of each head, computed as embedding dimension divided by the number of heads.
```python
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
```
##### c. Feedforward Neural Network (MLP)
The feedforward neural network applies two linear transformations with a non-linear activation function (GELU) in between. This allows for more complex representations.
- **Intermediate Size:** Typically four times the embedding dimension.
```python
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = nn.GELU(approximate='tanh')
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
```
#### d. Transformer Block Integration
Each Transformer block combines layer normalization, causal self-attention, and the feedforward neural network.
```python
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
```
#### 4. Output Layer
The final layer is a linear projection that maps the hidden states of the last Transformer block to the vocabulary size. This allows the model to predict the next token in the sequence.
```python
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
```
#### Putting It All Together
The GPT-2 model stacks multiple Transformer blocks and includes token and positional embeddings. The forward pass involves the following steps:
1. **Embedding Lookup:** Convert input tokens to embeddings.
2. **Positional Encoding:** Add positional embeddings to the token embeddings.
3. **Transformer Blocks:** Pass the embeddings through the stacked Transformer blocks.
4. **Output Projection:** Project the final hidden states to the vocabulary size.
```python
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer.wpe(pos)
tok_emb = self.transformer.wte(idx)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
```
### Summary
The GPT-2 architecture is built upon the powerful Transformer model, utilizing token and positional embeddings, multi-head self-attention, and feedforward neural networks to generate coherent and contextually relevant text. By stacking multiple layers of these components, GPT-2 can capture long-range dependencies and generate high-quality text.
---
> [!info]
> Below is a deeper dive into the inner working of the Architecture of GPT-2, with code block to show the different parts of the model.
## Architecture: Technical with Code
##### 1. **Transformer Block**
GPT-2 is built on the Transformer architecture, which relies on self-attention mechanisms and feedforward neural networks. Each block consists of the following components:
- **Layer Normalization (LayerNorm)**
- **Causal Self-Attention (CausalSelfAttention)**
- **Multi-Layer Perceptron (MLP)**
Each block processes the input sequentially through these components, ensuring effective learning of long-range dependencies.
##### 2. **Causal Self-Attention (CausalSelfAttention)**
The Causal Self-Attention mechanism ensures that the model can only attend to previous tokens in the sequence, maintaining the autoregressive property essential for text generation. The attention mechanism can be summarized as follows:
- Compute query (Q), key (K), and value (V) matrices using linear projections of the input.
- Reshape and transpose these matrices to allow multi-head attention.
- Apply scaled dot-product attention with a causal mask to ensure no future tokens are attended to.
- Aggregate the output of all attention heads and apply a final linear projection.
```python
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
```
##### 3. **Multi-Layer Perceptron (MLP)**
The MLP component consists of two linear transformations with a GELU activation function in between. This component is responsible for applying non-linear transformations to the input.
```python
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = nn.GELU(approximate='tanh')
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
```
##### 4. **GPT Block**
A GPT block integrates the components described above, forming the core unit of the GPT-2 model.
```python
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
```
##### 5. **GPT Model**
The GPT model stacks multiple blocks and includes token and position embeddings. The final layer is a linear projection to the vocabulary size for token prediction.
```python
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'NANOGPT_SCALE_INIT'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer.wpe(pos)
tok_emb = self.transformer.wte(idx)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@classmethod
def from_pretrained(cls, model_type):
from transformers import GPT2LMHeadModel
print("loading weights from pretrained gpt: %s" % model_type)
config_args = {
'gpt2': dict(n_layer=12, n_head=12, n_embd=768),
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024),
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280),
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600),
}[model_type]
config_args['vocab_size'] = 50257
config_args['block_size'] = 1024
config = GPTConfig(**config_args)
model = GPT(config)
sd = model.state_dict()
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()
sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
if any(k.endswith(w) for w in transposed):
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])
return model
def configure_optimizers(self, weight_decay, learning_rate, device_type):
param_dict = {pn: p for pn, p in self.named_parameters()}
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params =
[p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
if master_process:
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == "cuda"
if master_process:
print(f"using fused AdamW: {use_fused}")
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
return optimizer
```
#### Training Process
##### 1. **Data Loading**
We use a custom DataLoaderLite class to load and preprocess training and validation data. The data is loaded in shards, with each shard being a chunk of tokenized text. The data loader ensures that each process gets its portion of the data in a distributed training setup.
```python
def load_tokens(filename):
npt = np.load(filename)
npt = npt.astype(np.int32)
ptt = torch.tensor(npt, dtype=torch.long)
return ptt
class DataLoaderLite:
def __init__(self, B, T, process_rank, num_processes, split):
self.B = B
self.T = T
self.process_rank = process_rank
self.num_processes = num_processes
assert split in {'train', 'val'}
data_root = "edu_fineweb10B"
shards = os.listdir(data_root)
shards = [s for s in shards if split in s]
shards = sorted(shards)
shards = [os.path.join(data_root, s) for s in shards]
self.shards = shards
assert len(shards) > 0, f"no shards found for split {split}"
if master_process:
print(f"found {len(shards)} shards for split {split}")
self.reset()
def reset(self):
self.current_shard = 0
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position : self.current_position+B*T+1]
x = (buf[:-1]).view(B, T)
y = (buf[1:]).view(B, T)
self.current_position += B * T * self.num_processes
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.current_shard = (self.current_shard + 1) % len(self.shards)
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = B * T * self.process_rank
return x, y
```
##### 2. **Training Loop**
The training loop iterates through the data, performs forward and backward passes, updates model parameters, and logs training progress. We periodically evaluate the model on validation data and HellaSwag benchmark.
```python
for step in range(max_steps):
t0 = time.time()
last_step = (step == max_steps - 1)
if step % 250 == 0 or last_step:
model.eval()
val_loader.reset()
with torch.no_grad():
val_loss_accum = 0.0
val_loss_steps = 20
for _ in range(val_loss_steps):
x, y = val_loader.next_batch()
x, y = x.to(device), y.to(device)
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / val_loss_steps
val_loss_accum += loss.detach()
if ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
if master_process:
print(f"validation loss: {val_loss_accum.item():.4f}")
with open(log_file, "a") as f:
f.write(f"{step} val {val_loss_accum.item():.4f}\n")
if step > 0 and (step % 5000 == 0 or last_step):
checkpoint_path = os.path.join(log_dir, f"model_{step:05d}.pt")
checkpoint = {
'model': raw_model.state_dict(),
'config': raw_model.config,
'step': step,
'val_loss': val_loss_accum.item()
}
torch.save(checkpoint, checkpoint_path)
if (step % 250 == 0 or last_step) and (not use_compile):
num_correct_norm = 0
num_total = 0
for i, example in enumerate(iterate_examples("val")):
if i % ddp_world_size != ddp_rank:
continue
_, tokens, mask, label = render_example(example)
tokens = tokens.to(device)
mask = mask.to(device)
with torch.no_grad():
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(tokens)
pred_norm = get_most_likely_row(tokens, mask, logits)
num_total += 1
num_correct_norm += int(pred_norm == label)
if ddp:
num_total = torch.tensor(num_total, dtype=torch.long, device=device)
num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
num_total = num_total.item()
num_correct_norm = num_correct_norm.item()
acc_norm = num_correct_norm / num_total
if master_process:
print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
with open(log_file, "a") as f:
f.write(f"{step} hella {acc_norm:.4f}\n")
if ((step > 0 and step % 250 == 0) or last_step) and (not use_compile):
model.eval()
num_return_sequences = 4
max_length = 32
tokens = enc.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
xgen = tokens.to(device)
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(42 + ddp_rank)
while xgen.size(1) < max_length:
with torch.no_grad():
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(xgen)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
xcol = torch.gather(topk_indices, -1, ix)
xgen = torch.cat((xgen, xcol), dim=1)
for i in range(num_return_sequences):
tokens = xgen[i, :max_length].tolist()
decoded = enc.decode(tokens)
print(f"rank {ddp_rank} sample {i}: {decoded}")
model.train()
optimizer.zero_grad()
loss_accum = 0.0
for micro_step in range(grad_accum_steps):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
if ddp:
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
logits, loss = model(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
loss.backward()
if ddp:
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
if device_type == "cuda":
torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
tokens_per_sec = tokens_processed / dt
if master_process:
print(f"step {step:5d} | loss: {loss_accum.item():.6f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec
:.2f}")
with open(log_file, "a"):
f.write(f"{step} train {loss_accum.item():.6f}\n")
```
##### 3. **Evaluation**
We evaluate the model periodically on the validation set and using the HellaSwag benchmark to track its performance. These evaluations help ensure that the model generalizes well and identifies overfitting early.
##### 4. **Optimization**
The optimization strategy employs AdamW with weight decay, gradient clipping, and a learning rate schedule with warmup and cosine decay.
```python
optimizer = raw_model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device_type)
```
#### Conclusion
Our PyTorch implementation of GPT-2 maintains the architecture and performance characteristics of the original model. By leveraging PyTorch's capabilities, we provide a flexible and efficient training and inference framework for large-scale language models. This implementation can serve as a foundation for further research and development in the field of natural language processing.
This detailed explanation covers the architecture, training process, and key components of the PyTorch-based GPT-2 implementation.
References: [Neural Networks: Zero to Hero, by Andrej Karpathy](https://karpathy.ai/zero-to-hero.html)