Rotary Position Embedding (RoPE)

LLM
RoPE
Author

Ankur Singh

Published

July 25, 2025

Open In Colab

Rotary Position Embedding (RoPE)

After listening to Umar Jamil’s talk on GPU Mode, I was really inspired. I think it was one of the best talks debunking his mindset behind learning and producing super high-quality content. I decided to apply his approach to learning and gaining a deeper understanding of various LLM components.

In another part of my life, I was working on a PR to add LongRoPE support for Phi3-mini-128k in torchtune. I referred to multiple LongRoPE implementations—both within torchtune and in external projects—but I just couldn’t get it to work.

So, after the talk, I thought: Huh 🤔, let me start with RoPE. I’ll learn it the hard way—going through the paper, implementing it from scratch, then studying other people’s implementations. Only after fully understanding it (i.e., being able to implement all of it from scratch) will I say that I truly know and understand RoPE. From there, I’ll build up to LongRoPE.

This notebook serves as a logbook, documenting everything I learn on this adventurous journey to understanding RoPE from the ground up.

Reading RoPE paper (aka RoFormer)

Following the technique outlined by Umar, I started by reading the Abstract, Introduction, and Conclusion, highlighting key points. This is a great way to gain a high-level understanding and prepares you for what to expect from the paper. Here are some important highlights from these sections:

  • RoPE encodes absolute position using a rotation matrix while incorporating explicit relative position dependency in the self-attention formulation.
  • RoPE has several valuable and desirable properties, including:
    • Sequence length flexibility
    • Decaying inter-token dependency with increasing relative distances
    • The capability to equip linear self-attention (see LinFormer) with relative position encoding
  • The key idea is to encode relative position by multiplying context representations with a rotation matrix.
  • With absolute position information encoded through a rotation matrix, relative position can be naturally formulated using vector projection in self-attention.

For the math part, it’s highly recommended to use pen and paper—write down the equations and work through the derivations. TBH, I still don’t fully understand all the mathematics in the RoFormer paper. But the process of writing things down forces your brain to pick up details that it wouldn’t through passive reading.

Here are a few minor things I picked up:

  • Absolute Position Embeddings, whether from a predefined sinusoidal function (as in the original Transformer by Vaswani et al.) or as trainable vectors, add position information to all q, k, and v vectors. Look at the subscript of function ( \(f\) ).

image.png
  • (Almost all) Relative Position Embeddings add position information to only k and v vectors.

image.png
  • Finally, RoPE adds position information to only q and k, not the v vector.

image.png

Another important difference here is that both absolute and relative methods add position encoding to the context representations. However, RoPE multiplies position encoding and context representation. This point is emphasized repeatedly in the paper, but writing it down on paper really helped solidify it for me.

These differences are subtle but quite important to fully understand and appreciate RoPE. After picking up on these subtle distinctions, I went through the paper again, looking for the authors’ motivation for making these choices. Here is an excerpt from Section 3.1 in the paper:

image.png
  • Sentence 1 & 2: Since positional information is leveraged in the self-attention mechanism, particularly in the ( \(q_m^T k_n\) ) operation, the authors decided to update only the q and k vectors with position information.
  • Sentence 3: With all the relative embedding papers, the benefits of relative position encoding were clear. So, they chose to use only relative position information during the inner product of the query and key vectors.

Looks like the authors are some sort of math ninjas—they quickly provide the solution (i.e., the mathematical transformation that meets both constraints). Here is the solution for the 2D case:

image.png

Here is the general form of the solution:

image.png

Just after presenting the general form solution, the authors quickly point out that due to sparsity, applying matrix multiplication directly is not very computationally efficient (see the last two lines in the above image). They also present a more computationally efficient realization.

image.png

Equipped with all this information, I was eager to implement RoPE in PyTorch. But before I jump to the next section, here are three important lessons when reading the paper:

  1. For theoretical sections like the abstract, introduction, and conclusion, highlight key points as you read.
  2. For the mathematics, use pen and paper. Highly recommended!
  3. Re-read the paper a few times.

Implementing RoPE

import torch
import matplotlib.pyplot as plt
# batch_size, seq_len, dim
B, S, D = 4, 8, 128
x = torch.randn(B, S, D)

# RoPE initializations
base = 10_000


def get_theta(S, D):
    pos_ids = torch.arange(S, dtype=torch.float)
    freqs = base ** (torch.arange(0, D, 2) / D)
    theta = torch.outer(pos_ids, 1.0 / freqs)
    return theta


print(f"{x.shape=}")
theta = get_theta(S, D)
print(f"{theta.shape=}")

Compute Efficient Implementation (based on Equation 34)

Below is a naive one-to-one implementation of the computationally efficient realization of the rotary matrix based on equation 34 from the paper. Let me jot down the steps:

  1. Calculate freqs (\(\theta\)) and pos_ids (\(m\))
  2. Take the outer product to get a matrix of all \(m\theta\) pairs
  3. Duplicate \(\theta\) values and rearrange them
  4. Calculate \(cos\) and \(sin\)
  5. Calculate the transformed (\(x\_half\)) vector in the second half of the equation
  6. Finally, calculate the rotated vector using \(x = (x * \cos m\theta) + (x\_half * \sin m\theta)\)
def apply_rope34(x):
    _, S, D = x.shape

    # [mθ1, mθ2, mθ3, mθ4, . . . mθ(d//2)] | Shape: [seq_len, dim//2]
    theta = get_theta(S, D)

    # [batch, seq_len, dim//2]
    sin, cos = theta.sin(), theta.cos()

    # Expand [t1, t2, t3, t4] -> [t1, t1, t2, t2, t3, t3, t4, t4]
    sin_pos = torch.stack([sin, sin], dim=-1).view(S, D)
    cos_pos = torch.stack([cos, cos], dim=-1).view(S, D)

    # 2nd term: [-x2, x1, -x4, x3, -x6, x5, -x8, x7]
    x_half = torch.stack([-x[..., 1::2], x[..., 0::2]], dim=-1).reshape_as(x)

    x_rot = x * cos_pos + x_half * sin_pos
    return x_rot
x_rotated = apply_rope34(x)
print(f"{x.shape=}")
print(f"{x_rotated.shape=}")  # same shape as `x`

The above code is similar to the RoFormer implementation in the transformers library.

Simplified 2D Implementation (based on equation 13)

image.png

If we assume that the embedding \(x\) is already multiplied with the \(Q\) and \(K\) matrices, then we can simplify the above equation as follows:

\[\begin{align} f(x_m, m) = \begin{pmatrix} \cos m\theta & \sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix} \begin{pmatrix} x_1 \\ x_2 \end{pmatrix} &= \begin{pmatrix} x_1 * \cos m\theta - x_2 * \sin m\theta \\ x_1 * \sin m\theta - x_2 * \cos m\theta \end{pmatrix} \end{align}\]

Here’s what the code implementation looks like:

def apply_rope13(x):
    _, S, D = x.shape

    # [mθ1, mθ2, mθ3, mθ4, . . . mθ(d//2)] | Shape: [seq_len, dim//2]
    theta = get_theta(S, D)

    # [batch, seq_len, dim//2]
    sin, cos = theta.sin(), theta.cos()

    # even and odd terms
    x1, x2 = x[..., 0::2], x[..., 1::2]

    # [cos_nθ, -sin_nθ] [x1]
    # [sin_nθ,  cos_nθ] [x2]
    # => [x1 * cos_nθ - x2 * sin_nθ, x1 * sin_nθ + x2 * cos_nθ]
    x_rot = torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).reshape_as(
        x
    )
    return x_rot
x_rotated = apply_rope13(x)
print(f"{x.shape=}")
print(f"{x_rotated.shape=}")  # same shape as `x`
# Both implementation should produce same result
assert torch.allclose(apply_rope34(x), apply_rope13(x)), "Not Equal"

In this approach, we don’t have to duplicate theta. Ideally, this should save some memory. The official RoFormer PyTorch repo inspired the above implementation.

Llama3.1 Implementation

Most RoPE implementations I found online makes use one of the two implementations above. But RoPE implementation in Meta Llama3.1 implements RoPE directly using equation 12 with complex number mathematics. While It might sound complex, but the code is pretty concise.

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rope_llama31(x):
    _, S, D = x.shape

    # [mθ1, mθ2, mθ3, mθ4, . . . mθ(d//2)] | Shape: [seq_len, dim//2]
    theta = get_theta(S, D)

    # constructs a complex tensor from polar coordinates
    freq_cis = torch.polar(torch.ones_like(theta), theta)

    # turn `x` into complex number
    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freq_cis = reshape_for_broadcast(freq_cis, x_)

    # Real[X * e^(imθ)]
    x_rotated = torch.view_as_real(x_ * freq_cis).flatten(-2)
    return x_rotated
x_rotated = apply_rope_llama31(x)
print(f"{x.shape=}")
print(f"{x_rotated.shape=}")  # same shape as `x`
# should be equal to other two implementations
assert torch.allclose(apply_rope_llama31(x), apply_rope13(x)), "Not Equal"
assert torch.allclose(apply_rope_llama31(x), apply_rope34(x)), "Not Equal"

The complete code can be found here: Llama3.1 RoPE implementation. Also, Umar Jamil has this amazing YouTube video where he codes Llama2 from scratch, implementing RoPE step-by-step.

The last objective was to check out the HF implementation and call it a day. But I didn’t know what I was stepping into.

HuggingFace implementation

I decided to look at the RoPE implementation in the Llama model. Their implementation was very similar to our equation 34 implementation, where we duplicated theta, calculated cos, sin, and x_half, and finally combined them to derive x_rotated.

But there was one subtle difference in how they calculated x_half. Here is the code snippet:

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

Instead of taking even and odd indexes, they simply take the first and second halves. My initial thought was, “This seems wrong,” but if that were the case, others would have noticed it. Maybe I’m missing something. So, I extracted the relevant parts and compared them with my implementation. The results didn’t match.

My next hunch was that they might have used some combination of reshape, transpose, and broadcast to do it and still get the same results. But why? My guess was performance!

I started hunting for these magic transformations in the codebase. After wasting a couple of hours looking through the code, copy-pasting, writing code, and seeking help from LLMs (conversation link: ChatGPT & Claude), I became sure that the HF implementation was wrong (even though it’s very unlikely).

The next step was to create an issue, but being a good open-source contributor and to avoid duplication, I decided to first search for it. I realized that many people before me had the same realization. Here is one such issue with an in-depth discussion and numerous resources. TL;DR: comment.

Basically, when adding any model to the Transformers library, one has to include a convert_xxxx_weights_to_hf.py script. This script converts the original weight and layer names to fit the HF convention. The common belief is that weights should not be changed. But the convert script for Llama has a premute function that rearranges the weights:

# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
    return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

This re-arrangement is what allows them to use x[..., :D/2] and x[..., D/2:] in the RoPE operation, rather than splitting (x) by even and odd indices. Therefore, my guess was right—they do have the magic transformation that I was looking for, but it’s not in the training/inference code, it’s in the convert script. This can be a source of confusion, as it’s not very well documented and is located in a completely different part of the codebase.

This discovery was a fantastic outcome that truly validated my understanding of RoPE. It was a subtle, latent finding that confirmed my thought process and added a layer of depth to my comprehension. The whole journey was a thrilling adventure filled with learning, and I thoroughly enjoyed the process. It’s a perfect example of how diving deep into a problem not only strengthens your expertise but also makes the journey of discovery incredibly fun.

Building intuition about RoPE

Note: This section was added later, on April 13th, 2025.

Having a good intuition about RoPE is really important for understanding various RoPE-based LLM context window extension methods. After spending few weeks going through different papers, implementing them, and pulling my hairs out for days to make of it all, I found that visualizing the theta matrix was one of the most effective ways to understand them. Visualization turned out to be a game changer in this quest.

Let’s say your embedding vector \((X)\), for a token, is of shape \([1, 16]\). RoPE introduces a rotation vector \((\theta)\) which is half the size of \(X\) i.e., \([1, 8]\). The values in \(X\) are rotated in pairs of two – so \(x_1\) and \(x_2\) will be rotated by angle \(\theta_1\). Similarly, other pairs will be rotated by their corresponding angles. I have color-coded the elements of \(X\) and \(\theta\) to make it easy to visualize.

image.png

So, the first and obvious question is: Where do we get \(\theta\)? Simple, we calculate it using the formula: \[\theta_d = b^{-2d/|D|}\] Here, \(b\) is called base (generally set to 10,000), and \(|D|\) is dimensionality of \(X\) (which is 16 in our example).

Below is the pytorch code to calculate the \(\theta\) vector:

emb_dim = 16
base = 10_000

theta = 1.0 / (base ** (torch.arange(0, emb_dim, 2) / emb_dim))
theta

Lets plot our theta values

plt.plot(theta)
plt.title("Theta Vector")
plt.xlabel("Index")
plt.ylabel("Theta value")
plt.show()

Here’s how to interpret the plot: - Y-axis shows the actual value of \(θ\) (theta) for each index. - From left to right, the \(θ\) value decay exponentially – lower dimensions are assigned higher frequencies (large \(θ\) values), while higher dimension receive lower frequencies. - The rotation angle becomes nearly zero for higher dimensions. In the plot above, indices 5, 6, and 7 are all approximately zero.

Being able to simply calculate these values is what makes RoPE flexible for extending context window. Since these values are not learnt, we can change them even after training allowing us to easily extend the context window of pre-trained LLMs that use RoPE.

Processing Sequence

Moving on, when processing a sequence, we have multiple tokens. To encode positional information, we multiply the rotation vector \(θ\) by the position of each token in the sequence.

As shown in the image below, the same rotation vector \((\theta)\) is scaled (i.e. multiplied) by the position of the token in the sequence. This results in a 2D matrix, where each value represents the angle by which a corresponding pair of embedding values will be rotated. Notice, each row of this matrix (i.e., each token position) contains distinct angles, enabling the model to differentiate – and thus learn – based on token position.

image.png

Here is the pytorch code for generating the matrix

max_seq_len = 8

pos_ids = torch.arange(0, max_seq_len)
thetas = torch.outer(pos_ids, theta)
thetas.shape

We have a theta vector that is 8D, and we can have at most 8 tokens in the sequence. Plotting thetas matrix:

for i, t in enumerate(thetas):
    plt.plot(t, label=f"Position {i}")
plt.title("Theta Vector")
plt.xlabel("Index")
plt.ylabel("Theta value")
plt.legend()
plt.show()

As the token position in the sequence increases, the angle of rotation also increases linearly.

However, theta values for the higher dimensions (indices 5, 6, and 7) are almost indistinguishable across different token positions. This effect is much clearer when visualized using in a heatmap.

plt.imshow(thetas, cmap="Blues")
plt.colorbar()
plt.title("Theta Matrix")
plt.xlabel("Index")
plt.ylabel("Token Position")
plt.show()

Wow, most values in the matrix are close to zero!

What you’re looking at is a 2D matrix where:

  • Going left to right (i.e., as the index of the theta vector increases), the values decrease.

  • Going top to bottom (i.e., as the token position increases), the values increase.

This pattern is quite important to understand. So, here’s a quick question to test your grasp of it:

Where is the smallest value and where is the largest value in the matrix?

Now, scroll back up, and take a closer look at get_theta function defined above. At this point, it should be a piece of cake to both understand and even write this function from scratch!

Further Work

Now that we understand the mechanics of RoPE embeddings, I plan to dive deeper into various methods of extending the context window in LLMs that are built on top of RoPE. Specifically, I want to better understand the how theta values are manipulated to extend context window. The goal is to develop a more comprehensive understanding of RoPE by exploring it from multiple perspectives. I believe this will provide valuable insights and make it easier to further deepen my knowledge in this area.