Skip to content

Commit c36b8e5

Browse files
optimizing the 'reusing model' logic
1 parent a35f4f5 commit c36b8e5

File tree

2 files changed

+16
-23
lines changed

2 files changed

+16
-23
lines changed

keras_hub/src/models/causal_lm.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -144,24 +144,9 @@ def ov_infer(inputs, stop_token_ids, fn):
144144
inputs, stop_token_ids, fn
145145
)
146146

147-
if not hasattr(ov_infer, "_compiled_models"):
148-
ov_infer._compiled_models = {}
149-
150-
# Create hash based on inputs, inputs shapes, and input dtypes
151-
inputs_shapes = []
152-
inputs_dtypes = []
153-
for k, v in inputs.items():
154-
inputs_shapes.append(str(v.shape))
155-
inputs_dtypes.append(str(v.dtype))
156-
model_signature = (
157-
f"inputs_{len(inputs)}_"
158-
f"shapes_{'_'.join(inputs_shapes)}_"
159-
f"dtypes_{'_'.join(inputs_dtypes)}_"
160-
)
161-
162-
model_hash = hash(model_signature)
147+
if not hasattr(ov_infer, "_compiled_model"):
148+
ov_infer._compiled_model = None
163149

164-
if model_hash not in ov_infer._compiled_models:
165150
parameters = [
166151
p.output.get_node() for p in tree.flatten(struct_params)
167152
]
@@ -171,11 +156,9 @@ def ov_infer(inputs, stop_token_ids, fn):
171156
]
172157

173158
ov_model = ov.Model(results=results, parameters=parameters)
174-
ov_infer._compiled_models[model_hash] = ov.compile_model(
175-
ov_model, "CPU"
176-
)
159+
ov_infer._compiled_model = ov.compile_model(ov_model, "CPU")
177160

178-
compile_ov_model = ov_infer._compiled_models[model_hash]
161+
compile_ov_model = ov_infer._compiled_model
179162
return get_outputs(inputs, struct_outputs, compile_ov_model)
180163

181164
def wrapped_generate_function(inputs, stop_token_ids=None):

keras_hub/src/utils/openvino_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def parameterize_inputs(inputs):
3737
return {k: parameterize_inputs(v) for k, v in inputs.items()}
3838
elif isinstance(inputs, np.ndarray):
3939
ov_type = OPENVINO_DTYPES[str(inputs.dtype)]
40-
ov_shape = list(inputs.shape)
40+
ov_shape = [-1] * len(inputs.shape)
4141
param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)
4242
return ops.convert_to_tensor(param.output(0))
4343
elif isinstance(inputs, (int, np.integer)):
@@ -50,9 +50,19 @@ def parameterize_inputs(inputs):
5050
raise TypeError(f"Unknown input type: {type(inputs)}")
5151

5252

53+
def reshape_params(params, inputs):
54+
if isinstance(params, (list, tuple)):
55+
return [reshape_params(p, i) for p, i in zip(params, inputs)]
56+
elif isinstance(params, dict):
57+
return {k: reshape_params(params[k], inputs[k]) for k in params}
58+
else:
59+
return ops.reshape(params, inputs.shape)
60+
61+
5362
def get_struct_outputs(inputs, stop_token_ids, fn):
5463
struct_params = parameterize_inputs(inputs)
55-
struct_outputs = fn(struct_params, stop_token_ids)
64+
struct_params_reshaped = reshape_params(struct_params, inputs)
65+
struct_outputs = fn(struct_params_reshaped, stop_token_ids)
5666
return struct_params, struct_outputs
5767

5868

0 commit comments

Comments
 (0)