This repository was archived by the owner on May 6, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 234
This repository was archived by the owner on May 6, 2025. It is now read-only.
Feature masks do not get reduced in the kernel #159
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I am observing an error message when providing masked inputs with more than one feature dimensions to a kernel that involves stax.GlobalAvgPool()
Reproducer:
import jax
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
if __name__ == '__main__':
# input tokens
X = 3*np.ones((10,512))
mask_constant = 10
pad_token = 0
# pad some elements
X = X.at[0,4].set(pad_token)
X = X.at[7,422].set(pad_token)
print('before encode ',X.shape)
# vocabulary size
n_vocab = 5
def encode(x, mask_constant):
# zero mean embeddings
res = jax.nn.one_hot(x, n_vocab)
res -= np.mean(res, axis=-1, keepdims=True)
return np.where(x[..., None] == pad_token, mask_constant, res)
X = encode(X, mask_constant=mask_constant)
print('after encode ', X.shape)
# trace over output correlations
_, _, kernel_fn_avg = stax.GlobalAvgPool()
input_fn = nt.batch(kernel_fn_avg, batch_size=2)
cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
print('output ', cov.shape)
Output
$ python mask_reproducer.py
Attempting to register factory for plugin cuBLAS when one has already been registered
before encode (10, 512)
after encode (10, 512, 5)
Traceback (most recent call last):
File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
_, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
return _f(x_or_kernel, *args_np, **kwargs_np)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/api.py", line 525, in cache_miss
out_flat = xla.xla_call(
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1919, in bind
return call_bind(self, fun, *args, **params)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 1935, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 199, in _xla_call_impl
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 295, in memoized_fun
ans = call(fun, *args)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 248, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/dispatch.py", line 293, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2167, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2117, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
return f(_x_or_kernel, *_args, **_kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
fn_out = fn(*canonicalized_args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
return kernel_fn_x1(x1_or_kernel, x2, get,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
out_kernel = kernel_fn(kernel, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
return _mask_fn(mask, input_shape)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
_check_is_implemented(mask, channel_axis)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
raise NotImplementedError(
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/gpfs/alpine/bif136/world-shared/gpbind/mask_reproducer.py", line 32, in <module>
cov = input_fn(X, X, 'nngp', mask_constant=mask_constant, diagonal_spatial=True)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
_, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 387, in row_fn
return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 396, in col_fn
return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
return _f(x_or_kernel, *args_np, **kwargs_np)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/batching.py", line 751, in _f
return f(_x_or_kernel, *_args, **_kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 222, in getter_fn
fn_out = fn(*canonicalized_args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 1008, in kernel_fn_any
return kernel_fn_x1(x1_or_kernel, x2, get,
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 921, in kernel_fn_x1
out_kernel = kernel_fn(kernel, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 222, in kernel_fn_with_masking
mask1, mask2 = mask_fn(mask1, shape1), mask_fn(mask2, shape2)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py", line 188, in mask_fn
return _mask_fn(mask, input_shape)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 1756, in mask_fn
_check_is_implemented(mask, channel_axis)
File "/gpfs/alpine/world-shared/bif136/jax_env_summit_clone/lib/python3.10/site-packages/neural_tangents/_src/stax/linear.py", line 3621, in _check_is_implemented
raise NotImplementedError(
NotImplementedError: Different channel-wise masks as inputs to pooling layers are not yet supported. Please let us know about your use case at https://github.com/google/neural-tangents/issues/new
Expected output (with the fix from #158):
before encode (10, 512)
after encode (10, 512, 5)
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
/gpfs/alpine/world-shared/bif136/jax_env_summit/lib/python3.10/site-packages/neural_tangents/_src/stax/requirements.py:769: UserWarning: Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.
warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.")
output (10, 10)
I admit that the warning is a little noisy, perhaps it could be omitted and the reduction mentioned in the documentation.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working