@@ -1359,10 +1359,32 @@ def save_own_variables(self, store):
13591359 Args:
13601360 store: Dict where the state of the model will be saved.
13611361 """
1362+ if not getattr (self , "_is_quantized" , False ):
1363+ all_vars = self ._trainable_variables + self ._non_trainable_variables
1364+ for i , v in enumerate (all_vars ):
1365+ store [f"{ i } " ] = v
1366+ return
1367+
1368+ # Case: quantized layer
1369+ quantized_vars = self ._get_quantized_variables ()
1370+ for i , v in enumerate (quantized_vars ):
1371+ store [f"quantized_{ i } " ] = v
1372+
1373+ # Save non-quantized variables
13621374 all_vars = self ._trainable_variables + self ._non_trainable_variables
1363- for i , v in enumerate (all_vars ):
1375+ non_quantized_vars = [
1376+ v for v in all_vars if v not in quantized_vars and v .trainable
1377+ ]
1378+ for i , v in enumerate (non_quantized_vars ):
13641379 store [f"{ i } " ] = v
13651380
1381+ def _get_quantized_variables (self ):
1382+ quantized_vars = []
1383+ for v in self ._trainable_variables + self ._non_trainable_variables :
1384+ if not backend .is_float_dtype (v .dtype ):
1385+ quantized_vars .append (v )
1386+ return quantized_vars
1387+
13661388 def load_own_variables (self , store ):
13671389 """Loads the state of the layer.
13681390
@@ -1372,6 +1394,10 @@ def load_own_variables(self, store):
13721394 Args:
13731395 store: Dict from which the state of the model will be loaded.
13741396 """
1397+ if any (key .startswith ("quantized_" ) for key in store .keys ()):
1398+ self ._load_quantized_variables (store )
1399+ return
1400+
13751401 all_vars = self ._trainable_variables + self ._non_trainable_variables
13761402 if len (store .keys ()) != len (all_vars ):
13771403 if len (all_vars ) == 0 and not self .built :
@@ -1407,6 +1433,19 @@ def load_own_variables(self, store):
14071433 for i , v in enumerate (all_vars ):
14081434 v .assign (store [f"{ i } " ])
14091435
1436+ def _load_quantized_variables (self , store ):
1437+ quantized_vars = self ._get_quantized_variables ()
1438+ for i , v in enumerate (quantized_vars ):
1439+ v .assign (store [f"quantized_{ i } " ])
1440+
1441+ # Load non-quantized variables
1442+ all_vars = self ._trainable_variables + self ._non_trainable_variables
1443+ non_quantized_vars = [
1444+ v for v in all_vars if v not in quantized_vars and v .trainable
1445+ ]
1446+ for i , v in enumerate (non_quantized_vars ):
1447+ v .assign (store [f"{ i } " ])
1448+
14101449 def _track_variable (self , variable ):
14111450 if variable .trainable :
14121451 self ._tracker .add_to_store ("trainable_variables" , variable )
0 commit comments