-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
ENH: speed up einsum with optimize using batched matmul #23513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
No rush on my part, but wondering if there are any thoughts about this change or things I can do for this PR? |
TBH, I had hoped @dgasmith can give a recommendation. The timings do look like this should work pretty generically, although its easy to miss cases, I guess.
|
Hi @seberg, ah yes I just checked and currently that matmul behaviour would result in a performance regression for the specific case that arrays reach the matmul call (after preparatory reshape/transpose etc) with whichever strides trigger the non-blas, no-copy call (#23588). import numpy as np
d = 1000
x = np.random.randn(d, d)
y = np.random.randn(d, d) + 1j * np.random.randn(d, d)
%%timeit
z = np.einsum('ij,jk->ik', x, y.real, optimize=True)
# -->
# 11.5 ms ± 473 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) (current dot backend)
# -->
# 897 ms ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (new matmul backend) If the solution in #23752 is to do the copy when needed then that is very much aligned with the logic of |
It is hard for me to say, we benchmarked fairly throughly with |
Rather than wait on progress in #23588 would it be reasonable to simply require contiguity before the matmul in the I.e. do we expect there to be cases where the arrays are not contiguous, and we expect it to be faster to leave them so? My guess is that this is only the case for certain small contractions that are not the focus of when people use Of course having some rough heuristic in |
This PR enables dispatching all pairwise calls in
einsum
whenoptimize=
is turned on to batched matrix multiply (or non-batchedmatmul
andmultiply
as special cases). For a short explanation of the implementation see here - https://github.com/jcmgray/einsum_bmm.Currently falling back to
c_einsum
for these, when there exist batch indices, induces potentially many orders of magnitude slow downs (#22604, e.g. 7000x), particularly in the case of large or high dimensional tensors.After fairly extensive benchmarking (below), I believe that this should be a uniform improvement across essentially all the cases when one would use
optimize=True
. This is partly enabled by the modern performance of matmul which in the past was not always faster thaneinsum
. However, I knoweinsum
has an enormous number of uses and I'd gladly accept more suggestions for benchmarks, and would understand people's hesitation!(@dgasmith, @seberg)
Other notes:
can_blas
logic andtensordot
calls, and simplifies everything into thebmm_einsum
, which encapsulates howtensordot
is implementednumpy.multiply
is used for slightly better performanceout
and other kwargs on tomutiply
andmatmul
, unlike tensordot, meaningeinsum('ij,jk->ik', x, y, optimize=True, out=z)
should write directly toz
einsum(eq, ..., optimize=True)
to be used when there are more than 32 indices involved (Increase maximum number of array dimensions? #5744).bmm_einsum
implementation comes from caching what operations to use based on(eq, x.shape, y.shape)
, using@functools.lru_cache(2**12)
.torch.einsum
TODO:
einsumfunc.py
code?Benchmarks
In the following the kernels are the following:
no optimize
: the basenp.einsum(eq, x, y)
i.e.c_einsum
optimize + dot
: the currentnp.einsum(eq, x, y, optimize=True)
which callstensordot
where possible, and also induces some overhead from the potential path optimization (even though there are only two terms in the following, the equation must be checked and parsed etc.)optimize + bmm
: the proposednp.einsum(eq, x, y, optimize=True)
which calls batched matmul, and still induces some overhead from the potential path optimizationno optimize + bmm
: this calls the pairwise bmm einsum directly, and is provided simply to show what overhead comes from the bmm impl, and what from the potential path optimizationThe main two to compare are
optimize + dot
tooptimize + bmm
. Overall I'd summarize the results as the following:dot
is retainedAll dimensions are size
n
unless otherwise noted.For
batch_matmul_small
the sizes are(n, 2, 2), (n, 2, 2)
.For
batch_matmul_large
the sizes are(10, n, n), (10, n, n)
.In
random_extreme
the shapes are taken as the following:In
many_small_dims_dot
the shapes are([2] * n, [2] * n)
and half in an interleaved pattern are contracted.In
many_small_dims_batched_dot
the shapes are([2] * n, [2] * n)
and a third in an interleaved pattern are contracted and a third in an interleaved pattern are batch indices.