Skip to content

Commit 85b7866

Browse files
update wrapped_generate_function to be more general
1 parent b41265a commit 85b7866

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def ov_infer(inputs, stop_token_ids, fn):
179179
return get_outputs(inputs, struct_outputs, compile_ov_model)
180180

181181
def wrapped_generate_function(inputs, stop_token_ids=None):
182-
for k, v in inputs.items():
183-
inputs[k] = ops.array(v)
182+
inputs = tree.map_structure(ops.array, inputs)
184183
return ov_infer(inputs, stop_token_ids, self.generate_step)
185184

186185
self.generate_function = wrapped_generate_function

0 commit comments

Comments
 (0)