【优化器】带动量 Momentum 的SGD算法

张开发
2026/4/14 16:30:54 15 分钟阅读

分享文章

【优化器】带动量 Momentum 的SGD算法
思想让参数更新具有惯性每一步更新都是由前面梯度累积v vv和当前点梯度g gg组合而成公式累计梯度动量更新v ← α v ( 1 − α ) g v \leftarrow \alpha v (1-\alpha) gv←αv(1−α)g参数更新x ← x − η ⋅ v x \leftarrow x - \eta \cdot vx←x−η⋅v其中α \alphaα为动量参数v vv为累计梯度g gg为当前梯度η \etaη为学习率优点加快收敛能帮助参数在正确的方向上加速前进可以帮助跳出局部最小值实验一损失函数f ( x ) 0.1 x 1 2 2 x 2 2 f(x) 0.1x_1^2 2x_2^2f(x)0.1x12​2x22​初始值x 1 − 5 x_1 -5x1​−5x 2 − 2 x_2 -2x2​−2 学习率η 0.4 \eta 0.4η0.4我们使用不带动量的传统梯度下降法观察下降过程预期分析因为x 1 x_1x1​和x 2 x_2x2​的系数分别为 0.1 和 2。这就使得x 1 x_1x1​和x 2 x_2x2​的梯度相差一个量级 如果使用相同的学习率x 2 x_2x2​的更新幅度会较x 1 x_1x1​的更大些importnumpyasnpimportmatplotlib.pyplotaspltdefloss_func(x1,x2):# 定义目标函数return0.1*x1**22*x2**2x1,x2-5,-2eta0.4num_epochs20result[(x1,x2)]forepochinrange(num_epochs):gd10.2*x1 gd24*x2 x1-eta*gd1 x2-eta*gd2 result.append((x1,x2))plt.plot(*zip(*result),-o,color#ff7f0e)x1,x2np.meshgrid(np.arange(-5.5,1.0,0.1),np.arange(-3.0,1.0,0.1))plt.contour(x1,x2,loss_func(x1,x2),colors#1f77b4)plt.title(learning rate {}.format(eta))plt.xlabel(x1)plt.ylabel(x2)plt.show()结果分析与预想一致使用相同的学习率x 2 x_2x2​的更新幅度会较x 1 x_1x1​的更大些变化快得多而x 1 x_1x1​收敛速度太慢实验二依然使用不带动量的梯度下降算法将学习率设置为 0.6更新过程x 1 ← x 1 − 0.06 x 1 x_1 \leftarrow x_1 - 0.06x_1x1​←x1​−0.06x1​x 2 ← x 2 − 2.4 x 2 x_2 \leftarrow x_2-2.4x_2x2​←x2​−2.4x2​更新过程如下这时我们会陷入一个两难的选择如果我们选择小的学习率x 1 x_1x1​收敛速度慢如果我们选择大的学习率x 1 x_1x1​方向会收敛很快但在x 2 x_2x2​方向不会收敛实验三我们使用带动量的梯度下降法将历史的梯度考虑在内动量参数设置为 0.5 学习率设置为0.4累计梯度更新v ← α v ( 1 − α ) g v \leftarrow \alpha v (1-\alpha) gv←αv(1−α)g权重更新x ← x − η ⋅ v x \leftarrow x - \eta \cdot vx←x−η⋅vimportnumpyasnpimportmatplotlib.pyplotaspltdefloss_func(x1,x2):# 定义目标函数return0.1*x1**22*x2**2x1,x2-5,-2v1,v20,0eta,alpha0.4,0.5num_epochs20result[(x1,x2)]forepochinrange(num_epochs):v1alpha*v1(1-alpha)*(0.2*x1)v2alpha*v2(1-alpha)*(4*x2)x1-eta*v1 x2-eta*v2 result.append((x1,x2))plt.plot(*zip(*result),-o,color#ff7f0e)x1,x2np.meshgrid(np.arange(-5.5,1.0,0.1),np.arange(-3.0,1.0,0.1))plt.contour(x1,x2,loss_func(x1,x2),colors#1f77b4)plt.xlabel(x1)plt.ylabel(x2)plt.show()即使我们将学习率设置为 0.6 x 2 x_2x2​的梯度也不会发散了参考连接https://www.bilibili.com/video/BV1jh4y1q7ua/?spm_id_from333.1387.favlist.content.clickvd_sourcecf0b4c9c919d381324e8f3466e714d7a

更多文章