""" Handles dispatching array operations to the correct backend library, as well as converting arrays to backend formats and then potentially storing them as constants. """ import importlib import numpy from . import object_arrays from . import cupy as _cupy from . import jax as _jax from . import tensorflow as _tensorflow from . import theano as _theano from . import torch as _torch __all__ = ["get_func", "has_einsum", "has_tensordot", "build_expression", "evaluate_constants", "has_backend"] # known non top-level imports _aliases = { 'dask': 'dask.array', 'theano': 'theano.tensor', 'torch': 'opt_einsum.backends.torch', 'jax': 'jax.numpy', 'autograd': 'autograd.numpy', 'mars': 'mars.tensor', } def _import_func(func, backend, default=None): """Try and import ``{backend}.{func}``. If library is installed and func is found, return the func; otherwise if default is provided, return default; otherwise raise an error. """ try: lib = importlib.import_module(_aliases.get(backend, backend)) return getattr(lib, func) if default is None else getattr(lib, func, default) except AttributeError: error_msg = ("{} doesn't seem to provide the function {} - see " "https://optimized-einsum.readthedocs.io/en/latest/backends.html " "for details on which functions are required for which contractions.") raise AttributeError(error_msg.format(backend, func)) # manually cache functions as python2 doesn't support functools.lru_cache # other libs will be added to this if needed, but pre-populate with numpy _cached_funcs = { ('tensordot', 'numpy'): numpy.tensordot, ('transpose', 'numpy'): numpy.transpose, ('einsum', 'numpy'): numpy.einsum, # also pre-populate with the arbitrary object backend ('tensordot', 'object'): numpy.tensordot, ('transpose', 'object'): numpy.transpose, ('einsum', 'object'): object_arrays.object_einsum, } def get_func(func, backend='numpy', default=None): """Return ``{backend}.{func}``, e.g. ``numpy.einsum``, or a default func if provided. Cache result. """ try: return _cached_funcs[func, backend] except KeyError: fn = _import_func(func, backend, default) _cached_funcs[func, backend] = fn return fn # mark libs with einsum, else try to use tensordot/tranpose as much as possible _has_einsum = {} def has_einsum(backend): """Check if ``{backend}.einsum`` exists, cache result for performance. """ try: return _has_einsum[backend] except KeyError: try: get_func('einsum', backend) _has_einsum[backend] = True except AttributeError: _has_einsum[backend] = False return _has_einsum[backend] _has_tensordot = {} def has_tensordot(backend): """Check if ``{backend}.tensordot`` exists, cache result for performance. """ try: return _has_tensordot[backend] except KeyError: try: get_func('tensordot', backend) _has_tensordot[backend] = True except AttributeError: _has_tensordot[backend] = False return _has_tensordot[backend] # Dispatch to correct expression backend # these are the backends which support explicit to-and-from numpy conversion CONVERT_BACKENDS = { 'tensorflow': _tensorflow.build_expression, 'theano': _theano.build_expression, 'cupy': _cupy.build_expression, 'torch': _torch.build_expression, 'jax': _jax.build_expression, } EVAL_CONSTS_BACKENDS = { 'tensorflow': _tensorflow.evaluate_constants, 'theano': _theano.evaluate_constants, 'cupy': _cupy.evaluate_constants, 'torch': _torch.evaluate_constants, 'jax': _jax.evaluate_constants, } def build_expression(backend, arrays, expr): """Build an expression, based on ``expr`` and initial arrays ``arrays``, that evaluates using backend ``backend``. """ return CONVERT_BACKENDS[backend](arrays, expr) def evaluate_constants(backend, arrays, expr): """Convert constant arrays to the correct backend, and perform as much of the contraction of ``expr`` with these as possible. """ return EVAL_CONSTS_BACKENDS[backend](arrays, expr) def has_backend(backend): """Checks if the backend is known. """ return backend.lower() in CONVERT_BACKENDS