LLM From Scratch: Part 1

LLM
Author

Ankur Singh

Published

August 6, 2025

Open In Colab

To create a minimal GPT-style model, you need these core components:

Each component will be implemented simply for clarity. Later notebooks will introduce improvements and optimizations.

!pip install -Uq torch
!pip install -Uq datasets tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Misc
import math
import tiktoken
from tqdm.notebook import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from prettytable import PrettyTable

Model Architecture

We will start by first defining the model architecture and try to generate some text to make sure everything is working as expected

class MultiheadAttention(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        assert emb_dim % heads == 0, "`emb_dim` should be a multiple of `heads`"
        self.context = context
        self.mha = nn.MultiheadAttention(emb_dim, heads, batch_first=True)
        self.proj = nn.Linear(emb_dim, emb_dim)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context, context), diagonal=1).bool()
        )

    def forward(self, x):
        batch, seq_len, _ = x.shape
        seq_len = min(seq_len, self.context)
        attn_mask = self.mask[:seq_len, :seq_len]
        attn_out, _ = self.mha(x, x, x, attn_mask=attn_mask, need_weights=False)
        return self.proj(attn_out)


class Block(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        self.mha = MultiheadAttention(emb_dim, heads, context)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim), nn.GELU(), nn.Linear(4 * emb_dim, emb_dim)
        )
        self.sa_norm = nn.LayerNorm(emb_dim)
        self.mlp_norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.sa_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pos_embedding = nn.Embedding(config.context, config.emb_dim)
        self.tok_embedding = nn.Embedding(config.vocab, config.emb_dim)
        self.decoder = nn.Sequential(
            *[
                Block(config.emb_dim, config.heads, config.context)
                for _ in range(config.layers)
            ]
        )
        self.output = nn.Linear(config.emb_dim, config.vocab, bias=False)
        self.norm = nn.LayerNorm(config.emb_dim)

    def forward(self, x):
        batch, seq_len = x.shape
        pos = torch.arange(seq_len, device=x.device)
        x = self.tok_embedding(x) + self.pos_embedding(pos)
        x = self.decoder(x)
        return self.output(self.norm(x))

    @property
    def device(self):
        return next(self.parameters()).device
@dataclass
class ModelConfig:
    # GPT2 architecture
    vocab: int = math.ceil(50_257 / 64) * 64  # nearest multiple of 64
    emb_dim: int = 768
    heads: int = 12
    layers: int = 12
    context: int = 1024


device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT(ModelConfig)
model = model.to(device)
# Utility Function: Number of Trainable Parameters
def count_parameters(model, verbose=False):
    if verbose:
        table = PrettyTable(["Module", "Parameters"])
        total = 0
        for name, param in model.named_parameters():
            if param.requires_grad:
                count = param.numel()
                table.add_row([name, count])
                total += count
        print(table)
    else:
        total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Trainable Params: {total / 1e6:.2f} M")


count_parameters(model)

Based on my calculations, this looks good.

Note: This is not exactly save as GPT2 (124M). That is because of no weight-tying and other small difference. Read this to learn more about weight tying.

Inference (Next Token Generation)

tokenizer = tiktoken.get_encoding("gpt2")
@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10):
    model.eval()
    token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits = model(token_ids)
        logits = logits[:, -1, :]
        next_idx = torch.argmax(logits, dim=-1, keepdim=True)
        prefix += tokenizer.decode([next_idx.cpu()])
        token_ids = torch.cat((token_ids, next_idx), dim=1)
    return prefix


prefix = "Once upon a time"
print(generate(model, tokenizer, prefix))

The generated text is gibberish because the model is not trained yet.

Note: You will get the same output each time you run the cell, since there is no randomness in the sampling process. The model is initialized with random weights. To get different outputs, you must reinitialize the model.

Data

dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
dataset
val_ds = "\n\n".join(dataset["test"]["text"])
train_ds = "\n\n".join(dataset["train"]["text"])

val_tokens = tokenizer.encode(val_ds)
train_tokens = tokenizer.encode(train_ds)
len(val_tokens), len(train_tokens)
class WikiTextDataset(Dataset):
    def __init__(self, tokens, max_len):
        self.tokens = tokens
        self.max_len = max_len

    def __getitem__(self, idx):
        idx = idx * self.max_len
        x = self.tokens[idx : idx + self.max_len]
        y = self.tokens[idx + 1 : idx + 1 + self.max_len]
        if len(x) < self.max_len:
            x = x + [tokenizer.eot_token] * (self.max_len - len(x))
        if len(y) < self.max_len:
            y = y + [tokenizer.eot_token] * (self.max_len - len(y))
        return (torch.tensor(x), torch.tensor(y))

    def __len__(self):
        return math.ceil(len(self.tokens) / self.max_len)


val_ds = WikiTextDataset(val_tokens, ModelConfig.context)
train_ds = WikiTextDataset(train_tokens, ModelConfig.context)
len(val_ds), len(train_ds)
batch_size = 6
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
next(iter(val_dl))

Training Loop

@torch.no_grad()
def evaluate(model, dl):
    model.eval()
    loss = 0
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss += F.cross_entropy(logits.flatten(0, 1), y.flatten()).cpu().item()
    model.train()
    return loss / len(dl)
model.to(device)
evaluate(model, val_dl)

This looks correct. Initially, the probability will be evenly distributed, i.e., each token will roughly have the same probability. As a result, we can calculate the expected value of the loss: -ln(1/50304) ≈ 10.826.

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
log_freq = 40
epochs = 2
losses = []

for epoch in range(epochs):
    for i, (x, y) in enumerate(pbar := tqdm(train_dl, desc="Training")):
        if i % log_freq == 0:
            val_loss = evaluate(model, val_dl)
            losses.append(val_loss)
            pbar.set_postfix_str(f"[Epoch {epoch}] Val Loss: {val_loss:.3f}")
            torch.save(model.state_dict(), "model.pth")
            print("=" * 20)
            print(generate(model, tokenizer, prefix))

        model.train()
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.flatten(0, 1), y.flatten())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Inference

Lets load the saved model and try to generate some sample text . . .

state_dict = torch.load("model.pth", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
print(generate(model, tokenizer, "Once upon a time"))
print(generate(model, tokenizer, "Internet is an"))
print(generate(model, tokenizer, "AI will"))
print(generate(model, tokenizer, "The meaning of life is"))

The generated text is not very good. Let’s add some randomness.

Instead of always picking the highest probability, we can sample the next token index from the probability distribution. This involves two steps: 1. Convert logits to probabilities. 2. Sample the next token index from this distribution.

Note: Sampling adds randomness, so you will see different outputs each time you run the cell.

@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10):
    model.eval()
    token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits = model(token_ids)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)  # <-- update
        next_idx = torch.multinomial(probs, num_samples=1)  # <-- update
        prefix += tokenizer.decode([next_idx])
        token_ids = torch.cat((token_ids, next_idx), dim=1)
    return prefix
print(generate(model, tokenizer, "Once upon a time"))
print(generate(model, tokenizer, "Internet is an"))
print(generate(model, tokenizer, "AI will"))
print(generate(model, tokenizer, "The meaning of life is"))

temperature is a useful parameter that controls how sharp or flat the softmax output is.

@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10, temp=1.0):
    model.eval()
    token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits = model(token_ids)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits / temp, dim=-1)  # <-- update: scale using temp
        next_idx = torch.multinomial(probs, num_samples=1)
        prefix += tokenizer.decode([next_idx])
        token_ids = torch.cat((token_ids, next_idx), dim=1)
    return prefix

Here’s why temperature affects creativity:

import matplotlib.pyplot as plt

logits = torch.randn(4, 32)
plt.plot(torch.softmax(logits[0], dim=-1), label="No Temperature")
plt.plot(torch.softmax(logits[0] / 0.5, dim=-1), label="0.5 Temperature")
plt.plot(torch.softmax(logits[0] / 4, dim=-1), label="2 Temperature")
plt.legend()

As shown: - If temperature is low (< 1), softmax is sharp and only a few tokens have high probability. - If temperature is high (> 1), softmax is flatter and more tokens have similar probabilities.

Another improvement is to sample only from the top-k probabilities.

def topk(logits, k=5):
    topk_vals, topk_idxs = torch.topk(logits, k)
    probs = torch.zeros_like(logits)
    probs[:, topk_idxs] = torch.softmax(topk_vals, dim=-1)
    return probs


@torch.no_grad()
def generate(model, tokenizer, prefix, max_new_tokens=10, temp=1.0):
    model.eval()
    token_ids = torch.tensor(tokenizer.encode(prefix), device=device).unsqueeze(0)
    for _ in range(max_new_tokens):
        logits = model(token_ids)
        logits = logits[:, -1, :]
        probs = topk(logits / temp)  # <-- update: only `topk` probabilities
        next_idx = torch.multinomial(probs, num_samples=1)
        prefix += tokenizer.decode([next_idx])
        token_ids = torch.cat((token_ids, next_idx), dim=1)
    return prefix

Let’s try these improvements and see how the generated text changes.

print(generate(model, tokenizer, "Once upon a time"))
print(generate(model, tokenizer, "Internet is an"))
print(generate(model, tokenizer, "AI will"))
print(generate(model, tokenizer, "The meaning of life is"))