Einsum is easy and useful
einsum
is one of the most useful functions in Numpy/Pytorch/Tensorflow and yet many people don’t use it. It seems to have a reputation as being difficult to understand and use, which is completely backwards in my view: the reason einsum
is great is precisely because it is easier to use and reason about than the alternatives. So this post tries to set the record straight and show how simple einsum
really is.
The general syntax for einsum
is
einsum("some string describing an operation", tensor_1, tensor_2, ...)
with an arbitrary number of tensors after the string. (I’ll be saying “tensors” but they could just as well be Numpy arrays.)
Let’s look at an example. Say we have two matrices, A
and B
, with shapes such that we can multiply them as A @ B
. Using einsum
, we can write this matrix product as
einsum("ij,jk->ik", A, B)
The interesting part is the string, "ij,jk->ik"
. These einsum
strings always follow the same structure:
"<input indices> -> <output indices>"
In this example, the input indices are "ij,jk"
. These define letters for indices into each input tensor. Different tensors are comma-separated, so ij
refers to A
and jk
refers to B
. ij
means we call the first axis of A
i
and the second axis j
, and similarly, jk
defines names for the axes of B
. The specific letters we use are arbitrary here, we could just as well write "ga,aw->gw"
. There has to be one index per axis of the input tensor—in our case, both A
and B
have two axes (they’re matrices), so both get two indices.
What’s important is that we’re using the same letter, j
, both for the second axis of A
and for the first axis of B
. That’s not just an arbitrary definition, it has an effect on the result! Think of it this way: the entire left-hand side, "ij,jk"
defines a three-dimensional tensor, indexed by i
, j
, and k
. We get its elements by multiplying the corresponding elements of A
and B
:
product[i, j, k] = A[i, j] * B[j, k]
So it matters that j
appears twice—a string like "ij,lk"
would define a four-dimensional tensor:
product[i, j, l, k] = A[i, j] * B[l, k]
(Don’t worry about the order of these indices into product
—as we’ll see in a moment, the right-hand side of our string will explicitly specify the order we want).
The right side of the ->
arrow describes how to get our final output from this product
tensor. It’s very simple: any index that appears on the left (i.e. in the product
tensor) but doesn’t appear on the right is summed over. So in our matrix multiplication example, since our output indices are ik
, we sum over j
. So the final result is
out[i, k] = sum_j A[i, j] * B[j, k]
Precisely a matrix multiply, as promised!
All of this generalizes in very nice and intuitive ways. On the left side of the ->
arrow, we can have arbitrary patterns, and they’ll always describe a scheme for indexing into a product of the inputs. For example, the string "iij,kji,l"
would define a four-dimensional tensor, given by
product[i, j, k, l] = A[i, i, j] * B[k, j, i] * C[l]
(for input tensors A
, B
, C
).
Note how much easier this is compared to a version without einsum
:
n = A.shape[0]
product = A[t.arange(n), t.arange(n), :, None, None] \
* B.permute(2, 1, 0)[:, :, :, None] \
* C[None, None, None, :]
Our output can now be any permutation of any subset of ijkl
. For example, "iij,kji,l->ki"
would implicitly compute the product tensor above, then sum over j
and l
, and finally permute the result so the order of axes was ki
. Contrast with how messy this would be without einsum
:
# With broadcasting and array indexing:
n = A.shape[0]
out = (
A[t.arange(n), t.arange(n), :, None, None]
* B.permute(2, 1, 0)[:, :, :, None]
* C[None, None, None, :]
).sum(1,3).T
# With einsum:
out = t.einsum("iij,kji,l->ki", A, B, C)
The main point is not that the einsum
version is shorter—the point is that the other version took me 10 minutes to write and I’m still not sure it’s correct.
That concludes the description of einsum
, but let’s look at some more examples to get a better intuition:
- Say you want to compute the transpose of the matrix product,
(A @ B).T
. What that means is just that you want the indices in the output flipped, so the string now becomes"ij,jk->ki"
. - Sometimes you want to sum over all axes; in that case, you can just leave the right hand side empty. For example,
"i,i->"
will compute the inner product of two input vectors. - Just like you can have repeated indices in different input tensors, you can repeat indices within the same tensor. For example,
"ii->"
computes the trace of a matrix. Or you could do"ii->i"
to get the diagonal as a vector. - You can trivially add batch dimensions to any operation. For example, a batched inner product would be
"bi,bi->b"
. A batched matrix multiply would be"bij,bjk->bik"
. If for some reason, your batch dimension is in the last position for the second batch of matrices, that’s no problem:"bij,jkb->bik"
. - Batching also lets you take arbitrary diagonals of a tensor easily. For example,
"ibi->bi"
will give you the diagonal along the first and third axis, batched over the middle axis. - A neat trick is that you can have a variable number of batch dimensions using a
...
syntax:"...ij,...jk->...ik"
is a batched matrix multiply that works for any number of batch dimensions. The...
can be anywhere, not just at the front. For example,"...ij,j...k->ik..."
will work just fine, for any number of dimensions as the...
Einops
The main problem with einsum
is that it doesn’t support enough operations. For example, say you have an image tensor x
with shape (batch, channels, height, width)
. If all you want to do is move the channel axis, you can just do
einsum("bchw->bhwc", x)
But what if you also want to flatten the height and width dimension into a single axis? That’s where the einops
library comes into play:
rearrange(x, "b c h w -> b (h w) c")
The rearrange
function takes a tensor followed by a string similar to einsum
strings. The only new aspect are the parentheses—here, they tell einops
to combine the height and width dimension into one axis. Just like einsum
, rearrange
is extremely flexible, so you can for example transpose these axes before flattening them just by doing
rearrange(x, "b c h w -> b (w h) c")
You can also have parentheses on the left side, to split one axis into multiple. So for example, we can invert the operation above using
rearrange(x, "b (w h) c -> b c h w", h=32)
The h=32
tells einops
that the h
axis should have length 32 (without this information, it would be unclear how to split up the single input axis).
Note how obvious it is that "b (w h) c -> b c h w"
is the inverse of "b c h w -> b (w h) c"
. We just switched the left and right-hand side! In general, you can compose rearrange
operations: doing "string_1 -> string_2"
followed by "string_2 -> string_3"
is the same as doing "string_1 -> string_3"
.
There’s more to say about rearrange
, and its cousin repeat
, but the einops
documentation does a good job explaining them, so I’ll leave it at that.
A more complex example
As a final example, let’s consider a multi-head attention mechanism. Say we have our Q
, K
and V
tensors,
each of shape (batch, seq_length, n_heads * head_dim)
(the n_heads
and head_dim
dimensions are flattened into one because we did a single matrix multiply for all heads to obtain these tensors). We want to compute the attention pattern A
of shape (batch, n_heads, seq_length, seq_length)
. We can use einops
and einsum
:
Q = rearrange(Q, "b q (h d) -> b q h d", h=n_heads)
K = rearrange(K, "b k (h d) -> b k h d", h=n_heads)
A = einsum("b q h d, b k h d -> b h q k", Q, K)
Note how simple this is: the output of our rearranged tensors, bqhd
and bkhd
, already tell us what the inputs for the einsum
have to be. The output indices for the einsum
are almost determined by the desired output shape. Implicitly, this is a batched matrix multiply over the d
dimensions (batched over all other axes), but we don’t have to think about that to write down the einsum
.
Contrast with this alternative, where we have to carefully think about the position of each axis:
batch, seq_length, _ = Q.shape
Q = Q.view(batch, seq_length, 1, n_heads, -1)
K = K.view(batch, 1, seq_length, n_heads, -1)
A = (Q @ K.transpose(-2, -1)).moveaxis(-1, 1)
einsum
isn’t just easier to read and write, it’s also easier to refactor. Imagine we suddenly have to deal with multiple batch dimensions. In the einsum
version, we just replace the b
with ...
and everything works, whereas the second version becomes even messier. Or imagine the format of our tensors changes so that the h
and d
axes are now swapped. In the einsum
version, we just swap all occurrences of h
and d
. The other version requires us to carefully check all the magic numbers to see which ones need to be changed.
einsum
could be even better
I’d love to see an einsum
wrapper that combines the capabilities of einsum
and rearrange
. For example, we could write the attention mechanism above as a simple one-liner:
A = einsum("b q (h d), b k (h d) -> b h q k", Q, K, h=n_heads)
If anyone ends up implementing this, please let me know!