!pip install -Uq torch
!pip install -Uq datasets tiktokenTo create a minimal GPT-style model, you need these core components:
- Model Architecture: Defines how tokens are processed and contextual relationships are modeled.
- Inference (Next Token Generation): Uses the trained model to generate the next token from input tokens.
- Training Data: A tokenized text dataset for training the model.
- Training Loop: Iteratively updates model parameters to minimize prediction error.
Each component will be implemented simply for clarity. Later notebooks will introduce improvements and optimizations.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# Misc
import math
import tiktoken
from tqdm.notebook import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from prettytable import PrettyTableModel Architecture
We will start by first defining the model architecture and try to generate some text to make sure everything is working as expected
class MultiheadAttention(nn.Module):
def __init__(self, emb_dim, heads, context):
super().__init__()
assert emb_dim % heads == 0, "`emb_dim` should be a multiple of `heads`"
self.context = context
self.mha = nn.MultiheadAttention(emb_dim, heads, batch_first=True)
self.proj = nn.Linear(emb_dim, emb_dim)
self.register_buffer(
"mask", torch.triu(torch.ones(context, context), diagonal=1).bool()
)
def forward(self, x):
batch, seq_len, _ = x.shape
seq_len = min(seq_len, self.context)
attn_mask = self.mask[:seq_len, :seq_len]
attn_out, _ = self.mha(x, x, x, attn_mask=attn_mask, need_weights=False)
return self.proj(attn_out)
class Block(nn.Module):
def __init__(self, emb_dim, heads, context):
super().__init__()
self.mha = MultiheadAttention(emb_dim, heads, context)
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 4 * emb_dim), nn.GELU(), nn.Linear(4 * emb_dim, emb_dim)
)
self.sa_norm = nn.LayerNorm(emb_dim)
self.mlp_norm = nn.LayerNorm(emb_dim)
def forward(self, x):
x = x + self.mha(self.sa_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.pos_embedding = nn.Embedding(config.context, config.emb_dim)
self.tok_embedding = nn.Embedding(config.vocab, config.emb_dim)
self.decoder = nn.Sequential(
*[
Block(config.emb_dim, config.heads, config.context)
for _ in range(config.layers)
]
)
self.output = nn.Linear(config.emb_dim, config.vocab, bias=False)
self.norm = nn.LayerNorm(config.emb_dim)
def forward(self, x):
batch, seq_len = x.shape
pos = torch.arange(seq_len, device=x.device)
x = self.tok_embedding(x) + self.pos_embedding(pos)
x = self.decoder(x)
return self.output(self.norm(x))
@property
def device(self):
return next(self.parameters()).device@dataclass
class ModelConfig:
# GPT2 architecture
vocab: int = math.ceil(50_257 / 64) * 64 # nearest multiple of 64
emb_dim: int = 768
heads: int = 12
layers: int = 12
context: int = 1024
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT(ModelConfig)
model = model.to(device)# Utility Function: Number of Trainable Parameters
def count_parameters(model, verbose=False):
if verbose:
table = PrettyTable(["Module", "Parameters"])
total = 0
for name, param in model.named_parameters():
if param.requires_grad:
count = param.numel()
table.add_row([name, count])
total += count
print(table)
else:
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Params: {total / 1e6:.2f} M")
count_parameters(model)Based on my calculations, this looks good.
Note: This is not exactly save as GPT2 (124M). That is because of no weight-tying and other small difference. Read this to learn more about weight tying.
Inference (Next Token Generation)
tokenizer = tiktoken.get_encoding("gpt2")@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10):
model.eval()
token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
for _ in range(max_new_tokens):
logits = model(token_ids)
logits = logits[:, -1, :]
next_idx = torch.argmax(logits, dim=-1, keepdim=True)
prefix += tokenizer.decode([next_idx.cpu()])
token_ids = torch.cat((token_ids, next_idx), dim=1)
return prefix
prefix = "Once upon a time"
print(generate(model, tokenizer, prefix))The generated text is gibberish because the model is not trained yet.
Note: You will get the same output each time you run the cell, since there is no randomness in the sampling process. The model is initialized with random weights. To get different outputs, you must reinitialize the model.
Data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
datasetval_ds = "\n\n".join(dataset["test"]["text"])
train_ds = "\n\n".join(dataset["train"]["text"])
val_tokens = tokenizer.encode(val_ds)
train_tokens = tokenizer.encode(train_ds)
len(val_tokens), len(train_tokens)class WikiTextDataset(Dataset):
def __init__(self, tokens, max_len):
self.tokens = tokens
self.max_len = max_len
def __getitem__(self, idx):
idx = idx * self.max_len
x = self.tokens[idx : idx + self.max_len]
y = self.tokens[idx + 1 : idx + 1 + self.max_len]
if len(x) < self.max_len:
x = x + [tokenizer.eot_token] * (self.max_len - len(x))
if len(y) < self.max_len:
y = y + [tokenizer.eot_token] * (self.max_len - len(y))
return (torch.tensor(x), torch.tensor(y))
def __len__(self):
return math.ceil(len(self.tokens) / self.max_len)
val_ds = WikiTextDataset(val_tokens, ModelConfig.context)
train_ds = WikiTextDataset(train_tokens, ModelConfig.context)
len(val_ds), len(train_ds)batch_size = 6
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)next(iter(val_dl))Training Loop
@torch.no_grad()
def evaluate(model, dl):
model.eval()
loss = 0
for x, y in dl:
x, y = x.to(device), y.to(device)
logits = model(x)
loss += F.cross_entropy(logits.flatten(0, 1), y.flatten()).cpu().item()
model.train()
return loss / len(dl)model.to(device)
evaluate(model, val_dl)This looks correct. Initially, the probability will be evenly distributed, i.e., each token will roughly have the same probability. As a result, we can calculate the expected value of the loss: -ln(1/50304) ≈ 10.826.
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)log_freq = 40
epochs = 2
losses = []
for epoch in range(epochs):
for i, (x, y) in enumerate(pbar := tqdm(train_dl, desc="Training")):
if i % log_freq == 0:
val_loss = evaluate(model, val_dl)
losses.append(val_loss)
pbar.set_postfix_str(f"[Epoch {epoch}] Val Loss: {val_loss:.3f}")
torch.save(model.state_dict(), "model.pth")
print("=" * 20)
print(generate(model, tokenizer, prefix))
model.train()
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits.flatten(0, 1), y.flatten())
optimizer.zero_grad()
loss.backward()
optimizer.step()Inference
Lets load the saved model and try to generate some sample text . . .
state_dict = torch.load("model.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)print(generate(model, tokenizer, "Once upon a time"))print(generate(model, tokenizer, "Internet is an"))print(generate(model, tokenizer, "AI will"))print(generate(model, tokenizer, "The meaning of life is"))The generated text is not very good. Let’s add some randomness.
Instead of always picking the highest probability, we can sample the next token index from the probability distribution. This involves two steps: 1. Convert logits to probabilities. 2. Sample the next token index from this distribution.
Note: Sampling adds randomness, so you will see different outputs each time you run the cell.
@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10):
model.eval()
token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
for _ in range(max_new_tokens):
logits = model(token_ids)
logits = logits[:, -1, :]
probs = torch.softmax(logits, dim=-1) # <-- update
next_idx = torch.multinomial(probs, num_samples=1) # <-- update
prefix += tokenizer.decode([next_idx])
token_ids = torch.cat((token_ids, next_idx), dim=1)
return prefixprint(generate(model, tokenizer, "Once upon a time"))print(generate(model, tokenizer, "Internet is an"))print(generate(model, tokenizer, "AI will"))print(generate(model, tokenizer, "The meaning of life is"))temperature is a useful parameter that controls how sharp or flat the softmax output is.
@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10, temp=1.0):
model.eval()
token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
for _ in range(max_new_tokens):
logits = model(token_ids)
logits = logits[:, -1, :]
probs = torch.softmax(logits / temp, dim=-1) # <-- update: scale using temp
next_idx = torch.multinomial(probs, num_samples=1)
prefix += tokenizer.decode([next_idx])
token_ids = torch.cat((token_ids, next_idx), dim=1)
return prefixHere’s why temperature affects creativity:
import matplotlib.pyplot as plt
logits = torch.randn(4, 32)
plt.plot(torch.softmax(logits[0], dim=-1), label="No Temperature")
plt.plot(torch.softmax(logits[0] / 0.5, dim=-1), label="0.5 Temperature")
plt.plot(torch.softmax(logits[0] / 4, dim=-1), label="2 Temperature")
plt.legend()As shown: - If temperature is low (< 1), softmax is sharp and only a few tokens have high probability. - If temperature is high (> 1), softmax is flatter and more tokens have similar probabilities.
Another improvement is to sample only from the top-k probabilities.
def topk(logits, k=5):
topk_vals, topk_idxs = torch.topk(logits, k)
probs = torch.zeros_like(logits)
probs[:, topk_idxs] = torch.softmax(topk_vals, dim=-1)
return probs
@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10, temp=1.0):
model.eval()
token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
for _ in range(max_new_tokens):
logits = model(token_ids)
logits = logits[:, -1, :]
probs = topk(logits / temp) # <-- update: only `topk` probabilities
next_idx = torch.multinomial(probs, num_samples=1)
prefix += tokenizer.decode([next_idx])
token_ids = torch.cat((token_ids, next_idx), dim=1)
return prefixLet’s try these improvements and see how the generated text changes.
print(generate(model, tokenizer, "Once upon a time"))print(generate(model, tokenizer, "Internet is an"))print(generate(model, tokenizer, "AI will"))print(generate(model, tokenizer, "The meaning of life is"))