Source code for pyadjoint_utils.numpy_backend.backend

from .base import NumpyBackend
from .jax import JaxNumpyBackend
from .torch import TorchBackend
from typing import Optional

_valid_backends = {
    "jax": lambda: JaxNumpyBackend(),
    "torch": lambda: TorchBackend(),
    "numpy": lambda: NumpyBackend(),
}
_live_backends = {}


[docs]def get_backend(which: Optional[str] = None): """Returns a numpy backend corresponding to the string `'numpy'`, `'jax'`, `'torch'`, or the default backend with `None` :param which: Which backend to set, defaults to `None` :type which: str, optional :return: The default numpy backend :rtype NumpyBackend: """ global _valid_backends, _live_backends if which is None: return get_default_backend() if which in _live_backends: return _live_backends[which] if which in _valid_backends: backend = _valid_backends[which]() _live_backends[which] = backend return backend raise ValueError( "backend " + which + ' not available! Valid backends are "jax", "torch", and "numpy"' )
_default_backend = get_backend("jax")
[docs]def set_default_backend(which: str = "jax"): """Sets the current default backend to one of 'jax', 'torch', or 'numpy' :param which: Which backend to set :type which: str :returns: The backend you set :rtype: NumpyBackend """ global _default_backend _default_backend = get_backend(which) return _default_backend
[docs]def get_default_backend(): """Returns the default CRIKit numpy backend :return: The default numpy backend :rtype NumpyBackend: """ global _default_backend return _default_backend