自动求导
2025年1月18日大约 1 分钟
我们来逐步分析代码,并解释为什么 b.grad
是 [3.]
。
代码分析
import torch
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x) # a = w + x
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * b
b.retain_grad() # 保留 b 的梯度
y.backward() # 反向传播计算梯度
print(w.grad) # 打印 w 的梯度
print(b.grad) # 打印 b 的梯度
1. 计算图
a = w + x
:a
是w
和x
的和。b = w + 1
:b
是w
和常数1
的和。y = a * b
:y
是a
和b
的乘积。
计算图如下:
w ────┐
├── a = w + x ───┐
x ────┘ │
y = a * b
w ────┐ │
├── b = w + 1 ───┘
1 ────┘
2. 计算值
w = [1.]
,x = [2.]
。a = w + x = [1.] + [2.] = [3.]
。b = w + 1 = [1.] + [1.] = [2.]
。y = a * b = [3.] * [2.] = [6.]
。
3. 反向传播
反向传播的目的是计算 y
对各个变量的梯度。根据链式法则:
y = a * b
,所以:a = w + x
,所以:b = w + 1
,所以:
4. 计算梯度
y
对b
的梯度:因此,
b.grad = [3.]
。y
对w
的梯度:代入值:
因此,
w.grad = [5.]
。
结果
b.grad = [3.]
,因为y
对b
的梯度是a = [3.]
。w.grad = [5.]
,因为y
对w
的梯度是[5.]
。
总结
b.grad
是 [3.]
,因为 y
对 b
的偏导数是 a = [3.]
。这是通过链式法则和反向传播计算得出的结果。