torch.autograd.Function自定义反向求导

@TOC

Extending torch.autograd

在某些情况下我们的函数不可微(not differentiable),但是我们仍然需要对他求导时,就需要我们自定义求导方式,这里我们根据PyTorch官网给出的例子,来看一下torch.autograd.Function是如何运行的

官网给出的例子为LinearFunction,代码如下,这里我们假设输入为的矩阵,权重也为的矩阵,bias为的矩阵,则

forward

backward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from numpy import double
import torch
from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):

# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output

# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None

# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)

return grad_input, grad_weight, grad_bias

input = torch.tensor([[2.0, 1.5, 2.5], [1.0, 2.0, 3.0]], dtype=torch.double, requires_grad=True)
weight = torch.tensor([[3.0, 2.0, 3.5], [1.0, 2.0, 3.0]], dtype=torch.double, requires_grad=True)
bias = torch.tensor([0.1, 0.2], dtype=torch.double, requires_grad=True)

# two ways to use linear operation
# first
output = LinearFunction.apply(input, weight, bias)
print(output)
# second
linear = LinearFunction.apply
output = linear(input, weight, bias)
print(output)


# 检查backward是否计算正确
from torch.autograd import gradcheck
# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
test = gradcheck(LinearFunction.apply, (input, weight, bias), eps=1e-6, atol=1e-4)
print(test) # 没问题的话输出True