一文看明白PyTorch 模型设计训练保存加载预测
需求输入x128维fc1Linear 128→96ReLU激活Dropout 0.2fc2Linear 96→64ReLU激活Dropout 0.2fc3Linear 64→32输出out32维代码样例包含训练 → 保存 → 加载 → 预测代码可以直接运行importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoader,TensorDataset# -----------------------------# 1. 定义模型# -----------------------------classSimpleModel(nn.Module):def__init__(self):super(SimpleModel,self).__init__()self.fc1nn.Linear(128,96)self.fc2nn.Linear(96,64)self.fc3nn.Linear(64,32)self.relunn.ReLU()self.dropoutnn.Dropout(0.2)defforward(self,x):xself.relu(self.fc1(x))xself.dropout(x)xself.relu(self.fc2(x))xself.dropout(x)outself.fc3(x)returnout# -----------------------------# 2. 准备数据 (示例随机数据)# -----------------------------Xtorch.randn(1000,128)ytorch.randn(1000,32)datasetTensorDataset(X,y)batch_size32dataloaderDataLoader(dataset,batch_sizebatch_size,shuffleTrue)# -----------------------------# 3. 定义损失函数和优化器# MSELoss Mean Squared Error均方误差# -----------------------------modelSimpleModel()criterionnn.MSELoss()optimizeroptim.Adam(model.parameters(),lr0.001)# -----------------------------# 4. 训练循环# -----------------------------num_epochs20forepochinrange(num_epochs):model.train()# 训练模式epoch_loss0forbatch_X,batch_yindataloader:optimizer.zero_grad()outputsmodel(batch_X)losscriterion(outputs,batch_y)loss.backward()optimizer.step()epoch_lossloss.item()*batch_X.size(0)epoch_loss/len(dataset)print(fEpoch{epoch1}/{num_epochs}, Loss:{epoch_loss:.4f})# -----------------------------# 5. 保存训练好的模型参数# -----------------------------torch.save(model.state_dict(),simple_model.pth)print(模型参数已保存到 simple_model.pth)# -----------------------------# 6. 加载模型进行预测# -----------------------------# 重新创建模型对象model_loadedSimpleModel()# 加载保存的参数model_loaded.load_state_dict(torch.load(simple_model.pth))# 切换到评估模式model_loaded.eval()# 假设有新样本 x_newx_newtorch.randn(5,128)withtorch.no_grad():# 推理时禁用梯度y_predmodel_loaded(x_new)print(加载模型预测结果形状:,y_pred.shape)# [5, 32]✅ 特点训练完成后保存权重simple_model.pth可以随时加载。加载模型时必须重新创建类然后load_state_dict。推理时切换到eval()模式保证 Dropout 不随机失活。使用torch.no_grad()提升预测效率减少显存占用。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2637873.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!