Skip to content

test_on_batch/train_on_batch returns cumulative metric, causing inconsistent loss after load_model #22596

@coderW52

Description

@coderW52

When using train_on_batch for training and test_on_batch for validation, the reported loss is not deterministic after loading a saved model.
Even though model weights and optimizer state are correctly restored, test_on_batch returns a cumulative metric state, which can lead to misleading loss values in custom training loops.

import numpy as np
import keras
import tensorflow as tf

### --- Minimal model ---
model = keras.Sequential([
    keras.layers.LSTM(32, input_shape=(10, 4)),
    keras.layers.Dense(1)
])
model.compile(optimizer='adam', loss='mse')

### --- Dummy data ---
X = np.random.randn(100, 10, 4).astype("float32")
y = np.random.randn(100, 1).astype("float32")

X_batch = X[:32]
y_batch = y[:32]

### --- Train for 5 batches to accumulate metric state ---
for _ in range(5):
    model.train_on_batch(X_batch, y_batch)

### --- Save the model ---
model.save("bug_repro.keras")

### --- BEFORE load_model ---
loss_before = model.test_on_batch(X_batch, y_batch)
y_pred_before = model(X_batch, training=False)
mse_before = float(tf.reduce_mean(tf.square(y_pred_before - y_batch)))

print(f"BEFORE load_model | test_on_batch : {loss_before:.6f}")
print(f"BEFORE load_model | manual MSE    : {mse_before:.6f}")

### --- load_model creates a new model object with a fresh metric state ---
loaded_model = keras.models.load_model("bug_repro.keras")

### --- AFTER load_model ---
loss_after = loaded_model.test_on_batch(X_batch, y_batch)
y_pred_after = loaded_model(X_batch, training=False)
mse_after = float(tf.reduce_mean(tf.square(y_pred_after - y_batch)))

print(f"\nAFTER  load_model | test_on_batch : {loss_after:.6f}")
print(f"AFTER  load_model | manual MSE    : {mse_after:.6f}")

### --- Diagnosis ---
print(f"\n{'='*50}")
print(f"Weights identical          : {all(np.allclose(w1.numpy(), w2.numpy()) for w1, w2 in zip(model.weights, loaded_model.weights))}")
print(f"Manual MSE identical       : {abs(mse_before - mse_after) < 1e-6}")
print(f"test_on_batch identical    : {abs(loss_before - loss_after) < 1e-6}")

#############################################################
BEFORE load_model | test_on_batch : 1.608474
BEFORE load_model | manual MSE    : 1.573039

AFTER  load_model | test_on_batch : 1.573039
AFTER  load_model | manual MSE    : 1.573039

==================================================
Weights identical          : True
Manual MSE identical       : True
test_on_batch identical    : False

Expected behavior: test_on_batch should return the same value as manual MSE

Actual behavior: test_on_batch returns a cumulative average from the accumulated metric state, which is reset upon load_model().
This causes inconsistent loss reporting when resuming training after load_model(), even though model weights
and optimizer state are correctly restored.

Environment

  • Keras version: 3.13.0
  • TensorFlow version: 2.19.0
  • Python version: 3.11
  • GPU: Yes, cuDNN 9.3

Metadata

Metadata

Assignees

No one assigned

    Labels

    backend:tensorflowtype:supportUser is asking for help / asking an implementation question. Stackoverflow would be better suited.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions