Einsum is easy and useful

einsum is one of the most useful functions in Numpy/Pytorch/Tensorflow and yet many people don’t use it. It seems to have a reputation as being difficult to understand and use, which is completely backwards in my view: the reason einsum is great is precisely because it is easier to use and reason about than the alternatives. So this post tries to set the record straight and show how simple einsum really is.

The general syntax for einsum is

einsum("some string describing an operation", tensor_1, tensor_2, ...)


with an arbitrary number of tensors after the string. (I’ll be saying “tensors” but they could just as well be Numpy arrays.)

Let’s look at an example. Say we have two matrices, A and B, with shapes such that we can multiply them as A @ B. Using einsum, we can write this matrix product as

einsum("ij,jk->ik", A, B)


The interesting part is the string, "ij,jk->ik". These einsum strings always follow the same structure:

"<input indices> -> <output indices>"


In this example, the input indices are "ij,jk". These define letters for indices into each input tensor. Different tensors are comma-separated, so ij refers to A and jk refers to B. ij means we call the first axis of A i and the second axis j, and similarly, jk defines names for the axes of B. The specific letters we use are arbitrary here, we could just as well write "ga,aw->gw". There has to be one index per axis of the input tensor—in our case, both A and B have two axes (they’re matrices), so both get two indices.

What’s important is that we’re using the same letter, j, both for the second axis of A and for the first axis of B. That’s not just an arbitrary definition, it has an effect on the result! Think of it this way: the entire left-hand side, "ij,jk" defines a three-dimensional tensor, indexed by i, j, and k. We get its elements by multiplying the corresponding elements of A and B:

product[i, j, k] = A[i, j] * B[j, k]


So it matters that j appears twice—a string like "ij,lk" would define a four-dimensional tensor:

product[i, j, l, k] = A[i, j] * B[l, k]


(Don’t worry about the order of these indices into product—as we’ll see in a moment, the right-hand side of our string will explicitly specify the order we want).

The right side of the -> arrow describes how to get our final output from this product tensor. It’s very simple: any index that appears on the left (i.e. in the product tensor) but doesn’t appear on the right is summed over. So in our matrix multiplication example, since our output indices are ik, we sum over j. So the final result is

out[i, k] = sum_j A[i, j] * B[j, k]


Precisely a matrix multiply, as promised!

All of this generalizes in very nice and intuitive ways. On the left side of the -> arrow, we can have arbitrary patterns, and they’ll always describe a scheme for indexing into a product of the inputs. For example, the string "iij,kji,l" would define a four-dimensional tensor, given by

product[i, j, k, l] = A[i, i, j] * B[k, j, i] * C[l]


(for input tensors A, B, C). Note how much easier this is compared to a version without einsum:

n = A.shape[0]
product = A[t.arange(n), t.arange(n), :, None, None] \
* B.permute(2, 1, 0)[:, :, :, None] \
* C[None, None, None, :]


Our output can now be any permutation of any subset of ijkl. For example, "iij,kji,l->ki" would implicitly compute the product tensor above, then sum over j and l, and finally permute the result so the order of axes was ki. Contrast with how messy this would be without einsum:

# With broadcasting and array indexing:
n = A.shape[0]
out = (
A[t.arange(n), t.arange(n), :, None, None]
* B.permute(2, 1, 0)[:, :, :, None]
* C[None, None, None, :]
).sum(1,3).T

# With einsum:
out = t.einsum("iij,kji,l->ki", A, B, C)


The main point is not that the einsum version is shorter—the point is that the other version took me 10 minutes to write and I’m still not sure it’s correct.

That concludes the description of einsum, but let’s look at some more examples to get a better intuition:

• Say you want to compute the transpose of the matrix product, (A @ B).T. What that means is just that you want the indices in the output flipped, so the string now becomes "ij,jk->ki".
• Sometimes you want to sum over all axes; in that case, you can just leave the right hand side empty. For example, "i,i->" will compute the inner product of two input vectors.
• Just like you can have repeated indices in different input tensors, you can repeat indices within the same tensor. For example, "ii->" computes the trace of a matrix. Or you could do "ii->i" to get the diagonal as a vector.
• You can trivially add batch dimensions to any operation. For example, a batched inner product would be "bi,bi->b". A batched matrix multiply would be "bij,bjk->bik". If for some reason, your batch dimension is in the last position for the second batch of matrices, that’s no problem: "bij,jkb->bik".
• Batching also lets you take arbitrary diagonals of a tensor easily. For example, "ibi->bi" will give you the diagonal along the first and third axis, batched over the middle axis.
• A neat trick is that you can have a variable number of batch dimensions using a ... syntax: "...ij,...jk->...ik" is a batched matrix multiply that works for any number of batch dimensions. The ... can be anywhere, not just at the front. For example, "...ij,j...k->ik..." will work just fine, for any number of dimensions as the ...

Einops

The main problem with einsum is that it doesn’t support enough operations. For example, say you have an image tensor x with shape (batch, channels, height, width). If all you want to do is move the channel axis, you can just do

einsum("bchw->bhwc", x)


But what if you also want to flatten the height and width dimension into a single axis? That’s where the einops library comes into play:

rearrange(x, "b c h w -> b (h w) c")


The rearrange function takes a tensor followed by a string similar to einsum strings. The only new aspect are the parentheses—here, they tell einops to combine the height and width dimension into one axis. Just like einsum, rearrange is extremely flexible, so you can for example transpose these axes before flattening them just by doing

rearrange(x, "b c h w -> b (w h) c")


You can also have parentheses on the left side, to split one axis into multiple. So for example, we can invert the operation above using

rearrange(x, "b (w h) c -> b c h w", h=32)


The h=32 tells einops that the h axis should have length 32 (without this information, it would be unclear how to split up the single input axis).

Note how obvious it is that "b (w h) c -> b c h w" is the inverse of "b c h w -> b (w h) c". We just switched the left and right-hand side! In general, you can compose rearrange operations: doing "string_1 -> string_2" followed by "string_2 -> string_3" is the same as doing "string_1 -> string_3".

There’s more to say about rearrange, and its cousin repeat, but the einops documentation does a good job explaining them, so I’ll leave it at that.

A more complex example

As a final example, let’s consider a multi-head attention mechanism. Say we have our Q, K and V tensors, each of shape (batch, seq_length, n_heads * head_dim) (the n_heads and head_dim dimensions are flattened into one because we did a single matrix multiply for all heads to obtain these tensors). We want to compute the attention pattern A of shape (batch, n_heads, seq_length, seq_length). We can use einops and einsum:

Q = rearrange(Q, "b q (h d) -> b q h d", h=n_heads)
K = rearrange(K, "b k (h d) -> b k h d", h=n_heads)
A = einsum("b q h d, b k h d -> b h q k", Q, K)


Note how simple this is: the output of our rearranged tensors, bqhd and bkhd, already tell us what the inputs for the einsum have to be. The output indices for the einsum are almost determined by the desired output shape. Implicitly, this is a batched matrix multiply over the d dimensions (batched over all other axes), but we don’t have to think about that to write down the einsum.

Contrast with this alternative, where we have to carefully think about the position of each axis:

batch, seq_length, _ = Q.shape
Q = Q.view(batch, seq_length, 1, n_heads, -1)
K = K.view(batch, 1, seq_length, n_heads, -1)
A = (Q @ K.transpose(-2, -1)).moveaxis(-1, 1)


einsum isn’t just easier to read and write, it’s also easier to refactor. Imagine we suddenly have to deal with multiple batch dimensions. In the einsum version, we just replace the b with ... and everything works, whereas the second version becomes even messier. Or imagine the format of our tensors changes so that the h and d axes are now swapped. In the einsum version, we just swap all occurrences of h and d. The other version requires us to carefully check all the magic numbers to see which ones need to be changed.

einsum could be even better

I’d love to see an einsum wrapper that combines the capabilities of einsum and rearrange. For example, we could write the attention mechanism above as a simple one-liner:

A = einsum("b q (h d), b k (h d) -> b h q k", Q, K, h=n_heads)


If anyone ends up implementing this, please let me know!