参考:
https://www.ngui.cc/el/507608.html?action=onClick
这里面简单回顾一下PyTorch 里面的两个常用的梯度自动计算的API
autoGrad 和 Backward, 最后结合 softmax 简单介绍一下一下应用场景。
目录:
1 autoGrad
2 Backward
3 softmax
一 autoGrad
输入
x
输出
 
损失函数
 
参数更新
 
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 13 21:28:26 2023
@author: cxf
"""
import torch
import torch.nn.functional as F
def grad():
    
    x = torch.tensor([[1.0,2.0]]).view(2,1)
    w = torch.full([2,1], 1.0,requires_grad= True)
    target = torch.ones((1,1))
    out = torch.matmul(w.T, x)
    print(out)
    mse = F.mse_loss(out, target)
    
    print("\n mse",mse)
    grad_w = torch.autograd.grad(mse,[w])    
    print(grad_w)
if __name__ == "__main__":
    
    grad()
 
    
   二 Backward
求梯度另一种方法,可以通过backward
在创建动态图后,直接调用backward,更加方便
import torch
import torch.nn.functional as F
def grad():
    
    x = torch.tensor([[1.0,2.0]]).view(2,1)
    w = torch.full([2,1], 1.0,requires_grad= True)
    target = torch.ones((1,1))
    out = torch.matmul(w.T, x)
    print(out)
    mse = F.mse_loss(out, target)
    
    print("\n mse",mse)
    mse.backward()   
    print(w.grad)
if __name__ == "__main__":
    
    grad() 
   三 softmax
多分类模型常用的激活函数
 
    
   这种模型通常用交叉熵做损失函数
 
 
 因为标签中只有一个为1,其它都为0,假设为
 
 
    
   则:
  (j=i)
(j=i)
  (
( )
)
则写成向量形式为
 
import torch
import torch.nn.functional as F
from torch import nn
#自己实现该梯度计算
def calcGrad(a,target):
    
    grad =a -target
    print("\n 直接计算",grad)
    # 直接计算 tensor([[ 0.0900, -0.7553,  0.6652]], grad_fn=<SubBackward0>)
#调用API 方式实现
def grad():
    CEL =  nn.CrossEntropyLoss()
    z = torch.tensor([[1.0,2.0,3.0]],requires_grad=True)
    a = F.softmax(z,dim=1)
    
    print("\n 神经元输出",a)
    target = torch.tensor([[0.0,1.0,0.0]])
    
    loss =CEL(z,target)
    loss.backward()
   
 
    print("\n API 计算",z.grad)
    # API 计算 tensor([[ 0.0900, -0.7553,  0.6652]])
    calcGrad(a,target)
if __name__ == "__main__":
    
    grad()这里面要注意nn.CrossEntropyLoss
是相当于对z 先做softmax,得到a, 然后再做交叉熵
![buu [UTCTF2020]basic-crypto 1](https://img-blog.csdnimg.cn/6011adb31fbd4c42ada2ccc3f5e37ae8.png)


















