Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce wrapper eager_function #41

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eagerpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __getitem__(self, index: _T) -> _T:
from .astensor import astensors # noqa: F401,E402
from .astensor import astensor_ # noqa: F401,E402
from .astensor import astensors_ # noqa: F401,E402
from .astensor import eager_function # noqa: F401,E402

from .modules import torch # noqa: F401,E402
from .modules import tensorflow # noqa: F401,E402
Expand Down
98 changes: 94 additions & 4 deletions eagerpy/astensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from typing import TYPE_CHECKING, Union, overload, Tuple, TypeVar, Generic, Any
import functools
from typing import (
TYPE_CHECKING,
Union,
overload,
Tuple,
TypeVar,
Generic,
Any,
Callable,
)
import sys

from jax import tree_flatten, tree_unflatten

from .tensor import Tensor
from .tensor import TensorType

Expand Down Expand Up @@ -36,7 +48,7 @@ def astensor(x: NativeTensor) -> Tensor: # type: ignore
...


def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore
def astensor(x: Union[NativeTensor, Tensor, Any]) -> Union[Tensor, Any]: # type: ignore
if isinstance(x, Tensor):
return x
# we use the module name instead of isinstance
Expand All @@ -52,7 +64,9 @@ def astensor(x: Union[NativeTensor, Tensor]) -> Tensor: # type: ignore
return JAXTensor(x)
if name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore
return NumPyTensor(x)
raise ValueError(f"Unknown type: {type(x)}")

# non Tensor types are returned unmodified
return x


def astensors(*xs: Union[NativeTensor, Tensor]) -> Tuple[Tensor, ...]: # type: ignore
Expand Down Expand Up @@ -84,7 +98,7 @@ def __call__(self, *args: Any) -> Any:
...

def __call__(self, *args): # type: ignore # noqa: F811
result = tuple(x.raw for x in args) if self.unwrap else args
result = tuple(as_raw_tensor(x) for x in args) if self.unwrap else args
if len(result) == 1:
(result,) = result
return result
Expand All @@ -96,3 +110,79 @@ def astensor_(x: T) -> Tuple[Tensor, RestoreTypeFunc[T]]:

def astensors_(x: T, *xs: T) -> Tuple[Tuple[Tensor, ...], RestoreTypeFunc[T]]:
return astensors(x, *xs), RestoreTypeFunc[T](x)


def as_tensors(data: Any) -> Any:
leaf_values, tree_def = tree_flatten(data)
leaf_values = tuple(astensor(value) for value in leaf_values)
return tree_unflatten(tree_def, leaf_values)


def has_tensor(tree_def: Any) -> bool:
return "<class 'eagerpy.tensor" in str(tree_def)


def as_tensors_any(data: Any) -> Tuple[Any, bool]:
"""Convert data structure leaves in Tensor and detect if any of the input data contains a Tensor.

Parameters
----------
data
data structure.

Returns
-------
Any
modified data structure.
bool
True if input data contains a Tensor type.
"""
leaf_values, tree_def = tree_flatten(data)
transformed_leaf_values = tuple(astensor(value) for value in leaf_values)
return tree_unflatten(tree_def, transformed_leaf_values), has_tensor(tree_def)


def as_raw_tensor(x: T) -> Any:
if isinstance(x, Tensor):
return x.raw
else:
return x


def as_raw_tensors(data: Any) -> Any:
leaf_values, tree_def = tree_flatten(data)

if not has_tensor(tree_def):
return data

leaf_values = tuple(as_raw_tensor(value) for value in leaf_values)
unwrap_leaf_values = []
for x in leaf_values:
name = _get_module_name(x)
m = sys.modules
if name == "torch" and isinstance(x, m[name].Tensor): # type: ignore
unwrap_leaf_values.append((x, True))
elif name == "tensorflow" and isinstance(x, m[name].Tensor): # type: ignore
unwrap_leaf_values.append((x, True))
elif (name == "jax" or name == "jaxlib") and isinstance(x, m["jax"].numpy.ndarray): # type: ignore
unwrap_leaf_values.append((x, True))
elif name == "numpy" and isinstance(x, m[name].ndarray): # type: ignore
unwrap_leaf_values.append((x, True))
else:
unwrap_leaf_values.append(x)
return tree_unflatten(tree_def, unwrap_leaf_values)


def eager_function(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def eager_func(*args: Any, **kwargs: Any) -> Any:
(args, kwargs), has_tensor = as_tensors_any((args, kwargs))
unwrap = not has_tensor
result = func(*args, **kwargs)
if unwrap:
raw_result = as_raw_tensors(result)
return raw_result
else:
return result

return eager_func
29 changes: 28 additions & 1 deletion eagerpy/tensor/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing_extensions import final
from typing import Any, cast
from typing import Any, Type, Union, cast, Tuple

from .tensor import Tensor
from .tensor import TensorType
Expand All @@ -17,6 +17,33 @@ def unwrap1(t: Any) -> Any:
class BaseTensor(Tensor):
__slots__ = "_raw"

_registered = False

def __new__(cls: Type["BaseTensor"], *args: Any, **kwargs: Any) -> "BaseTensor":
if not cls._registered:
import jax

def flatten(t: Tensor) -> Tuple[Any, None]:
return ((t.raw,), None)

def unflatten(aux_data: None, children: Tuple) -> Union[Tensor, Any]:
assert len(children) == 1
x = children[0]
del children

if isinstance(x, tuple):
x, unwrap = x
if unwrap:
return x

if isinstance(x, Tensor):
return x
return cls(x)

jax.tree_util.register_pytree_node(cls, flatten, unflatten)
cls._registered = True
return cast("BaseTensor", super().__new__(cls))

def __init__(self: TensorType, raw: Any):
assert not isinstance(raw, Tensor)
self._raw = raw
Expand Down
1 change: 0 additions & 1 deletion eagerpy/tensor/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .tensor import Tensor


T = TypeVar("T")


Expand Down
76 changes: 19 additions & 57 deletions eagerpy/tensor/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
Optional,
overload,
Callable,
Type,
)

from typing_extensions import Literal
from importlib import import_module
import numpy as onp
Expand Down Expand Up @@ -58,27 +58,11 @@ def getitem_preprocess(x: Any) -> Any:

class JAXTensor(BaseTensor):
__slots__ = ()

# more specific types for the extensions
norms: "NormsMethods[JAXTensor]"

_registered = False
key = None

def __new__(cls: Type["JAXTensor"], *args: Any, **kwargs: Any) -> "JAXTensor":
if not cls._registered:
import jax

def flatten(t: JAXTensor) -> Tuple[Any, None]:
return ((t.raw,), None)

def unflatten(aux_data: None, children: Tuple) -> JAXTensor:
return cls(*children)

jax.tree_util.register_pytree_node(cls, flatten, unflatten)
cls._registered = True
return cast(JAXTensor, super().__new__(cls))

def __init__(self, raw: "np.ndarray"): # type: ignore
global jax
global np
Expand Down Expand Up @@ -434,46 +418,24 @@ def _value_and_grad_fn(
def _value_and_grad_fn( # noqa: F811 (waiting for pyflakes > 2.1.1)
self: TensorType, f: Callable, has_aux: bool = False
) -> Callable[..., Tuple]:
# f takes and returns JAXTensor instances
# jax.value_and_grad accepts functions that take JAXTensor instances
# because we registered JAXTensor as JAX type, but it still requires
# the output to be a scalar (that is not not wrapped as a JAXTensor)

# f_jax is like f but unwraps loss
if has_aux:

def f_jax(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
loss, aux = f(*args, **kwargs)
return loss.raw, aux

else:

def f_jax(*args: Any, **kwargs: Any) -> Any: # type: ignore
loss = f(*args, **kwargs)
return loss.raw

value_and_grad_jax = jax.value_and_grad(f_jax, has_aux=has_aux)

# value_and_grad is like value_and_grad_jax but wraps loss
if has_aux:

def value_and_grad(
x: JAXTensor, *args: Any, **kwargs: Any
) -> Tuple[JAXTensor, Any, JAXTensor]:
assert isinstance(x, JAXTensor)
(loss, aux), grad = value_and_grad_jax(x, *args, **kwargs)
assert grad.shape == x.shape
return JAXTensor(loss), aux, grad

else:

def value_and_grad( # type: ignore
x: JAXTensor, *args: Any, **kwargs: Any
) -> Tuple[JAXTensor, JAXTensor]:
assert isinstance(x, JAXTensor)
loss, grad = value_and_grad_jax(x, *args, **kwargs)
assert grad.shape == x.shape
return JAXTensor(loss), grad
from eagerpy.astensor import as_tensors, as_raw_tensors

def value_and_grad(
x: JAXTensor, *args: Any, **kwargs: Any
) -> Union[Tuple[JAXTensor, JAXTensor], Tuple[JAXTensor, Any, JAXTensor]]:
assert isinstance(x, JAXTensor)
x, args, kwargs = as_raw_tensors((x, args, kwargs))

loss_aux, grad = jax.value_and_grad(f, has_aux=has_aux)(x, *args, **kwargs)
assert grad.shape == x.shape
loss_aux, grad = as_tensors((loss_aux, grad))

if has_aux:
loss, aux = loss_aux
return loss, aux, grad
else:
loss = loss_aux
return loss, grad

return value_and_grad

Expand Down
Loading