文章目录
线性回归-标准方程法示例(python原生实现)
1. 导包
import numpy as np
import matplotlib.pyplot as plt
import random
2. 原始数据生成与展示
# 原始数据(因为是随机的,所以生成的数据可能会不同)
x_data = np.linspace(1,10,100)
y_data = np.array([j * 1.5 + random.gauss(0,0.9) for j in x_data])
# 作图
plt.scatter(x_data, y_data)
plt.show()
3. 数据预处理
# 数据预处理
x_data = x_data[:, np.newaxis]
y_data = y_data[:, np.newaxis]
# 添加偏置
X_data = np.concatenate((np.ones((100, 1)), x_data), axis = 1)
print(X_data)
4. 定义利用标准方程法求weights的方法
def calc_weights(X_data, y_data):
"""
标准方程法求weights
算法: weights = (X的转置矩阵 * X矩阵)的逆矩阵 * X的转置矩阵 * Y矩阵
:param x_data: 特征数据
:param y_data: 标签数据
"""
x_mat = np.mat(X_data)
y_mat = np.mat(y_data)
xT_x = x_mat.T * x_mat
if np.linalg.det(xT_x) == 0:
print("x_mat为不可逆矩阵,不能使用标准方程法求解")
return
weights = xT_x.I * x_mat.T * y_mat
return weights
5. 求出weights
weights = calc_weights(X_data, y_data)
print(weights)
6. 结果展示
x_test = np.array([2, 5, 8, 9])[:, np.newaxis]
predict_result = weights[0] + x_test * weights[1]
# 原始数据
plt.plot(x_data, y_data, "b.")
# 预测数据
plt.plot(x_test, predict_result, "r")
plt.show()