理解反向传播计算梯度

最后更新于

在神经网络的基础知识中,反向传播求梯度对初学者来说是比较不好理解的。本文将以图文并茂、通俗易懂的方式解释如何通过反向传播求梯度。

数学方法计算梯度

我们之所以要求梯度,是希望更新参数和偏置到”使得损失函数值最小的“程度。

我们从最小的例子开始理解,假设我们有如下一个二元一次函数:

F(x,y)=3x+4yF(x, y) = 3x + 4y

根据高等数学知识,梯度的计算公式为:

gradF(x,y)=(Fx,Fy)grad F(x, y) = (\frac{∂F}{∂x},\frac{∂F}{∂y})

容易求得梯度为(3, 4)。为了贴近真实的使用场景,我们再来看一个复杂一点的式子:

F1(x)=w1x+b1F_{1}(x) = w_{1}x + b_{1}
F2(x)=w2x+b2F_{2}(x) = w_{2}x + b_{2}
F(x)=F2(F1(x))F(x) = F_{2}(F_{1}(x))

我们带入神经网络的知识来理解。这个式子相当于使用了两个Affine函数并且使用线性激活层(可以暂时理解为忽略该层梯度存在)。我们的目标是求w_{1}的梯度。根据高等数学的链式法则有:

yw1=yF1F1w1\frac{∂y}{w_{1}} = \frac{y}{∂F_{1}} \frac{∂F_{1}}{w_{1}}

容易解出梯度为w_{2}x。

一般的计算梯度方法为微分法,即:

def numerical_gradient(f, x):
    h = 1e-4 # 0.0001
    grad = np.zeros_like(x)

    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        idx = it.multi_index
        tmp_val = x[idx]
        x[idx] = float(tmp_val) + h
        fxh1 = f(x) # f(x+h)

        x[idx] = tmp_val - h
        fxh2 = f(x) # f(x-h)
        grad[idx] = (fxh1 - fxh2) / (2*h)

        x[idx] = tmp_val # 还原值
        it.iternext()

这也叫做正向传播求梯度法,需要将整个神经网络函数投入进行计算。由于这种方法需要消耗大量的资源,所以我们引入了反向传播计算梯度的方法。

反向传播计算梯度

我们先将上面用到的函数抽象成代码。在Affine层的backward方法中,我们更新了三个变量,其中

  • dx代表继续反向传播的信号
  • dw为权重的局部导数
  • db为偏置的局部导数
class Affine:
    def __init__(self, W, b) -> None:
        self.W = W
        self.b = b
        self.x = None
        self.dW = None
        self.db = None
    
    def forward(self, x):
        self.x = x
        dot = np.dot(self.x, self.W)
        out = dot + self.b # Boardcasting...

        return out

    def backward(self, dout):
        dx = np.dot(dout, self.W) # A
        self.dW = np.dot(self.x, dout) # B
        self.db = np.sum(dout, axis=0)

        return dx

那么如何使用反向传播来计算呢?可以这么理解:我们先计算E关于x的梯度,然后反向传播给F_{2},随即计算出损失值关于w_{2}的梯度。在上面的方法中,通常我们是从左往右计算,而反向传播就是从右往左计算。

我们通过一个更贴近实际的例子来加深理解。

E=(yt)22E = \frac{(y - t)^2}{2}

为损失函数,其中t代表教师标签,y代表隐藏层输出。根据链式法则,w_{1}梯度为:

Ew1=EyyF1F1w1\frac{∂E}{∂w_{1}} = \frac{∂E}{∂y} * \frac{∂y}{∂F_{1}}*\frac{∂F_{1}}{∂w_{1}}

在反向传播中,我们首先计算出输出层(E)的导数,然后反向传播给上一层(F_{2})。于是我们可以先得到w_{2}关于E的梯度为

Ew2=Eyx2\frac{∂E}{∂w_{2}} = \frac{∂E}{∂y} * x_{2}

这个式子对应以上代码 B 注释处。

之后我们就可以根据此梯度,结合各种梯度下降算法,更新权重。接着,我们还要求出x2关于E的梯度以便继续往前传播,容易得到

Ex2=EyEx2\frac{∂E}{∂x_{2}} = \frac{∂E}{∂y} * \frac{∂E}{∂x_{2}}

即:

Ex2=Eyw2\frac{∂E}{∂x_{2}} = \frac{∂E}{∂y} * w_{2}

这个式子对应以上代码 A 注释处。

也就是说,在每一层的反向传播方法中,我们做了两件事情:

  1. 计算当前层权重关于损失函数的梯度(其实是关于上一层的梯度,但最终还是关于损失函数),以此更新当前层权重。
  2. 计算当前层x关于损失函数的梯度,其中x代表上一层,应用链式法则向前传播。

这两个梯度均依赖于上一层的关于 x 的梯度,因为上一层的x即为这一层。

值得一提的是,代码中的两次点乘的顺序是不一样的,因为实际运算时是使用矩阵运算,而矩阵运算是对顺序有要求的。