我试图将一段代码从Matlab转换为Python,但遇到了一些错误:

Matlab:

function [beta] = linear_regression_train(traindata)
y = traindata(:,1); %output
ind2 = find(y == 2);
ind3 = find(y == 3);
y(ind2) = -1;
y(ind3) = 1;
X = traindata(:,2:257); %X matrix,with size of 1389x256
beta = inv(X'*X)*X'*y;


蟒蛇:

def linear_regression_train(traindata):
        y = traindata[:,0] # This is the output
        ind2 = (labels==2).nonzero()
        ind3 = (labels==3).nonzero()
        y[ind2] = -1
        y[ind3] = 1
        X = traindata[ : , 1:256]
        X_T = numpy.transpose(X)
        beta = inv(X_T*X)*X_T*y
        return beta


我收到一个错误:无法在计算beta的行上将操作数与形状(257,0,1389)(1389,0,257)一起广播。

任何帮助表示赞赏!

谢谢!

最佳答案

问题是您正在使用numpy数组,而不是MATLAB中的矩阵。默认情况下,矩阵执行矩阵数学运算。因此X*Y进行XY的矩阵乘法。但是,对于数组,默认设置是使用逐个元素的操作。因此,X*YXY的每个对应元素相乘。这等效于MATLAB的.*操作。

但是就像MATLAB的矩阵可以进行逐元素运算一样,Numpy的数组也可以进行矩阵乘法。因此,您需要做的是使用numpy的矩阵乘法,而不是逐个元素的乘法。对于Python 3.5或更高版本(用于这种工作的版本),它只是@运算符。因此,您的行变为:

beta = inv(X_T @ X) @ X_T @ y


或者,更好的是,您可以使用更简单的.T转置,它与np.transpose相同,但更加简洁(您可以完全摆脱`np.transpose行):

beta = inv(X.T @ X) @ X.T @ y


对于Python 3.4或更早版本,您将需要使用np.dot,因为这些版本的python没有@矩阵乘法运算符:

beta = np.dot(np.dot(inv(np.dot(X.T, X)), X.T), y)


Numpy有一个默认情况下像MATLAB矩阵一样使用矩阵运算的矩阵对象。不要使用它!它运行缓慢,支持不力,而且几乎从没有您真正想要的东西。 Python社区已经围绕数组进行了标准化,因此请使用它们。

traindata的尺寸可能还存在一些问题。为了使其正常工作,traindata.ndim应等于3。为了使yX为2D,traindata应该为3D

如果traindata是2D并且您希望y是MATLAB样式的“向量”(MATLAB所谓的“向量”并不是真正的向量),则可能会出现问题。在numpy中,使用像traindata[:, 0]这样的单个索引可以减少维数,而像traindata[:, :1]这样的切片则不能。因此,要在y为2D时保持traindata 2D,只需做一个长度为1的切片traindata[:, :1]。这是完全相同的值,但是保持与traindata相同的维数。

注意:使用逻辑索引可以大大简化您的代码:

def linear_regression_train(traindata):
    y = traindata[:, 0] # This is the output
    y[labels == 2] = -1
    y[labels == 3] = 1
    X = traindata[:, 1:257]
    return inv(X.T @ X) @ X.T @ y
    return beta


另外,定义X时,您的分片是错误的。 Python切片会排除最后一个值,因此要获得256个长切片,您需要像上面一样做1:257

最后,请记住,对函数内部数组的修改会保留在函数外部,并且索引不会复制。因此,对y的更改(将某些值设置为1,将其他值设置为-1)将影响功能之外的traindata。如果要避免这种情况,则需要先进行复制,然后再进行更改:

y = traindata[:, 0].copy()

关于python - 将线性回归从Matlab转换为Python,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/39195204/

10-11 23:19
查看更多