Source code for pyadjoint_utils.numpy_backend.backend
from .base import NumpyBackend
from .jax import JaxNumpyBackend
from .torch import TorchBackend
from typing import Optional
_valid_backends = {
"jax": lambda: JaxNumpyBackend(),
"torch": lambda: TorchBackend(),
"numpy": lambda: NumpyBackend(),
}
_live_backends = {}
[docs]def get_backend(which: Optional[str] = None):
"""Returns a numpy backend corresponding to the string
`'numpy'`, `'jax'`, `'torch'`, or the default backend with `None`
:param which: Which backend to set, defaults to `None`
:type which: str, optional
:return: The default numpy backend
:rtype NumpyBackend:
"""
global _valid_backends, _live_backends
if which is None:
return get_default_backend()
if which in _live_backends:
return _live_backends[which]
if which in _valid_backends:
backend = _valid_backends[which]()
_live_backends[which] = backend
return backend
raise ValueError(
"backend "
+ which
+ ' not available! Valid backends are "jax", "torch", and "numpy"'
)
_default_backend = get_backend("jax")
[docs]def set_default_backend(which: str = "jax"):
"""Sets the current default backend to one of 'jax', 'torch', or 'numpy'
:param which: Which backend to set
:type which: str
:returns: The backend you set
:rtype: NumpyBackend
"""
global _default_backend
_default_backend = get_backend(which)
return _default_backend
[docs]def get_default_backend():
"""Returns the default CRIKit numpy backend
:return: The default numpy backend
:rtype NumpyBackend:
"""
global _default_backend
return _default_backend