我已经尝试过使用这段代码进行多变量回归来找到系数,但是找不到我在哪里出错或者在正确的道路上?
问题是MSE值未收敛。
这里x1,x2,x3是我拥有的3个特征变量(我已将每个特征列切成这些x1,x2,x3变量)
def gradientDescent(x,y):
mCurrent1=mCurrent2=mCurrent3=bCurrent=0
iteration=1000
learningRate=0.0000001
n=len(x)
for i in range(0,iteration):
y_predict=mCurrent1*x1+mCurrent2*x2+mCurrent3*x3+bCurrent
mse=(1/n)*np.sum([val**2 for val in (y-y_predict)])
mPartDerivative1=-(2/n)*np.sum(x1*(y-y_predict))
mPartDerivative2=-(2/n)*np.sum(x2*(y-y_predict))
mPartDerivative3=-(2/n)*np.sum(x3*(y-y_predict))
bPartDerivative=-(2/n)*np.sum(y-y_predict)
mCurrent1=mCurrent1-(learningRate*mPartDerivative1)
mCurrent2=mCurrent2-(learningRate*mPartDerivative2)
mCurrent3=mCurrent3-(learningRate*mPartDerivative3)
bCurrent=bCurrent-(learningRate*bPartDerivative)
print('m1:{} m2:{} m3:{} b:{} iter:{} mse:{}'.format(mCurrent1,mCurrent2,mCurrent3,bCurrent,i,mse))
return(round(mCurrent1,3),round(mCurrent2,3),round(mCurrent3,3),round(bCurrent,3))
最佳答案
看起来您的程序应该可以工作。但是,您的学习速度可能太小。请记住,学习率就是您要降低成本函数的步骤的大小。如果学习率太小,则会使成本曲线沿速度下降得太慢,并且需要很长时间才能达到收敛(需要较大的迭代次数)。但是,如果学习率太大,那么就会出现发散问题。选择正确的学习率和迭代次数(换句话说,调整您的超参数)更多的是艺术而不是科学。您应该以不同的学习率玩耍。
我创建了自己的数据集并随机生成了数据(其中(m1, m2, m3, b) = (10, 5, 4, 2)
)并运行了代码:
import pandas as pd
import numpy as np
x1 = np.random.rand(100,1)
x2 = np.random.rand(100,1)
x3 = np.random.rand(100,1)
y = 2 + 10 * x1 + 5 * x2 + 4 * x3 + 2 * np.random.randn(100,1)
df = pd.DataFrame(np.c_[y,x1,x2,x3],columns=['y','x1','x2','x3'])
#df.head()
# y x1 x2 x3
# 0 11.970573 0.785165 0.012989 0.634274
# 1 19.980349 0.919672 0.971063 0.752341
# 2 2.884538 0.170164 0.991058 0.003270
# 3 8.437686 0.474261 0.326746 0.653011
# 4 14.026173 0.509091 0.921010 0.375524
以
0.0000001
的学习率运行算法会产生以下结果:(m1, m2, m3, b) = (0.001, 0.001, 0.001, 0.002)
以
.1
的学习率运行算法会产生以下结果:(m1, m2, m3, b) = (9.382, 4.841, 4.117, 2.485)
请注意,当学习速率为
0.0000001
时,您的系数与其开始的位置(0
)并没有太大差异。就像我之前说的那样,学习率低使之成为现实,因此我们以太小的速率来改变系数,因为我们正在以超小的步长减小成本函数。我添加了一张图片以帮助可视化选择步长。请注意,第一张图片使用较小的学习率,第二张图片使用较大的学习率。
学习率低:
学习率高:
关于python - 多元回归值的梯度下降未收敛,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/54352913/