|
11 | 11 | from keras.src import models
|
12 | 12 | from keras.src import saving
|
13 | 13 | from keras.src import testing
|
| 14 | +from keras.src.backend.torch.core import get_device |
14 | 15 | from keras.src.utils.torch_utils import TorchModuleWrapper
|
15 | 16 |
|
16 | 17 |
|
@@ -246,3 +247,27 @@ def test_build_model(self):
|
246 | 247 | model = keras.Model(x, y)
|
247 | 248 | self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16))
|
248 | 249 | self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16))
|
| 250 | + |
| 251 | + def test_save_load(self): |
| 252 | + @keras.saving.register_keras_serializable() |
| 253 | + class M(keras.Model): |
| 254 | + def __init__(self, channels=10, **kwargs): |
| 255 | + super().__init__() |
| 256 | + self.sequence = torch.nn.Sequential( |
| 257 | + torch.nn.Conv2d(1, channels, kernel_size=(3, 3)), |
| 258 | + ) |
| 259 | + |
| 260 | + def call(self, x): |
| 261 | + return self.sequence(x) |
| 262 | + |
| 263 | + m = M() |
| 264 | + device = get_device() # Get the current device (e.g., "cuda" or "cpu") |
| 265 | + x = torch.ones( |
| 266 | + (10, 1, 28, 28), device=device |
| 267 | + ) # Place input on the correct device |
| 268 | + m(x) |
| 269 | + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.keras") |
| 270 | + m.save(temp_filepath) |
| 271 | + new_model = saving.load_model(temp_filepath) |
| 272 | + for ref_w, new_w in zip(m.get_weights(), new_model.get_weights()): |
| 273 | + self.assertAllClose(ref_w, new_w, atol=1e-5) |
0 commit comments