LLM From Scratch: Part 2

LLM
Author

Ankur Singh

Published

August 15, 2025

Open In Colab

In this notebook, we incorporate architectural improvements to the LLM from the previous part. While code changes are minor, understanding the rationale is crucial. We’ll discuss these changes and provide helpful resources:

Setup

The model architecture is copied directly from Part 1.

!pip install -Uq torch
!pip install -Uq datasets tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Misc
import math
import tiktoken
from tqdm.notebook import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from prettytable import PrettyTable
class MultiheadAttention(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        assert emb_dim % heads == 0, "`emb_dim` should be a multiple of `heads`"
        self.context = context
        self.mha = nn.MultiheadAttention(emb_dim, heads, batch_first=True, bias=False)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context, context), diagonal=1).bool()
        )

    def forward(self, x):
        batch, seq_len, _ = x.shape
        seq_len = min(seq_len, self.context)
        attn_mask = self.mask[:seq_len, :seq_len]
        attn_out, _ = self.mha(x, x, x, attn_mask=attn_mask, need_weights=False)
        return attn_out


class Block(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        self.mha = MultiheadAttention(emb_dim, heads, context)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim), nn.GELU(), nn.Linear(4 * emb_dim, emb_dim)
        )
        self.mha_norm = nn.LayerNorm(emb_dim)
        self.mlp_norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.mha_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pos_embedding = nn.Embedding(config.context, config.emb_dim)
        self.tok_embedding = nn.Embedding(config.vocab, config.emb_dim)
        self.decoder = nn.Sequential(
            *[
                Block(config.emb_dim, config.heads, config.context)
                for _ in range(config.layers)
            ]
        )
        self.output = nn.Linear(config.emb_dim, config.vocab, bias=False)
        self.norm = nn.LayerNorm(config.emb_dim)

    def forward(self, x):
        batch, seq_len = x.shape
        pos = torch.arange(seq_len, device=x.device)
        x = self.tok_embedding(x) + self.pos_embedding(pos)
        x = self.decoder(x)
        return self.output(self.norm(x))

    @property
    def device(self):
        return next(self.parameters()).device
# Utility Function: Number of Trainable Parameters
def count_parameters(model, verbose=False):
    if verbose:
        table = PrettyTable(["Module", "Parameters"])
        total = 0
        for name, param in model.named_parameters():
            if param.requires_grad:
                count = param.numel()
                table.add_row([name, count])
                total += count
        print(table)
    else:
        total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Trainable Params: {total / 1e6:.2f} M")
device = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class ModelConfig:
    # GPT2 architecture
    vocab: int = math.ceil(50_257 / 64) * 64  # nearest multiple of 64
    emb_dim: int = 768
    heads: int = 12
    layers: int = 12
    context: int = 1024
model = GPT(ModelConfig)
model = model.to(device)
count_parameters(model)

1. RMSNorm

This is a simple change: replace nn.LayerNorm with nn.RMSNorm, which applies Root Mean Square Layer Normalization over a mini-batch of inputs.

2. Post / Pre Normalization

The original Transformer (from the “Attention Is All You Need” paper) placed normalization layers after the attention and feedforward modules—this is known as Post-Norm.

GPT and most later LLMs use Pre-Norm, placing normalization layers before these modules.

In 2020, Xiong et al. showed that Pre-Norm leads to more stable gradients at initialization and can perform well without careful learning rate warm-up—unlike Post-Norm.

From an implementation standpoint, the change is minor: just swap the order of operations.

Note: We’re already using Pre-Norm here, so no code changes are needed.

# Post Norm
class BlockPostNorm(Block):
    def forward(self, x):
        x = x + self.mha_norm(self.mha(x))
        x = x + self.mlp_norm(self.nlp(x))


# Pre Norm
class BlockPostNorm(Block):
    def forward(self, x):
        x = x + self.mha(self.mha_norm(x))
        x = x + self.mlp(self.nlp_norm(x))

3. SwiGLU

Again, this is a small change to the FFN in the transformer block. Instead of a two-layer feedforward network with GeLU activation, we use SwiGLU.

It’s a short and easy-to-read paper: GLU Variants Improve Transformer

class GatedFFN(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.w1 = nn.Linear(emb_dim, 4 * int(2 / 3 * emb_dim), bias=False)
        self.w3 = nn.Linear(emb_dim, 4 * int(2 / 3 * emb_dim), bias=False)
        self.w2 = nn.Linear(4 * int(2 / 3 * emb_dim), emb_dim, bias=False)
        self.silu_act = nn.SiLU()

    def forward(self, x):
        x = self.silu_act(self.w1(x)) * self.w3(x)
        return self.w2(x)

Here’s an excerpt from the paper explaining the rationale behind the 2/3 scaling factor. However, you can choose to ignore it—it was used mainly for apples-to-apples comparison in the paper.

image.png

At this point, the updated Transformer block (aka Block class) looks like this:

class Block(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        self.mha = MultiheadAttention(emb_dim, heads, context)
        self.mlp = GatedFFN(emb_dim)  # <--- update
        self.mha_norm = nn.RMSNorm(emb_dim)  # <--- update
        self.mlp_norm = nn.RMSNorm(emb_dim)  # <--- update

    def forward(self, x):
        x = x + self.mha(self.mha_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

4. RoPE

In the basic implementation, we use absolute position embeddings, as introduced in the original paper. However, things have evolved, and the current norm is to use relative position embeddings. Among various approaches, RoPE has become the dominant choice.

Incorporating RoPE requires more than just a few line changes. It might seem overwhelming, but it’s quite simple once you understand it. We essentially apply a transformation to our Q and K vectors before the attention operation.

Let’s first define the transformation.

def compute_freqs(dim, context_length, base=10_000):
    assert dim % 2 == 0, "Embedding dimension should be even"
    inv_freq = 1.0 / (
        base ** (torch.arange(0, dim, 2).float() / dim)
    )  # shape: (1, dim//2)
    pos_ids = torch.arange(context_length)  # shape: (context_len)
    thetas = pos_ids.unsqueeze(1) * inv_freq  # shape: (context_len, dim//2)
    thetas = torch.cat([thetas, thetas], dim=1)  # shape: (context_len, dim)
    return thetas.cos(), thetas.sin()


def apply_rope(x, cos, sin):
    batch_size, heads, seq_len, emb_dim = x.shape
    x1, x2 = x[..., : emb_dim // 2], x[..., emb_dim // 2 :]
    rotated = torch.cat([-x2, x1], dim=-1)
    cos, sin = cos[:seq_len, :], sin[:seq_len, :]
    x_rotated = (x * cos) + (rotated * sin)
    return x_rotated.to(dtype=x.dtype)

If the above code looks cryptic, please refer to my detailed notebook on RoPE. It implements RoPE in gradual steps to build better intuition, includes visuals, and compares different implementations.

Note: Here, we follow the HuggingFace implementation of RoPE, as I plan to load the model weights from HF.

An important point about RoPE: unlike positional embeddings added once at the beginning, RoPE is applied in every transformer block.

class MultiheadAttention(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        assert emb_dim % heads == 0, "`emb_dim` should be a multiple of `heads`"
        self.heads = heads
        self.head_dim = emb_dim // heads
        self.qkv_proj = nn.Linear(emb_dim, 3 * emb_dim, bias=False)
        self.out_proj = nn.Linear(emb_dim, emb_dim, bias=False)
        # RoPE and Casual Mask
        cos, sin = compute_freqs(self.head_dim, context)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context, context), diagonal=1).bool()
        )

    def forward(self, x):
        batch, seq_len, emb_dim = x.shape
        qkv = self.qkv_proj(x)
        qkv = qkv.view(batch, seq_len, 3, self.heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq_len, dim)
        q, k = apply_rope(q, self.cos, self.sin), apply_rope(k, self.cos, self.sin)
        attn = (q @ k.mT) / (self.head_dim**0.5)
        mask = self.mask[:seq_len, :seq_len]
        attn = attn.masked_fill(mask, float("-inf"))
        attn_out = torch.softmax(attn, dim=-1) @ v
        attn_out = attn_out.transpose(1, 2)
        attn_out = attn_out.reshape(batch, seq_len, -1)
        return self.out_proj(attn_out)


class Block(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        self.mha = MultiheadAttention(emb_dim, heads, context)
        self.mlp = GatedFFN(emb_dim)
        self.mha_norm = nn.RMSNorm(emb_dim)
        self.mlp_norm = nn.RMSNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.mha_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # self.pos_embedding = nn.Embedding(config.context, config.emb_dim) # <--- update
        self.tok_embedding = nn.Embedding(config.vocab, config.emb_dim)
        self.decoder = nn.Sequential(
            *[
                Block(config.emb_dim, config.heads, config.context)
                for _ in range(config.layers)
            ]
        )
        self.output = nn.Linear(config.emb_dim, config.vocab, bias=False)
        self.norm = nn.RMSNorm(config.emb_dim)  # <--- update

    def forward(self, x):
        batch, seq_len = x.shape
        # pos = torch.arange(seq_len, device=x.device)         # <--- update
        x = self.tok_embedding(x)  # + self.pos_embedding(pos) # <--- update
        x = self.decoder(x)
        return self.output(self.norm(x))

    @property
    def device(self):
        return next(self.parameters()).device
model = GPT(ModelConfig)
model = model.to(device)
count_parameters(model)

We have about 0.85 million fewer parameters, mainly by removing pos_embedding. We also eliminated bias in GatedFFN, and unlike LayerNorm, RMSNorm has no bias.

With these changes, let’s train the model and see if we get better scores and, consequently, better generations.

Training

tokenizer = tiktoken.get_encoding("gpt2")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

val_ds = "\n\n".join(dataset["test"]["text"])
train_ds = "\n\n".join(dataset["train"]["text"])

val_tokens = tokenizer.encode(val_ds)
train_tokens = tokenizer.encode(train_ds)
len(val_tokens), len(train_tokens)
class WikiTextDataset(Dataset):
    def __init__(self, tokens, max_len):
        self.tokens = tokens
        self.max_len = max_len

    def __getitem__(self, idx):
        idx = idx * self.max_len
        x = self.tokens[idx : idx + self.max_len]
        y = self.tokens[idx + 1 : idx + 1 + self.max_len]
        if len(x) < self.max_len:
            x = x + [tokenizer.eot_token] * (self.max_len - len(x))
        if len(y) < self.max_len:
            y = y + [tokenizer.eot_token] * (self.max_len - len(y))
        return (torch.tensor(x), torch.tensor(y))

    def __len__(self):
        return math.ceil(len(self.tokens) / self.max_len)


batch_size = 4
val_ds = WikiTextDataset(val_tokens, ModelConfig.context)
train_ds = WikiTextDataset(train_tokens, ModelConfig.context)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10, temp=1.0):
    model.eval()
    token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits = model(token_ids)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits / temp, dim=-1)  # <-- update: scale using temp
        next_idx = torch.multinomial(probs, num_samples=1)
        prefix += tokenizer.decode([next_idx])
        token_ids = torch.cat((token_ids, next_idx), dim=1)
    return prefix


@torch.no_grad()
def evaluate(model, dl):
    model.eval()
    loss = 0
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss += F.cross_entropy(logits.flatten(0, 1), y.flatten()).cpu().item()
    model.train()
    return loss / len(dl)


model.to(device)
evaluate(model, val_dl)
prefix = "Once upon a time"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
log_freq = 40
epochs = 2
losses = []

for epoch in range(epochs):
    for i, (x, y) in enumerate(pbar := tqdm(train_dl, desc="Training")):
        if i % log_freq == 0:
            val_loss = evaluate(model, val_dl)
            losses.append(val_loss)
            pbar.set_postfix_str(f"[Epoch {epoch}] Val Loss: {val_loss:.3f}")
            torch.save(model.state_dict(), "model.pth")
            print("=" * 20)
            print(generate(model, tokenizer, prefix))

        model.train()
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.flatten(0, 1), y.flatten())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
state_dict = torch.load("model.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
print(generate(model, tokenizer, "Once upon a time"))
print("=" * 15)
print(generate(model, tokenizer, "Internet is an"))
print("=" * 15)
print(generate(model, tokenizer, "AI will"))
print("=" * 15)
print(generate(model, tokenizer, "The meaning of life is"))

Gradient Accumulation

The generations are still not very good. Generally, larger batches help when training LLMs, but we’re limited by memory here.

Gradient Accumulation is a simple way to effectively increase the batch size.

model = GPT(ModelConfig)
model = model.to(device)
prefix = "Once upon a time"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
log_freq = 40
epochs = 2
losses = []
accumulate = 4

for epoch in range(epochs):
    for i, (x, y) in enumerate(pbar := tqdm(train_dl, desc="Training")):
        if i % log_freq == 0:
            val_loss = evaluate(model, val_dl)
            losses.append(val_loss)
            pbar.set_postfix_str(f"[Epoch {epoch}] Val Loss: {val_loss:.3f}")
            torch.save(model.state_dict(), "model.pth")
            print("=" * 20)
            print(generate(model, tokenizer, prefix))

        model.train()
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.flatten(0, 1), y.flatten())
        loss /= accumulate  # <--- Update

        # Backward pass
        loss.backward()
        if (i + 1) % accumulate == 0:  # <--- Update
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
state_dict = torch.load("model.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
print(generate(model, tokenizer, "Once upon a time"))
print("=" * 15)
print(generate(model, tokenizer, "Internet is an"))
print("=" * 15)
print(generate(model, tokenizer, "AI will"))
print("=" * 15)
print(generate(model, tokenizer, "The meaning of life is"))

Automatic Mixed Precision (AMP)

AMP is another technique to speed up training by using mixed precision instead of float32. It doesn’t reduce memory much but can speed up training. See the paper.

In PyTorch, it’s easy to implement: wrap the forward pass in torch.autocast(). For the backward pass, scale the loss with scaler.scale(loss).backward() and use scaler.step(optimizer) instead of optimizer.step().

model = GPT(ModelConfig)
model = model.to(device)
@torch.no_grad()
def evaluate(model, dl):
    model.eval()
    loss = 0
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type=device):  # <--- Update
            logits = model(x)
            loss += F.cross_entropy(logits.flatten(0, 1), y.flatten()).cpu().item()
    model.train()
    return loss / len(dl)
prefix = "Once upon a time"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
log_freq = 40
epochs = 2
losses = []
accumulate = 4
scaler = torch.GradScaler()

for epoch in range(epochs):
    for i, (x, y) in enumerate(pbar := tqdm(train_dl, desc="Training")):
        if i % log_freq == 0:
            val_loss = evaluate(model, val_dl)
            losses.append(val_loss)
            pbar.set_postfix_str(f"[Epoch {epoch}] Val Loss: {val_loss:.3f}")
            torch.save(model.state_dict(), "model.pth")
            print("=" * 20)
            print(generate(model, tokenizer, prefix))

        model.train()
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type=device):  # <--- Update
            logits = model(x)
            loss = F.cross_entropy(logits.flatten(0, 1), y.flatten())
        loss /= accumulate

        # Backward pass
        scaler.scale(loss).backward()  # <--- Update
        if (i + 1) % accumulate == 0:
            scaler.step(optimizer)  # <--- Update
            scaler.update()  # <--- Update
            optimizer.zero_grad(set_to_none=True)
state_dict = torch.load("model.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
print(generate(model, tokenizer, "Once upon a time"))
print("=" * 15)
print(generate(model, tokenizer, "Internet is an"))
print("=" * 15)
print(generate(model, tokenizer, "AI will"))
print("=" * 15)
print(generate(model, tokenizer, "The meaning of life is"))

Another improvement we could add is a learning rate scheduler, but I’ll leave that for now. It’s also time to add WandB logging to track our training progress.