- 加载由
tf.keras训练并保存的模型(最常见、最推荐的方式) - 加载由原生 TensorFlow 或 TensorFlow 1.x 保存的检查点
下面我将详细讲解这两种方法。

(图片来源网络,侵删)
加载由 tf.keras 保存的模型 (推荐)
当你使用 tf.keras 构建和训练模型时,最佳实践是使用 model.save() 方法,这个方法不仅保存了模型的权重,还保存了模型的结构、配置和优化器的状态。
保存模型
在训练完成后,你可以用以下方式保存整个模型:
import tensorflow as tf
# 1. 创建并编译一个简单的模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# (这里省略训练过程 model.fit(...))
# 2. 保存整个模型到单个文件
# 这会保存模型架构、权重、训练配置(优化器、损失等)
model.save('my_keras_model.keras')
# 或者保存为 TensorFlow SavedModel 格式(更通用,推荐用于部署)
# model.save('my_keras_model_tf')
加载模型
加载模型非常简单,只需一行代码:
# 方法 A: 加载 .keras 格式 (HDF5)
# 这是最直接的方式
loaded_model = tf.keras.models.load_model('my_keras_model.keras')
# 方法 B: 加载 TensorFlow SavedModel 格式
# loaded_model = tf.keras.models.load_model('my_keras_model_tf')
# 3. 验证加载的模型
# 检查模型结构
loaded_model.summary()
# 检查模型是否能正常工作
# 假设我们有一些测试数据
import numpy as np
test_data = np.random.rand(1, 784) # 创建一个随机样本
prediction = loaded_model.predict(test_data)
print("Prediction:", prediction)
优点:
- 极其方便:一个函数
load_model()搞定所有事情。 - 完整性:保存了模型的所有信息,无需重新定义模型结构。
- 鲁棒性:
tf.keras会自动处理版本兼容性问题。
加载原生 TensorFlow 检查点
在某些情况下,你可能不使用 tf.keras,或者你想在训练过程中定期保存检查点以便从中断处恢复,这时你会使用 tf.train.Checkpoint。
保存检查点
你需要明确指定要保存哪些对象(通常是模型、优化器等)。
import tensorflow as tf
# 1. 创建模型和优化器
# 注意:这里使用的是原生 TensorFlow 层,而不是 Keras 层
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
optimizer = tf.keras.optimizers.Adam()
# 2. 创建一个检查点对象
# 你需要告诉它要保存哪些对象
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
# 3. 定义检查点管理器(可选但推荐)
# 它可以帮助你管理多个检查点文件,只保留最新的 N 个
checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory='./checkpoints', max_to_keep=3)
# (这里省略训练过程)
# 在训练循环中,定期保存
# for epoch in range(epochs):
# # ... 训练步骤 ...
# if epoch % 5 == 0:
# save_path = checkpoint_manager.save()
# print(f"Saved checkpoint at {save_path}")
加载检查点
加载检查点时,你需要先创建一个完全相同结构的模型和优化器实例,然后调用 checkpoint.restore()。
# 1. 重新创建一个完全相同的模型和优化器结构
# 这一步至关重要!模型的结构必须和保存时一模一样。
new_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
new_optimizer = tf.keras.optimizers.Adam()
# 2. 创建一个新的检查点对象,指向相同的目录
new_checkpoint = tf.train.Checkpoint(model=new_model, optimizer=new_optimizer)
new_checkpoint_manager = tf.train.CheckpointManager(new_checkpoint, directory='./checkpoints', max_to_keep=3)
# 3. 恢复最新的检查点
# restore() 方法会返回一个状态对象,你可以用它来检查是否加载成功
status = new_checkpoint_manager.restore(new_checkpoint_manager.latest_checkpoint)
if status.assert_existing_objects_matched():
print("检查点加载成功!模型权重已恢复。")
else:
print("警告:检查点中的某些对象在当前模型中找不到对应项。")
# 4. 验证模型
# 模型现在已经加载了权重,可以用来进行预测
new_model.summary()
test_data = np.random.rand(1, 784)
prediction = new_model.predict(test_data)
print("Prediction from restored model:", prediction)
关键点:
- 结构必须匹配:加载时创建的模型结构(层数、类型、形状等)必须与保存时完全一致,否则,TensorFlow 无法将权重张量映射到模型的变量上。
- 显式指定对象:在创建
tf.train.Checkpoint时,必须显式地将需要保存和恢复的对象(如model,optimizer)作为参数传入。 - 检查状态:
restore()返回一个tf.train.CheckpointStatus对象,可以用来验证加载是否成功。
总结与对比
| 特性 | tf.keras.models.save_model / load_model |
tf.train.Checkpoint |
|---|---|---|
| 适用场景 | Keras 模型的完整保存与加载,生产部署,迁移学习 | 原生 TensorFlow 训练,训练过程中的断点续训 |
| 模型架构、权重、优化器状态、训练配置 | 你指定的对象的权重和状态(通常是模型和优化器) | |
| 加载方式 | tf.keras.models.load_model('path') |
创建相同结构的模型 创建 Checkpoint 对象调用 restore() |
| 优点 | 简单、方便、完整,是 Keras 的标准实践 | 灵活,可以保存训练过程中的任何状态 |
| 缺点 | 对于非 Keras 模型不适用 | 需要手动管理模型结构,稍显繁琐 |
最佳实践建议
- 优先使用
tf.keras:如果你的模型是用 Keras 构建的,请始终使用model.save()和tf.keras.models.load_model(),这是最简单、最可靠、最不容易出错的方法。 - 仅在必要时使用
tf.train.Checkpoint:当你需要非常精细地控制训练过程,或者使用非 Keras 的原生 TensorFlow API 时,才使用tf.train.Checkpoint,对于绝大多数深度学习任务,Keras 已经足够强大和方便。 - 保存路径:将模型文件和检查点文件保存在一个专门的目录中,便于管理。
- 版本控制:对于重要的模型,考虑将
.keras文件或 SavedModel 目录放入你的版本控制系统(如 Git),对于非常大的模型,可以只保存模型架构文件(JSON/YAML)和权重文件(HDF5),但SavedModel格式通常是更好的选择。
