When using train_on_batch for training and test_on_batch for validation, the reported loss is not deterministic after loading a saved model.
Even though model weights and optimizer state are correctly restored, test_on_batch returns a cumulative metric state, which can lead to misleading loss values in custom training loops.
import numpy as np
import keras
import tensorflow as tf
### --- Minimal model ---
model = keras.Sequential([
keras.layers.LSTM(32, input_shape=(10, 4)),
keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')
### --- Dummy data ---
X = np.random.randn(100, 10, 4).astype("float32")
y = np.random.randn(100, 1).astype("float32")
X_batch = X[:32]
y_batch = y[:32]
### --- Train for 5 batches to accumulate metric state ---
for _ in range(5):
model.train_on_batch(X_batch, y_batch)
### --- Save the model ---
model.save("bug_repro.keras")
### --- BEFORE load_model ---
loss_before = model.test_on_batch(X_batch, y_batch)
y_pred_before = model(X_batch, training=False)
mse_before = float(tf.reduce_mean(tf.square(y_pred_before - y_batch)))
print(f"BEFORE load_model | test_on_batch : {loss_before:.6f}")
print(f"BEFORE load_model | manual MSE : {mse_before:.6f}")
### --- load_model creates a new model object with a fresh metric state ---
loaded_model = keras.models.load_model("bug_repro.keras")
### --- AFTER load_model ---
loss_after = loaded_model.test_on_batch(X_batch, y_batch)
y_pred_after = loaded_model(X_batch, training=False)
mse_after = float(tf.reduce_mean(tf.square(y_pred_after - y_batch)))
print(f"\nAFTER load_model | test_on_batch : {loss_after:.6f}")
print(f"AFTER load_model | manual MSE : {mse_after:.6f}")
### --- Diagnosis ---
print(f"\n{'='*50}")
print(f"Weights identical : {all(np.allclose(w1.numpy(), w2.numpy()) for w1, w2 in zip(model.weights, loaded_model.weights))}")
print(f"Manual MSE identical : {abs(mse_before - mse_after) < 1e-6}")
print(f"test_on_batch identical : {abs(loss_before - loss_after) < 1e-6}")
#############################################################
BEFORE load_model | test_on_batch : 1.608474
BEFORE load_model | manual MSE : 1.573039
AFTER load_model | test_on_batch : 1.573039
AFTER load_model | manual MSE : 1.573039
==================================================
Weights identical : True
Manual MSE identical : True
test_on_batch identical : False
Expected behavior: test_on_batch should return the same value as manual MSE
Actual behavior: test_on_batch returns a cumulative average from the accumulated metric state, which is reset upon load_model().
This causes inconsistent loss reporting when resuming training after load_model(), even though model weights
and optimizer state are correctly restored.
Environment
- Keras version: 3.13.0
- TensorFlow version: 2.19.0
- Python version: 3.11
- GPU: Yes, cuDNN 9.3
When using
train_on_batchfor training andtest_on_batchfor validation, the reported loss is not deterministic after loading a saved model.Even though model weights and optimizer state are correctly restored,
test_on_batchreturns a cumulative metric state, which can lead to misleading loss values in custom training loops.Expected behavior: test_on_batch should return the same value as manual MSE
Actual behavior: test_on_batch returns a cumulative average from the accumulated metric state, which is reset upon load_model().
This causes inconsistent loss reporting when resuming training after load_model(), even though model weights
and optimizer state are correctly restored.
Environment