""" Tests thte BLAS capability for the opt_einsum module. """ import numpy as np import pytest from opt_einsum import blas, contract, helpers blas_tests = [ # DOT ((['k', 'k'], '', set('k')), 'DOT'), # DDOT ((['ijk', 'ijk'], '', set('ijk')), 'DOT'), # DDOT # GEMV? # GEMM ((['ij', 'jk'], 'ik', set('j')), 'GEMM'), # GEMM N N ((['ijl', 'jlk'], 'ik', set('jl')), 'GEMM'), # GEMM N N Tensor ((['ij', 'kj'], 'ik', set('j')), 'GEMM'), # GEMM N T ((['ijl', 'kjl'], 'ik', set('jl')), 'GEMM'), # GEMM N T Tensor ((['ji', 'jk'], 'ik', set('j')), 'GEMM'), # GEMM T N ((['jli', 'jlk'], 'ik', set('jl')), 'GEMM'), # GEMM T N Tensor ((['ji', 'kj'], 'ik', set('j')), 'GEMM'), # GEMM T T ((['jli', 'kjl'], 'ik', set('jl')), 'GEMM'), # GEMM T T Tensor # GEMM with final transpose ((['ij', 'jk'], 'ki', set('j')), 'GEMM'), # GEMM N N ((['ijl', 'jlk'], 'ki', set('jl')), 'GEMM'), # GEMM N N Tensor ((['ij', 'kj'], 'ki', set('j')), 'GEMM'), # GEMM N T ((['ijl', 'kjl'], 'ki', set('jl')), 'GEMM'), # GEMM N T Tensor ((['ji', 'jk'], 'ki', set('j')), 'GEMM'), # GEMM T N ((['jli', 'jlk'], 'ki', set('jl')), 'GEMM'), # GEMM T N Tensor ((['ji', 'kj'], 'ki', set('j')), 'GEMM'), # GEMM T T ((['jli', 'kjl'], 'ki', set('jl')), 'GEMM'), # GEMM T T Tensor # Tensor Dot (requires copy), lets not deal with this for now ((['ilj', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM N N Tensor ((['ijl', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM N N Tensor ((['ilj', 'kjl'], 'ik', set('jl')), 'TDOT'), # FT GEMM N T Tensor ((['ijl', 'klj'], 'ik', set('jl')), 'TDOT'), # ST GEMM N T Tensor ((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor ((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor ((['lji', 'jlk'], 'ik', set('jl')), 'TDOT'), # FT GEMM T N Tensor ((['jli', 'ljk'], 'ik', set('jl')), 'TDOT'), # ST GEMM T N Tensor # Tensor Dot (requires copy), lets not deal with this for now with transpose ((['ilj', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM N N Tensor ((['ijl', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM N N Tensor ((['ilj', 'kjl'], 'ik', set('lj')), 'TDOT'), # FT GEMM N T Tensor ((['ijl', 'klj'], 'ik', set('lj')), 'TDOT'), # ST GEMM N T Tensor ((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor ((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor ((['lji', 'jlk'], 'ik', set('lj')), 'TDOT'), # FT GEMM T N Tensor ((['jli', 'ljk'], 'ik', set('lj')), 'TDOT'), # ST GEMM T N Tensor # Other ((['ijk', 'ikj'], '', set('ijk')), 'DOT/EINSUM'), # Transpose DOT ((['i', 'j'], 'ij', set()), 'OUTER/EINSUM'), # Outer ((['ijk', 'ik'], 'j', set('ik')), 'GEMV/EINSUM'), # Matrix-vector ((['ijj', 'jk'], 'ik', set('j')), False), # Double index ((['ijk', 'j'], 'ij', set()), False), # Index sum 1 ((['ij', 'ij'], 'ij', set()), False), # Index sum 2 ] @pytest.mark.parametrize("inp,benchmark", blas_tests) def test_can_blas(inp, benchmark): result = blas.can_blas(*inp) assert result == benchmark @pytest.mark.parametrize("inp,benchmark", blas_tests) def test_tensor_blas(inp, benchmark): # Weed out non-blas cases if benchmark is False: return tensor_strs, output, reduced_idx = inp einsum_str = ','.join(tensor_strs) + '->' + output # Only binary operations should be here if len(tensor_strs) != 2: assert False view_left, view_right = helpers.build_views(einsum_str) einsum_result = np.einsum(einsum_str, view_left, view_right) blas_result = blas.tensor_blas(view_left, tensor_strs[0], view_right, tensor_strs[1], output, reduced_idx) assert np.allclose(einsum_result, blas_result) def test_blas_out(): a = np.random.rand(4, 4) b = np.random.rand(4, 4) c = np.random.rand(4, 4) d = np.empty((4, 4)) contract('ij,jk->ik', a, b, out=d) assert np.allclose(d, np.dot(a, b)) contract('ij,jk,kl->il', a, b, c, out=d) assert np.allclose(d, np.dot(a, b).dot(c))