-
Notifications
You must be signed in to change notification settings - Fork 68
Description
I am trying to create a simple linear layer as follows,
from penzai import pz
import jax
embed_axis = "embed_axis"
head_axis = "head_axis"
num_heads = 4
embed_size = 10
layer = pz.nn.Linear.from_config(
name="layer_name",
init_base_rng=jax.random.key(42),
input_axes={embed_axis: embed_size},
output_axes={
head_axis: num_heads,
f"{embed_axis}/{head_axis}": embed_size//num_heads
},
initializer=jax.nn.initializers.xavier_normal(),
)
I am getting the error,
TypeError Traceback (most recent call last)
in <cell line: 0>()
7 embed_size = 10
8
----> 9 layer = pz.nn.Linear.from_config(
10 name="layer_name",
11 init_base_rng=jax.random.key(42),1 frames
/usr/local/lib/python3.11/dist-packages/penzai/nn/parameters.py in make_parameter(name, init_base_rng, initializer, metadata, *init_args, **init_kwargs)
110 metadata = {}
111 return variables.Parameter(
--> 112 value=initializer(
113 derive_param_key(init_base_rng, name), *init_args, **init_kwargs
114 ),TypeError: variance_scaling..init() got an unexpected keyword argument 'input_axes'
I don't see anything in the documentation that can explain the cause of this error.