去年,我在Matlab中为线性回归程序中的设计矩阵编写了代码。它工作正常。现在,我需要将其翻译为Python并在Pycharm中运行。我已经使用了好几天了,虽然我是Python的新手,但是我在翻译中找不到任何错误,但是当代码与程序的其余部分一起运行时却出现错误。
Matlab中的代码:
function DesignMatrix = design_matrix( xTrain, M )
% This function calculates the Design Matrix for
% a M-th degree polynomial
% xTrain - training set Nx1
% M - polynomial degree 0,1,2,...
N = size(xTrain,1);
DesignMatrix = zeros(N,M+1);
for i=1:M+1
DesignMatrix(:,i)=xTrain.^(i-1)
end
end
和我在Python中的翻译(np代表numpy,已导入):
def design_matrix(x_train,M):
'''
:param x_train: input vector Nx1
:param M: polynomial degree 0,1,2,...
:return: Design Matrix Nx(M+1) for M degree polynomial
'''
desm = np.zeros(shape =(len(x_train), M+1))
for i in range(1, M+1):
desm[:,i] = np.power(x_train, (i-1))
return desm
pass
错误指向此行:
desm[:,i] = np.power(x_train, (i-1))
,这是一个值错误。我尝试使用在线翻译器ompc,但由于它对我不起作用,它似乎已过时。如果我的翻译中有任何明显的错误,有人可以向我解释吗?我知道这是更大程序的一部分,但是我要问的只是语法翻译本身。如果是正确的话,尽管到目前为止我没有提出任何建议,但我会尝试查找其他任何错误。谢谢。编辑:回溯
ERROR: test_design_matrix (test.TestDesignMatrix)
----------------------------------------------------------------------
Traceback (most recent call last):
File "...\test.py", line 61, in test_design_matrix
dm_computed = design_matrix(x_train, M)
File "...\content.py", line 34, in design_matrix
desm[:,i] = np.power(x_train, (i-1))
ValueError: could not broadcast input array from shape (20,1) into shape (20)
我无法更改test.py文件,该文件已提供给我,无法更改,因此我仅依靠第二个错误。
给出错误的函数的test.py摘录:
def test_design_matrix(self):
x_train = TEST_DATA['design_matrix']['x_train']
M = TEST_DATA['design_matrix']['M']
dm = TEST_DATA['design_matrix']['dm']
dm_computed = design_matrix(x_train, M)
max_diff = np.max(np.abs(dm - dm_computed))
self.assertAlmostEqual(max_diff, 0, 8)
最佳答案
你可以尝试一下:
def design_matrix(x_train,M):
'''
:param x_train: input vector Nx1
:param M: polynomial degree 0,1,2,...
:return: Design Matrix Nx(M+1) for M degree polynomial
'''
x_train = np.asarray(x_train)
desm = np.zeros(shape =(len(x_train), M+1))
for i in range(0, M+1):
desm[:,i] = np.power(x_train, i).reshape(x_train.shape[0],)
return desm
该错误来自于不兼容的Numpy数组尺寸。 desm [:,i]的形状为(n,),但是您要存储的值的形状为(n,1),因此需要将其重塑为(n,)。另外,正如GLR所述,Python索引从0开始,因此您需要修改索引,并且函数执行在返回行处停止,因此根本无法到达传递行。
关于python - Matlab将Python转换为设计矩阵函数,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/42737252/