本文淺顯直白的介紹反向傳播算法(BP 算法, The Back Propagation Algorithm)的理解思路, 計算過程和程式實現
我們考慮一個最簡單的layer,這個layer有兩個input x1,x2 和1個output y1:
則數學式可以寫成: y1=w1x1+w2x2
上式的w下標是跟著x在跑, 考慮到y可以擴展成yj(不只一個y), 我們改一下w的下標, 也把y的標號考慮進去: y1=w11x1+w12x2
其中wji的第一個下標j是跟著y在跑, 而i是跟著x在跑, 整個式子可以寫成矩陣形式: Y=WX
其中令輸入X有n維, 輸出Y有m維:
X=[x1, x2,...,xi, ... , xn ]T
Y=[y1, y2,...,yj, ... , ym ]T
W=[w11, w12,..., w1n ...,wji,...wm1, wm2,..., wmn ]={wji}m×n
通常我們還會把上式加個偏置(bias) B:
Y=WX+B
B維度會和Y相同 B=[b1, b2,...,bj, ... , bm ]T
通常還會再加個可微分的激活函數σ來增加其非線性:
Y=σ(WX+B)
以上是"前饋"(forward)的部分
接下來考慮"反向傳播"(backward)的部分,假設損失函數(loss function)J 有d維, 下標用k表示:
Assume lostfunction J has dim d
J=[J1, J2,...,Jk, ... , Jd ]T
損失梯度(Gradient)要從輸出Y流回輸入X, 並流到W和B以進行權重和偏置的更新(update), 也就是可以把問題定義成:
Given ∂Jk∂yj,find ∂Jk∂wji, ∂Jk∂bjand ∂Jk∂xi
梯度傳遞會用到微積分的鏈鎖率(Chain Rule), 為了方便起見我們多定義一個Z:
Z=WX+B
Y=σ(Z)
Z=[z1, z2,...,zj, ... , zm ]T
可以發現, 上式和原本的Y=σ(WX+B)並沒有不同, 只是計算過程中間多一個Z
定義輸出梯度(已知): ∇out=∂Jk∂yj
定義輸入梯度(待求): ∇in=∂Jk∂xi
用鏈鎖率(Chain Rule)將待求的項目展開:
∂Jk∂wji=∂Jk∂yj∂yj∂zj∂zj∂wji=Gkj ∂zj∂wji ...(1)
∂Jk∂bj=∂Jk∂yj∂yj∂zj∂zj∂bj=Gkj ∂zj∂bj ...(2)
∂Jk∂xi (∇in)=∂Jk∂yj∂yj∂zj∂zj∂xi=Gkj ∂zj∂xi ...(3)
上列各式重複∂Jk∂yj∂yj∂zj的部分, 方便起見, 定義中介梯度G:
G =∂Jk∂yj∂yj∂zj
G=[G11, G12,..., G1m ...,Gkj,...Gk1, Gk2,..., Gdm ]={Gkj}d×m
至此, 整個計算流程(計算圖)已經很明朗了, 我們的目的就是找到:
Find G, ∂zj∂wji, ∂zj∂bj, ∂zj∂xi
只要算出上述四項, 就可以得到所求
- 計算Gkj
∂yj∂zj=σ′(zj)
=>Gkj=∂Jk∂yj σ′(zj)=(dout) σ′(zj)
- 計算∂zj∂wji, ∂zj∂bj, ∂zj∂xi
∂zj∂wji=xi
∂zj∂bj=1
∂zj∂xi=wji
- 代入式(1)(2)(3), 得到結果
=>∂Jk∂wji=Gkj xi , ∂Jk∂bj=Gkj, ∂Jk∂xi (∇in)=Gkj wji
以上就是整個BP演算法"單層"的計算過程, 化成python程式碼如下:
1 | def backward(self,dout): |
程式碼來源: https://github.com/purelyvivid/DeepLearning_practice/blob/master/3.%20BP%20algo.py (用numpy寫一個BP algorithm)