迁移学习-预训练模型的保存与加载

1,模型保存和读取:

# 如果要保存最好的参数,使用: best_model_state = deepcopy(model.state_dict())
model_save_path = os.path.join('', 'model.pt')
torch.save(model.state_dict(), model_save_path)

# 模型参数读取
model = LeNet5()
model_save_path = os.path.join(model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
    loaded_paras = torch.load(model_save_path)
    model.load_state_dict(loaded_paras)
    
# 也可以保存优化器等:
# model_save_path = os.path.join(model_save_dir, 'model.pt')
# torch.save({
# 'epoch': epoch,
# 'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'loss': loss,
# }, model_save_path)

# 读取:
checkpoint = torch.load(model_save_path) 
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
epoch = checkpoint['epoch'] 5 loss = checkpoint['loss']

2,迁移学习

按照上面的方法对模型进行保存和读取,当迁移的模型部分不同时,可以根据参数名称和大小,选择性的保留读取进来的参数:

def para_state_dict(model, model_save_dir): 
state_dict = deepcopy(model.state_dict()) 
model_save_path = os.path.join(model_save_dir, 'model.pt') 
if os.path.exists(model_save_path): 
    loaded_paras = torch.load(model_save_path) 
    for key in state_dict: # 在新的网络模型中遍历对应参数 
        if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size(): 
        print("成功初始化参数:", key) 
        state_dict[key] = loaded_paras[key] 
return state_dict
 
全部评论

相关推荐

头像
点赞 评论 收藏
转发
点赞 收藏 评论
分享
牛客网
牛客企业服务