Processing math: 100%

BP Algorithm的理解思路

本文淺顯直白的介紹反向傳播算法(BP 算法, The Back Propagation Algorithm)的理解思路, 計算過程和程式實現


我們考慮一個最簡單的layer,這個layer有兩個input x1,x2 和1個output y1:

則數學式可以寫成: y1=w1x1+w2x2

上式的w下標是跟著x在跑, 考慮到y可以擴展成yj(不只一個y), 我們改一下w的下標, 也把y的標號考慮進去: y1=w11x1+w12x2

其中wji的第一個下標j是跟著y在跑, 而i是跟著x在跑, 整個式子可以寫成矩陣形式: Y=WX

其中令輸入Xn維, 輸出Ym維:

X=[x1, x2,...,xi, ... , xn ]T

Y=[y1, y2,...,yj, ... , ym ]T

W=[w11, w12,..., w1n ...,wji,...wm1, wm2,..., wmn ]={wji}m×n

通常我們還會把上式加個偏置(bias) B:

Y=WX+B

B維度會和Y相同 B=[b1, b2,...,bj, ... , bm ]T

通常還會再加個可微分的激活函數σ來增加其非線性:

Y=σ(WX+B)

以上是"前饋"(forward)的部分


接下來考慮"反向傳播"(backward)的部分,假設損失函數(loss function)Jd維, 下標用k表示:

Assume lostfunction J has dim d

J=[J1, J2,...,Jk, ... , Jd ]T

損失梯度(Gradient)要從輸出Y流回輸入X, 並流到WB以進行權重和偏置的更新(update), 也就是可以把問題定義成:

Given Jkyj,find Jkwji, Jkbjand Jkxi

梯度傳遞會用到微積分的鏈鎖率(Chain Rule), 為了方便起見我們多定義一個Z:

Z=WX+B

Y=σ(Z)

Z=[z1, z2,...,zj, ... , zm ]T

可以發現, 上式和原本的Y=σ(WX+B)並沒有不同, 只是計算過程中間多一個Z

定義輸出梯度(已知): out=Jkyj

定義輸入梯度(待求): in=Jkxi

用鏈鎖率(Chain Rule)將待求的項目展開:

Jkwji=Jkyjyjzjzjwji=Gkj zjwji  ...(1)

Jkbj=Jkyjyjzjzjbj=Gkj zjbj  ...(2)

Jkxi (in)=Jkyjyjzjzjxi=Gkj zjxi  ...(3)

上列各式重複Jkyjyjzj的部分, 方便起見, 定義中介梯度G:

G =Jkyjyjzj

G=[G11, G12,..., G1m ...,Gkj,...Gk1, Gk2,..., Gdm ]={Gkj}d×m

至此, 整個計算流程(計算圖)已經很明朗了, 我們的目的就是找到:

Find  G, zjwji, zjbj, zjxi

只要算出上述四項, 就可以得到所求

  1. 計算Gkj

yjzj=σ(zj)

=>Gkj=Jkyj σ(zj)=(dout) σ(zj)

  1. 計算zjwji, zjbj, zjxi

zjwji=xi

zjbj=1

zjxi=wji

  1. 代入式(1)(2)(3), 得到結果

=>Jkwji=Gkj xi , Jkbj=Gkj, Jkxi (in)=Gkj wji 


以上就是整個BP演算法"單層"的計算過程, 化成python程式碼如下:

1
2
3
4
5
6
7
def backward(self,dout):  
Z,x = self.cache
G = dout * self.dactive_fn(Z)
din = G.dot(self.W.T)
dW = x.T.dot(G)
db = G.sum(axis=0)
return din,dW,db

程式碼來源: https://github.com/purelyvivid/DeepLearning_practice/blob/master/3.%20BP%20algo.py (用numpy寫一個BP algorithm)