阅读量:3
在Linux环境下,使用PyTorch保存和加载模型非常简单。以下是一个简单的例子来说明如何保存和加载一个PyTorch模型。
首先,我们需要导入所需的库并定义一个简单的模型:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
接下来,我们可以使用torch.save()函数将模型保存到文件中:
torch.save(model, 'model.pth')
现在,模型已经被保存到了名为model.pth的文件中。要加载模型,我们可以使用torch.load()函数:
loaded_model = torch.load('model.pth')
加载模型后,我们可以像使用原始模型一样使用它:
input_data = torch.randn(1, 10)
output = loaded_model(input_data)
注意:在加载模型时,确保你的环境中已经安装了与保存模型时相同的PyTorch版本。否则,可能会出现兼容性问题。
以上就是关于“Linux环境下PyTorch模型保存与加载”的相关介绍,筋斗云是国内较早的云主机应用的服务商,拥有10余年行业经验,提供丰富的云服务器、租用服务器等相关产品服务。云服务器资源弹性伸缩,主机vCPU、内存性能强悍、超高I/O速度、故障秒级恢复;电子化备案,提交快速,专业团队7×24小时服务支持!
简单好用、高性价比云服务器租用链接:https://www.jindouyun.cn/product/cvm