• 欢迎光临~

# 深度学习入门 | 梯度下降

## 2.关于loss即损失（“评估模型”）

• mean表示每个样本的损失加起来除以样本的总数，即所有样本的平均损失
• 训练的目标即是想办法寻找到一个最好的权重值w，以便让平均损失降到最低

## 3. 关于求解w的方法

### 3.3 梯度下降算法

#### 代码

``````import matplotlib.pyplot as plt

# prepare the training set
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# initial guess of weight
w = 1.0

# define the model linear model y = w*x
def forward(x):
return x*w

#define the cost function MSE
def cost(xs, ys):
cost = 0
for x, y in zip(xs,ys):
y_pred = forward(x)
cost += (y_pred - y)**2
return cost / len(xs)

# define the gradient function  gd
for x, y in zip(xs,ys):

epoch_list = []
cost_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
cost_val = cost(x_data, y_data)
w-= 0.01 * grad_val  # 0.01 learning rate
print('epoch:', epoch, 'w=', w, 'loss=', cost_val)
epoch_list.append(epoch)
cost_list.append(cost_val)

print('predict (after training)', 4, forward(4))
plt.plot(epoch_list,cost_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()
``````
``````predict (before training) 4 4.0
epoch: 0 w= 1.0933333333333333 loss= 4.666666666666667
epoch: 1 w= 1.1779555555555554 loss= 3.8362074074074086
epoch: 2 w= 1.2546797037037036 loss= 3.1535329869958857
epoch: 3 w= 1.3242429313580246 loss= 2.592344272332262
epoch: 4 w= 1.3873135910979424 loss= 2.1310222071581117
epoch: 5 w= 1.4444976559288012 loss= 1.7517949663820642
epoch: 6 w= 1.4963445413754464 loss= 1.440053319920117
epoch: 7 w= 1.5433523841804047 loss= 1.1837878313441108
epoch: 8 w= 1.5859728283235668 loss= 0.9731262101573632
epoch: 9 w= 1.6246153643467005 loss= 0.7999529948031382
epoch: 10 w= 1.659651263674342 loss= 0.6575969151946154
epoch: 11 w= 1.6914171457314033 loss= 0.5405738908195378
epoch: 12 w= 1.7202182121298057 loss= 0.44437576375991855
epoch: 13 w= 1.7463311789976905 loss= 0.365296627844598
epoch: 14 w= 1.7700069356245727 loss= 0.3002900634939416
epoch: 15 w= 1.7914729549662791 loss= 0.2468517784170642
epoch: 16 w= 1.8109354791694263 loss= 0.2029231330489788
epoch: 17 w= 1.8285815011136133 loss= 0.16681183417217407
epoch: 18 w= 1.8445805610096762 loss= 0.1371267415488235
epoch: 19 w= 1.8590863753154396 loss= 0.11272427607497944
epoch: 20 w= 1.872238313619332 loss= 0.09266436490145864
epoch: 21 w= 1.8841627376815275 loss= 0.07617422636521683
epoch: 22 w= 1.8949742154979183 loss= 0.06261859959338009
epoch: 23 w= 1.904776622051446 loss= 0.051475271914629306
epoch: 24 w= 1.9136641373266443 loss= 0.04231496130368814
epoch: 25 w= 1.9217221511761575 loss= 0.03478477885657844
epoch: 26 w= 1.9290280837330496 loss= 0.02859463421027894
epoch: 27 w= 1.9356521292512983 loss= 0.023506060193480772
epoch: 28 w= 1.9416579305211772 loss= 0.01932302619282764
epoch: 29 w= 1.9471031903392007 loss= 0.015884386331668398
epoch: 30 w= 1.952040225907542 loss= 0.01305767153735723
epoch: 31 w= 1.9565164714895047 loss= 0.010733986344664803
epoch: 32 w= 1.9605749341504843 loss= 0.008823813841374291
epoch: 33 w= 1.9642546069631057 loss= 0.007253567147113681
epoch: 34 w= 1.9675908436465492 loss= 0.005962754575689583
epoch: 35 w= 1.970615698239538 loss= 0.004901649272531298
epoch: 36 w= 1.9733582330705144 loss= 0.004029373553099482
epoch: 37 w= 1.975844797983933 loss= 0.0033123241439168096
epoch: 38 w= 1.9780992835054327 loss= 0.0027228776607060357
epoch: 39 w= 1.980143350378259 loss= 0.002238326453885249
epoch: 40 w= 1.9819966376762883 loss= 0.001840003826269386
epoch: 41 w= 1.983676951493168 loss= 0.0015125649231412608
epoch: 42 w= 1.9852004360204722 loss= 0.0012433955919298103
epoch: 43 w= 1.9865817286585614 loss= 0.0010221264385926248
epoch: 44 w= 1.987834100650429 loss= 0.0008402333603648631
epoch: 45 w= 1.9889695845897222 loss= 0.0006907091659248264
epoch: 46 w= 1.9899990900280147 loss= 0.0005677936325753796
epoch: 47 w= 1.9909325082920666 loss= 0.0004667516012495216
epoch: 48 w= 1.9917788075181404 loss= 0.000383690560742734
epoch: 49 w= 1.9925461188164473 loss= 0.00031541069384432885
epoch: 50 w= 1.9932418143935788 loss= 0.0002592816085930997
epoch: 51 w= 1.9938725783835114 loss= 0.0002131410058905752
epoch: 52 w= 1.994444471067717 loss= 0.00017521137977565514
epoch: 53 w= 1.9949629871013967 loss= 0.0001440315413480261
epoch: 54 w= 1.9954331083052663 loss= 0.0001184003283899171
epoch: 55 w= 1.9958593515301082 loss= 9.733033217332803e-05
epoch: 56 w= 1.9962458120539648 loss= 8.000985883901657e-05
epoch: 57 w= 1.9965962029289281 loss= 6.57716599593935e-05
epoch: 58 w= 1.9969138906555615 loss= 5.406722767150764e-05
epoch: 59 w= 1.997201927527709 loss= 4.444566413387458e-05
epoch: 60 w= 1.9974630809584561 loss= 3.65363112808981e-05
epoch: 61 w= 1.9976998600690001 loss= 3.0034471708953996e-05
epoch: 62 w= 1.9979145397958935 loss= 2.4689670610172655e-05
epoch: 63 w= 1.9981091827482769 loss= 2.0296006560253656e-05
epoch: 64 w= 1.9982856590251044 loss= 1.6684219437262796e-05
epoch: 65 w= 1.9984456641827613 loss= 1.3715169898293847e-05
epoch: 66 w= 1.9985907355257035 loss= 1.1274479219506377e-05
epoch: 67 w= 1.9987222668766378 loss= 9.268123006398985e-06
epoch: 68 w= 1.9988415219681517 loss= 7.61880902783969e-06
epoch: 69 w= 1.9989496465844576 loss= 6.262999634617916e-06
epoch: 70 w= 1.9990476795699081 loss= 5.1484640551938914e-06
epoch: 71 w= 1.9991365628100501 loss= 4.232266273994499e-06
epoch: 72 w= 1.999217150281112 loss= 3.479110977946351e-06
epoch: 73 w= 1.999290216254875 loss= 2.859983851026929e-06
epoch: 74 w= 1.9993564627377531 loss= 2.3510338359374262e-06
epoch: 75 w= 1.9994165262155628 loss= 1.932654303533636e-06
epoch: 76 w= 1.999470983768777 loss= 1.5887277332523938e-06
epoch: 77 w= 1.9995203586170245 loss= 1.3060048068548734e-06
epoch: 78 w= 1.9995651251461022 loss= 1.0735939958924364e-06
epoch: 79 w= 1.9996057134657994 loss= 8.825419799121559e-07
epoch: 80 w= 1.9996425135423248 loss= 7.254887315754342e-07
epoch: 81 w= 1.999675878945041 loss= 5.963839812987369e-07
epoch: 82 w= 1.999706130243504 loss= 4.902541385825727e-07
epoch: 83 w= 1.9997335580874436 loss= 4.0301069098738336e-07
epoch: 84 w= 1.9997584259992822 loss= 3.312926995781724e-07
epoch: 85 w= 1.9997809729060159 loss= 2.723373231729343e-07
epoch: 86 w= 1.9998014154347876 loss= 2.2387338352920307e-07
epoch: 87 w= 1.9998199499942075 loss= 1.8403387118941732e-07
epoch: 88 w= 1.9998367546614149 loss= 1.5128402140063082e-07
epoch: 89 w= 1.9998519908930161 loss= 1.2436218932547864e-07
epoch: 90 w= 1.9998658050763347 loss= 1.0223124683409346e-07
epoch: 91 w= 1.9998783299358769 loss= 8.403862850836479e-08
epoch: 92 w= 1.9998896858085284 loss= 6.908348768398496e-08
epoch: 93 w= 1.9998999817997325 loss= 5.678969725349543e-08
epoch: 94 w= 1.9999093168317574 loss= 4.66836551287917e-08
epoch: 95 w= 1.9999177805941268 loss= 3.8376039345125727e-08
epoch: 96 w= 1.9999254544053418 loss= 3.154680994333735e-08
epoch: 97 w= 1.9999324119941766 loss= 2.593287985380858e-08
epoch: 98 w= 1.9999387202080534 loss= 2.131797981222471e-08
epoch: 99 w= 1.9999444396553017 loss= 1.752432687141379e-08
predict (after training) 4 7.999777758621207
``````

（1）局部振荡很大

（2）训练发散（失败）

## 5. 关于随机梯度下降

### 利用随机梯度下降法修改的上面的代码

``````import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = 1.0

def forward(x):
return x*w

# calculate loss function
def loss(x, y):
y_pred = forward(x)
return (y_pred - y)**2

# define the gradient function  sgd
return 2*x*(x*w - y)

epoch_list = []
loss_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
for x,y in zip(x_data, y_data):
w = w - 0.01*grad    # update weight by every grad of sample of training set
l = loss(x,y)
print("progress:",epoch,"w=",w,"loss=",l)
epoch_list.append(epoch)
loss_list.append(l)

print('predict (after training)', 4, forward(4))
plt.plot(epoch_list,loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
``````
``````predict (before training) 4 4.0
progress: 0 w= 1.260688 loss= 4.919240100095999
progress: 1 w= 1.453417766656 loss= 2.688769240265834
progress: 2 w= 1.5959051959019805 loss= 1.4696334962911515
progress: 3 w= 1.701247862192685 loss= 0.8032755585999681
progress: 4 w= 1.7791289594933983 loss= 0.43905614881022015
progress: 5 w= 1.836707389300983 loss= 0.2399802903801062
progress: 6 w= 1.8792758133988885 loss= 0.1311689630744999
progress: 7 w= 1.910747160155559 loss= 0.07169462478267678
progress: 8 w= 1.9340143044689266 loss= 0.03918700813247573
progress: 9 w= 1.9512159834655312 loss= 0.021418922423117836
progress: 10 w= 1.9639333911678687 loss= 0.01170720245384975
progress: 11 w= 1.9733355232910992 loss= 0.006398948863435593
progress: 12 w= 1.9802866323953892 loss= 0.003497551760830656
progress: 13 w= 1.9854256707695 loss= 0.001911699652671057
progress: 14 w= 1.9892250235079405 loss= 0.0010449010656399273
progress: 15 w= 1.9920339305797026 loss= 0.0005711243580809696
progress: 16 w= 1.994110589284741 loss= 0.0003121664271570621
progress: 17 w= 1.9956458879852805 loss= 0.0001706246229305199
progress: 18 w= 1.9967809527381737 loss= 9.326038746484765e-05
progress: 19 w= 1.9976201197307648 loss= 5.097447086306101e-05
progress: 20 w= 1.998240525958391 loss= 2.7861740127856012e-05
progress: 21 w= 1.99869919972735 loss= 1.5228732143933469e-05
progress: 22 w= 1.9990383027488265 loss= 8.323754426231206e-06
progress: 23 w= 1.9992890056818404 loss= 4.549616284094891e-06
progress: 24 w= 1.999474353368653 loss= 2.486739429417538e-06
progress: 25 w= 1.9996113831376856 loss= 1.3592075910762856e-06
progress: 26 w= 1.9997126908902887 loss= 7.429187207079447e-07
progress: 27 w= 1.9997875889274812 loss= 4.060661735575354e-07
progress: 28 w= 1.9998429619451539 loss= 2.2194855602869353e-07
progress: 29 w= 1.9998838998815958 loss= 1.213131374411496e-07
progress: 30 w= 1.9999141657892625 loss= 6.630760559646474e-08
progress: 31 w= 1.9999365417379913 loss= 3.624255915449335e-08
progress: 32 w= 1.9999530845453979 loss= 1.9809538924707548e-08
progress: 33 w= 1.9999653148414271 loss= 1.0827542027017377e-08
progress: 34 w= 1.999974356846045 loss= 5.9181421028034105e-09
progress: 35 w= 1.9999810417085633 loss= 3.2347513278475087e-09
progress: 36 w= 1.9999859839076413 loss= 1.7680576050779005e-09
progress: 37 w= 1.9999896377347262 loss= 9.6638887447731e-10
progress: 38 w= 1.999992339052936 loss= 5.282109892545845e-10
progress: 39 w= 1.9999943361699042 loss= 2.887107421958329e-10
progress: 40 w= 1.9999958126624442 loss= 1.5780416225633037e-10
progress: 41 w= 1.999996904251097 loss= 8.625295142578772e-11
progress: 42 w= 1.999997711275687 loss= 4.71443308235547e-11
progress: 43 w= 1.9999983079186507 loss= 2.5768253628059826e-11
progress: 44 w= 1.9999987490239537 loss= 1.4084469615916932e-11
progress: 45 w= 1.9999990751383971 loss= 7.698320862431846e-12
progress: 46 w= 1.9999993162387186 loss= 4.20776540913866e-12
progress: 47 w= 1.9999994944870796 loss= 2.299889814334344e-12
progress: 48 w= 1.9999996262682318 loss= 1.2570789110540446e-12
progress: 49 w= 1.999999723695619 loss= 6.870969979249939e-13
progress: 50 w= 1.9999997957248556 loss= 3.7555501141274804e-13
progress: 51 w= 1.9999998489769344 loss= 2.052716967104274e-13
progress: 52 w= 1.9999998883468353 loss= 1.1219786256679713e-13
progress: 53 w= 1.9999999174534755 loss= 6.132535848018759e-14
progress: 54 w= 1.999999938972364 loss= 3.351935118167793e-14
progress: 55 w= 1.9999999548815364 loss= 1.8321081844499955e-14
progress: 56 w= 1.9999999666433785 loss= 1.0013977760018664e-14
progress: 57 w= 1.9999999753390494 loss= 5.473462367088053e-15
progress: 58 w= 1.9999999817678633 loss= 2.991697274308627e-15
progress: 59 w= 1.9999999865207625 loss= 1.6352086111474931e-15
progress: 60 w= 1.999999990034638 loss= 8.937759877335403e-16
progress: 61 w= 1.9999999926324883 loss= 4.885220495987371e-16
progress: 62 w= 1.99999999455311 loss= 2.670175009618106e-16
progress: 63 w= 1.9999999959730488 loss= 1.4594702493172377e-16
progress: 64 w= 1.9999999970228268 loss= 7.977204100704301e-17
progress: 65 w= 1.9999999977989402 loss= 4.360197735196887e-17
progress: 66 w= 1.9999999983727301 loss= 2.3832065197304227e-17
progress: 67 w= 1.9999999987969397 loss= 1.3026183953845832e-17
progress: 68 w= 1.999999999110563 loss= 7.11988308874388e-18
progress: 69 w= 1.9999999993424284 loss= 3.89160224698574e-18
progress: 70 w= 1.9999999995138495 loss= 2.1270797208746147e-18
progress: 71 w= 1.9999999996405833 loss= 1.1626238773828175e-18
progress: 72 w= 1.999999999734279 loss= 6.354692062078993e-19
progress: 73 w= 1.9999999998035491 loss= 3.4733644793346653e-19
progress: 74 w= 1.9999999998547615 loss= 1.8984796531526204e-19
progress: 75 w= 1.9999999998926234 loss= 1.0376765851119951e-19
progress: 76 w= 1.9999999999206153 loss= 5.671751114309842e-20
progress: 77 w= 1.9999999999413098 loss= 3.100089617511693e-20
progress: 78 w= 1.9999999999566096 loss= 1.6944600977692705e-20
progress: 79 w= 1.9999999999679208 loss= 9.2616919156479e-21
progress: 80 w= 1.9999999999762834 loss= 5.062350511130293e-21
progress: 81 w= 1.999999999982466 loss= 2.7669155644059242e-21
progress: 82 w= 1.9999999999870368 loss= 1.5124150106147723e-21
progress: 83 w= 1.999999999990416 loss= 8.26683933105326e-22
progress: 84 w= 1.9999999999929146 loss= 4.518126871054872e-22
progress: 85 w= 1.9999999999947617 loss= 2.469467919185614e-22
progress: 86 w= 1.9999999999961273 loss= 1.349840097651456e-22
progress: 87 w= 1.999999999997137 loss= 7.376551550022107e-23
progress: 88 w= 1.9999999999978835 loss= 4.031726170507742e-23
progress: 89 w= 1.9999999999984353 loss= 2.2033851437431755e-23
progress: 90 w= 1.9999999999988431 loss= 1.2047849775995315e-23
progress: 91 w= 1.9999999999991447 loss= 6.5840863393251405e-24
progress: 92 w= 1.9999999999993676 loss= 3.5991747246272455e-24
progress: 93 w= 1.9999999999995324 loss= 1.969312363793734e-24
progress: 94 w= 1.9999999999996543 loss= 1.0761829795642296e-24
progress: 95 w= 1.9999999999997444 loss= 5.875191475205477e-25
progress: 96 w= 1.999999999999811 loss= 3.2110109830478153e-25
progress: 97 w= 1.9999999999998603 loss= 1.757455879087579e-25
progress: 98 w= 1.9999999999998967 loss= 9.608404711682446e-26
progress: 99 w= 1.9999999999999236 loss= 5.250973729513143e-26
predict (after training) 4 7.9999999999996945
``````

## 6、梯度下降法和随机梯度下降法比较

• 1、损失函数由cost()更改为loss()。cost是计算所有训练数据的损失，loss是计算一个训练函数的损失。对应于源代码则是少了两个for循环。
• 3、本算法中的随机梯度主要是指，每次拿一个训练数据来训练，然后更新梯度参数。本算法中梯度总共更新100(epoch)x3 = 300次。梯度下降法中梯度总共更新100(epoch)次。

（1）梯度下降法中计算 \$x_i\$ 和 \$x_{i+1}\$ 对应的函数值之间是没有函数关系的；

（2）随机梯度下降法由于对于每一次的w关注的是每一个样本，所以必须要等待上一个w计算完成之后再计算下一个w。（相邻两个样本之间的“梯度下降”不能并行，有依赖）

## 7. 关于Mini-Batch:

• Mini-batch 和batch的区别
• 谈谈深度学习中的 Batch_Size