""" Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths """ import numpy as np import pytest from opt_einsum import contract, contract_expression, contract_path, helpers from opt_einsum.paths import linear_to_ssa, ssa_to_linear tests = [ # Test hadamard-like products 'a,ab,abc->abc', 'a,b,ab->ab', # Test index-transformations 'ea,fb,gc,hd,abcd->efgh', 'ea,fb,abcd,gc,hd->efgh', 'abcd,ea,fb,gc,hd->efgh', # Test complex contractions 'acdf,jbje,gihb,hfac,gfac,gifabc,hfac', 'acdf,jbje,gihb,hfac,gfac,gifabc,hfac', 'cd,bdhe,aidb,hgca,gc,hgibcd,hgac', 'abhe,hidj,jgba,hiab,gab', 'bde,cdh,agdb,hica,ibd,hgicd,hiac', 'chd,bde,agbc,hiad,hgc,hgi,hiad', 'chd,bde,agbc,hiad,bdi,cgh,agdb', 'bdhe,acad,hiab,agac,hibd', # Test collapse 'ab,ab,c->', 'ab,ab,c->c', 'ab,ab,cd,cd->', 'ab,ab,cd,cd->ac', 'ab,ab,cd,cd->cd', 'ab,ab,cd,cd,ef,ef->', # Test outer prodcuts 'ab,cd,ef->abcdef', 'ab,cd,ef->acdf', 'ab,cd,de->abcde', 'ab,cd,de->be', 'ab,bcd,cd->abcd', 'ab,bcd,cd->abd', # Random test cases that have previously failed 'eb,cb,fb->cef', 'dd,fb,be,cdb->cef', 'bca,cdb,dbf,afc->', 'dcc,fce,ea,dbf->ab', 'fdf,cdd,ccd,afe->ae', 'abcd,ad', 'ed,fcd,ff,bcf->be', 'baa,dcf,af,cde->be', 'bd,db,eac->ace', 'fff,fae,bef,def->abd', 'efc,dbc,acf,fd->abe', # Inner products 'ab,ab', 'ab,ba', 'abc,abc', 'abc,bac', 'abc,cba', # GEMM test cases 'ab,bc', 'ab,cb', 'ba,bc', 'ba,cb', 'abcd,cd', 'abcd,ab', 'abcd,cdef', 'abcd,cdef->feba', 'abcd,efdc', # Inner than dot 'aab,bc->ac', 'ab,bcc->ac', 'aab,bcc->ac', 'baa,bcc->ac', 'aab,ccb->ac', # Randomly build test caes 'aab,fa,df,ecc->bde', 'ecb,fef,bad,ed->ac', 'bcf,bbb,fbf,fc->', 'bb,ff,be->e', 'bcb,bb,fc,fff->', 'fbb,dfd,fc,fc->', 'afd,ba,cc,dc->bf', 'adb,bc,fa,cfc->d', 'bbd,bda,fc,db->acf', 'dba,ead,cad->bce', 'aef,fbc,dca->bde', ] all_optimizers = [ 'optimal', 'branch-all', 'branch-2', 'branch-1', 'greedy', 'random-greedy', 'random-greedy-128', 'dp', 'auto', 'auto-hq' ] @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", all_optimizers) def test_compare(optimize, string): views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) opt = contract(string, *views, optimize=optimize, use_blas=False) assert np.allclose(ein, opt) @pytest.mark.parametrize("string", tests) def test_drop_in_replacement(string): views = helpers.build_views(string) opt = contract(string, *views) assert np.allclose(opt, np.einsum(string, *views)) @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", all_optimizers) def test_compare_greek(optimize, string): views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) # convert to greek string = ''.join(chr(ord(c) + 848) if c not in ',->.' else c for c in string) opt = contract(string, *views, optimize=optimize, use_blas=False) assert np.allclose(ein, opt) @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", all_optimizers) def test_compare_blas(optimize, string): views = helpers.build_views(string) ein = contract(string, *views, optimize=False) opt = contract(string, *views, optimize=optimize) assert np.allclose(ein, opt) @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", all_optimizers) def test_compare_blas_greek(optimize, string): views = helpers.build_views(string) ein = contract(string, *views, optimize=False) # convert to greek string = ''.join(chr(ord(c) + 848) if c not in ',->.' else c for c in string) opt = contract(string, *views, optimize=optimize) assert np.allclose(ein, opt) def test_some_non_alphabet_maintains_order(): # 'c beta a' should automatically go to -> 'a c beta' string = 'c' + chr(ord('b') + 848) + 'a' # but beta will be temporarily replaced with 'b' for which 'cba->abc' # so check manual output kicks in: x = np.random.rand(2, 3, 4) assert np.allclose(contract(string, x), contract('cxa', x)) def test_printing(): string = "bbd,bda,fc,db->acf" views = helpers.build_views(string) ein = contract_path(string, *views) assert len(str(ein[1])) == 728 @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", all_optimizers) @pytest.mark.parametrize("use_blas", [False, True]) @pytest.mark.parametrize("out_spec", [False, True]) def test_contract_expressions(string, optimize, use_blas, out_spec): views = helpers.build_views(string) shapes = [view.shape for view in views] expected = contract(string, *views, optimize=False, use_blas=False) expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas) if out_spec and ("->" in string) and (string[-2:] != "->"): out, = helpers.build_views(string.split('->')[1]) expr(*views, out=out) else: out = expr(*views) assert np.allclose(out, expected) # check representations assert string in expr.__repr__() assert string in expr.__str__() def test_contract_expression_interleaved_input(): x, y, z = (np.random.randn(2, 2) for _ in 'xyz') expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) xshp, yshp, zshp = ((2, 2) for _ in 'xyz') expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0]) out = expr(x, y, z) assert np.allclose(out, expected) @pytest.mark.parametrize("string,constants", [ ('hbc,bdef,cdkj,ji,ikeh,lfo', [1, 2, 3, 4]), ('bdef,cdkj,ji,ikeh,hbc,lfo', [0, 1, 2, 3]), ('hbc,bdef,cdkj,ji,ikeh,lfo', [1, 2, 3, 4]), ('hbc,bdef,cdkj,ji,ikeh,lfo', [1, 2, 3, 4]), ('ijab,acd,bce,df,ef->ji', [1, 2, 3, 4]), ('ab,cd,ad,cb', [1, 3]), ('ab,bc,cd', [0, 1]), ]) def test_contract_expression_with_constants(string, constants): views = helpers.build_views(string) expected = contract(string, *views, optimize=False, use_blas=False) shapes = [view.shape for view in views] expr_args = [] ctrc_args = [] for i, (shape, view) in enumerate(zip(shapes, views)): if i in constants: expr_args.append(view) else: expr_args.append(shape) ctrc_args.append(view) expr = contract_expression(string, *expr_args, constants=constants) print(expr) out = expr(*ctrc_args) assert np.allclose(expected, out) @pytest.mark.parametrize("optimize", ['greedy', 'optimal']) @pytest.mark.parametrize("n", [4, 5]) @pytest.mark.parametrize("reg", [2, 3]) @pytest.mark.parametrize("n_out", [0, 2, 4]) @pytest.mark.parametrize("global_dim", [False, True]) def test_rand_equation(optimize, n, reg, n_out, global_dim): eq, _, size_dict = helpers.rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True) views = helpers.build_views(eq, size_dict) expected = contract(eq, *views, optimize=False) actual = contract(eq, *views, optimize=optimize) assert np.allclose(expected, actual) @pytest.mark.parametrize('equation', tests) def test_linear_vs_ssa(equation): views = helpers.build_views(equation) linear_path, _ = contract_path(equation, *views) ssa_path = linear_to_ssa(linear_path) linear_path2 = ssa_to_linear(ssa_path) assert linear_path2 == linear_path def test_contract_path_supply_shapes(): eq = 'ab,bc,cd' shps = [(2, 3), (3, 4), (4, 5)] contract_path(eq, *shps, shapes=True)