TensorFlow加载参数时路径怎么写?

99ANYc3cd6
预计阅读时长 17 分钟
位置: 首页 参数 正文
  1. 加载由 tf.keras 训练并保存的模型(最常见、最推荐的方式)
  2. 加载由原生 TensorFlow 或 TensorFlow 1.x 保存的检查点

下面我将详细讲解这两种方法。

tensorflow 加载参数
(图片来源网络,侵删)

加载由 tf.keras 保存的模型 (推荐)

当你使用 tf.keras 构建和训练模型时,最佳实践是使用 model.save() 方法,这个方法不仅保存了模型的权重,还保存了模型的结构、配置和优化器的状态。

保存模型

在训练完成后,你可以用以下方式保存整个模型:

import tensorflow as tf
# 1. 创建并编译一个简单的模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# (这里省略训练过程 model.fit(...))
# 2. 保存整个模型到单个文件
# 这会保存模型架构、权重、训练配置(优化器、损失等)
model.save('my_keras_model.keras')
# 或者保存为 TensorFlow SavedModel 格式(更通用,推荐用于部署)
# model.save('my_keras_model_tf')

加载模型

加载模型非常简单,只需一行代码:

# 方法 A: 加载 .keras 格式 (HDF5)
# 这是最直接的方式
loaded_model = tf.keras.models.load_model('my_keras_model.keras')
# 方法 B: 加载 TensorFlow SavedModel 格式
# loaded_model = tf.keras.models.load_model('my_keras_model_tf')
# 3. 验证加载的模型
# 检查模型结构
loaded_model.summary()
# 检查模型是否能正常工作
# 假设我们有一些测试数据
import numpy as np
test_data = np.random.rand(1, 784) # 创建一个随机样本
prediction = loaded_model.predict(test_data)
print("Prediction:", prediction)

优点:

  • 极其方便:一个函数 load_model() 搞定所有事情。
  • 完整性:保存了模型的所有信息,无需重新定义模型结构。
  • 鲁棒性tf.keras 会自动处理版本兼容性问题。

加载原生 TensorFlow 检查点

在某些情况下,你可能不使用 tf.keras,或者你想在训练过程中定期保存检查点以便从中断处恢复,这时你会使用 tf.train.Checkpoint

保存检查点

你需要明确指定要保存哪些对象(通常是模型、优化器等)。

import tensorflow as tf
# 1. 创建模型和优化器
# 注意:这里使用的是原生 TensorFlow 层,而不是 Keras 层
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
optimizer = tf.keras.optimizers.Adam()
# 2. 创建一个检查点对象
# 你需要告诉它要保存哪些对象
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
# 3. 定义检查点管理器(可选但推荐)
# 它可以帮助你管理多个检查点文件,只保留最新的 N 个
checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory='./checkpoints', max_to_keep=3)
# (这里省略训练过程)
# 在训练循环中,定期保存
# for epoch in range(epochs):
#     # ... 训练步骤 ...
#     if epoch % 5 == 0:
#         save_path = checkpoint_manager.save()
#         print(f"Saved checkpoint at {save_path}")

加载检查点

加载检查点时,你需要先创建一个完全相同结构的模型和优化器实例,然后调用 checkpoint.restore()

# 1. 重新创建一个完全相同的模型和优化器结构
# 这一步至关重要!模型的结构必须和保存时一模一样。
new_model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
new_optimizer = tf.keras.optimizers.Adam()
# 2. 创建一个新的检查点对象,指向相同的目录
new_checkpoint = tf.train.Checkpoint(model=new_model, optimizer=new_optimizer)
new_checkpoint_manager = tf.train.CheckpointManager(new_checkpoint, directory='./checkpoints', max_to_keep=3)
# 3. 恢复最新的检查点
# restore() 方法会返回一个状态对象,你可以用它来检查是否加载成功
status = new_checkpoint_manager.restore(new_checkpoint_manager.latest_checkpoint)
if status.assert_existing_objects_matched():
    print("检查点加载成功!模型权重已恢复。")
else:
    print("警告:检查点中的某些对象在当前模型中找不到对应项。")
# 4. 验证模型
# 模型现在已经加载了权重,可以用来进行预测
new_model.summary()
test_data = np.random.rand(1, 784)
prediction = new_model.predict(test_data)
print("Prediction from restored model:", prediction)

关键点:

  • 结构必须匹配:加载时创建的模型结构(层数、类型、形状等)必须与保存时完全一致,否则,TensorFlow 无法将权重张量映射到模型的变量上。
  • 显式指定对象:在创建 tf.train.Checkpoint 时,必须显式地将需要保存和恢复的对象(如 model, optimizer)作为参数传入。
  • 检查状态restore() 返回一个 tf.train.CheckpointStatus 对象,可以用来验证加载是否成功。

总结与对比

特性 tf.keras.models.save_model / load_model tf.train.Checkpoint
适用场景 Keras 模型的完整保存与加载,生产部署,迁移学习 原生 TensorFlow 训练,训练过程中的断点续训
模型架构、权重、优化器状态、训练配置 你指定的对象的权重和状态(通常是模型和优化器)
加载方式 tf.keras.models.load_model('path') 创建相同结构的模型
创建 Checkpoint 对象
调用 restore()
优点 简单、方便、完整,是 Keras 的标准实践 灵活,可以保存训练过程中的任何状态
缺点 对于非 Keras 模型不适用 需要手动管理模型结构,稍显繁琐

最佳实践建议

  1. 优先使用 tf.keras:如果你的模型是用 Keras 构建的,请始终使用 model.save()tf.keras.models.load_model(),这是最简单、最可靠、最不容易出错的方法。
  2. 仅在必要时使用 tf.train.Checkpoint:当你需要非常精细地控制训练过程,或者使用非 Keras 的原生 TensorFlow API 时,才使用 tf.train.Checkpoint,对于绝大多数深度学习任务,Keras 已经足够强大和方便。
  3. 保存路径:将模型文件和检查点文件保存在一个专门的目录中,便于管理。
  4. 版本控制:对于重要的模型,考虑将 .keras 文件或 SavedModel 目录放入你的版本控制系统(如 Git),对于非常大的模型,可以只保存模型架构文件(JSON/YAML)和权重文件(HDF5),但 SavedModel 格式通常是更好的选择。
-- 展开阅读全文 --
头像
三星智能电视怎样装泰捷
« 上一篇 02-03
lifebook u772 拆机
下一篇 » 02-03

相关文章

取消
微信二维码
支付宝二维码

最近发表

标签列表

目录[+]