Skip to content

ENH: speed up einsum with optimize using batched matmul #23513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jcmgray
Copy link

@jcmgray jcmgray commented Mar 31, 2023

This PR enables dispatching all pairwise calls in einsum when optimize= is turned on to batched matrix multiply (or non-batched matmul and multiply as special cases). For a short explanation of the implementation see here - https://github.com/jcmgray/einsum_bmm.

Currently falling back to c_einsum for these, when there exist batch indices, induces potentially many orders of magnitude slow downs (#22604, e.g. 7000x), particularly in the case of large or high dimensional tensors.

After fairly extensive benchmarking (below), I believe that this should be a uniform improvement across essentially all the cases when one would use optimize=True. This is partly enabled by the modern performance of matmul which in the past was not always faster than einsum. However, I know einsum has an enormous number of uses and I'd gladly accept more suggestions for benchmarks, and would understand people's hesitation!

(@dgasmith, @seberg)

Other notes:

  • This replaces all the can_blas logic and tensordot calls, and simplifies everything into the bmm_einsum, which encapsulates how tensordot is implemented
  • When there are no contracted indices, numpy.multiply is used for slightly better performance
  • The current implementation can simply pass out and other kwargs on to mutiply and matmul, unlike tensordot, meaning einsum('ij,jk->ik', x, y, optimize=True, out=z) should write directly to z
  • The fusing allows einsum(eq, ..., optimize=True) to be used when there are more than 32 indices involved (Increase maximum number of array dimensions? #5744).
  • The low overhead of the actuall bmm_einsum implementation comes from caching what operations to use based on (eq, x.shape, y.shape), using @functools.lru_cache(2**12).
  • the performance broadly speaking is brought in line with torch.einsum

TODO:

  • update docs?
  • decide cache size?
  • should I lint the existing einsumfunc.py code?

Benchmarks

In the following the kernels are the following:

  • no optimize: the base np.einsum(eq, x, y) i.e. c_einsum
  • optimize + dot: the current np.einsum(eq, x, y, optimize=True) which calls tensordot where possible, and also induces some overhead from the potential path optimization (even though there are only two terms in the following, the equation must be checked and parsed etc.)
  • optimize + bmm: the proposed np.einsum(eq, x, y, optimize=True) which calls batched matmul, and still induces some overhead from the potential path optimization
  • no optimize + bmm: this calls the pairwise bmm einsum directly, and is provided simply to show what overhead comes from the bmm impl, and what from the potential path optimization

The main two to compare are optimize + dot to optimize + bmm. Overall I'd summarize the results as the following:

  1. when a contraction is memory bound, all methods are pretty much equivalent
  2. when a contraction involves no batch dimensions, performance of dot is retained
  3. when a contraction involves batch dimensions, performance can be orders of magnitude faster

All dimensions are size n unless otherwise noted.

vec_inner

vec_outer

square_matmul

square_matvec

For batch_matmul_small the sizes are (n, 2, 2), (n, 2, 2).

batch_matmul_small

square_vecmat

batch_matmul_equal

For batch_matmul_large the sizes are (10, n, n), (10, n, n).

batch_matmul_large

hadamard

hadamard_unalinged

CCSDT_1

interleaved_dot

interleaved_batched_dot

In random_extreme the shapes are taken as the following:

[
    (n, 4, 3, 4, 3, 4, 2, 4, 2, n, 2, n, 3, 2, 4),
    (2, 4, n, n, 4, 4, 4, 3, 4, 4, 4, 3, 4)
]

random_extreme

In many_small_dims_dot the shapes are ([2] * n, [2] * n) and half in an interleaved pattern are contracted.

many_small_dims_dot

In many_small_dims_batched_dot the shapes are ([2] * n, [2] * n) and a third in an interleaved pattern are contracted and a third in an interleaved pattern are batch indices.

many_small_dims_batched_dot

@jcmgray
Copy link
Author

jcmgray commented May 6, 2023

No rush on my part, but wondering if there are any thoughts about this change or things I can do for this PR?

@seberg
Copy link
Member

seberg commented May 17, 2023

TBH, I had hoped @dgasmith can give a recommendation. The timings do look like this should work pretty generically, although its easy to miss cases, I guess.

matmul does currently have a pretty big flaw with strided inputs unfortunately (there is another PR that would address this), would that be related here?

@jcmgray
Copy link
Author

jcmgray commented May 17, 2023

Hi @seberg, ah yes I just checked and currently that matmul behaviour would result in a performance regression for the specific case that arrays reach the matmul call (after preparatory reshape/transpose etc) with whichever strides trigger the non-blas, no-copy call (#23588).

import numpy as np
d = 1000
x = np.random.randn(d, d)
y = np.random.randn(d, d) + 1j * np.random.randn(d, d)

%%timeit
z = np.einsum('ij,jk->ik', x, y.real, optimize=True)
# -->
# 11.5 ms ± 473 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)  (current dot backend)
# -->
# 897 ms ± 23.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)  (new matmul backend)

If the solution in #23752 is to do the copy when needed then that is very much aligned with the logic of optimize=True, i.e. trade memory and copying for being able to BLAS-it. Else one could check in the einsum_bmm call but preferably it would be deferred to matmul I suppose.

@dgasmith
Copy link
Contributor

TBH, I had hoped @dgasmith can give a recommendation. The timings do look like this should work pretty generically, although its easy to miss cases, I guess.

It is hard for me to say, we benchmarked fairly throughly with optimize=True; however, the deployed result had a large number of edge cases that we missed. That being said, localizing the code to optimize=True is a good move to let everyone switch on/off performance.

@jcmgray
Copy link
Author

jcmgray commented Jul 26, 2023

Rather than wait on progress in #23588 would it be reasonable to simply require contiguity before the matmul in the bmm_einsum function here?

I.e. do we expect there to be cases where the arrays are not contiguous, and we expect it to be faster to leave them so?

My guess is that this is only the case for certain small contractions that are not the focus of when people use optimize=True.

Of course having some rough heuristic in matmul itself would be cleaner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants