• 欢迎光临~

Pytorch Chain-Rules

开发技术 开发技术 2022-01-27 85次浏览
Derivative Rules

[begin{aligned} &frac{delta E}{delta w^1_{jk}}=frac{delta E}{delta O_k^1}frac{delta O_k^1}{delta w^1_{jk}}=frac{delta E}{delta O_k^2}frac{delta O_k^2}{delta O_k^1}frac{delta O_k^1}{delta w^1_{jk}}\ end{aligned} ]

import torch, torch.nn.functional as F
x = torch.tensor(1.)
w1, w2 = torch.tensor(2., requires_grad=True), torch.tensor(2., requires_grad=True)
b1, b2 = torch.tensor(1.), torch.tensor(1.)
y1 = x * w1 + b1 
y2 = y1 * w2 +b2 
dy2_dy1 = torch.autograd.grad(y2, [y1], retain_graph=True)[0]
dy1_dw1 = torch.autograd.grad(y1, [w1], retain_graph=True)[0]
dy2_dw1 = torch.autograd.grad(y2, [w1], retain_graph=True)[0]
dy2_dy1 * dy1_dw1
tensor(2.)
dy2_dw1
tensor(2.)
程序员灯塔
转载请注明原文链接:Pytorch Chain-Rules
喜欢 (0)
违法和不良信息举报电话:022-22558618 举报邮箱:dljd@tidljd.com