Skip to content

序列化语义

保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)只保存和加载模型参数:

torch.save(the_model.state_dict(), PATH) 

然后:

the_model = TheModelClass(*args, **kwargs)
the_mdel.load_state_dict(torch.load(PATH)) 

第二个方法是保存并加载整个模型:

torch.save(the_model, PATH) 

然后:

the_model = torch.load(PATH) 

然而,在这种情况下,序列化的数据会与特定的类结构和准确的目录结构相绑定,所以在其他项目中使用或经大量重构之后,这些结构可能会以各种方式被破坏。



回到顶部