我刚刚在youtube上使用Siraj Raval的视频开始了机器学习,并尝试了视频“ Intro-The Math of Intelligence”的挑战,该视频是使用kaggle.com的数据集使用Gradient Descent执行线性回归。这是我的代码:

"""
An Example of a Linear Regression model.

Here i am taking an example from https://www.kaggle.com/alopez247/pokemon
to find a relation between variable "Total" and "HP".

"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import sys
import os

data = pd.read_csv("./pokemon_alopez247.csv")
d = {"Total": data['Total'],
     "HP": data['HP']}
smallData = pd.DataFrame(d)
test = smallData.values
epsilon = 0.001


def compute_error_for_line(b, m, points):
    """Return the Error for Line given the points."""
    totalError = 0
    for i in range(0, len(points)):
        x = test[i, 0]
        y = test[i, 1]
        totalError += (y - (m * x + b)) ** 2
    return totalError / float(len(points))


def step_gradient(b_current, m_current, points, learningRate):
    """Return the new b and m points."""
    b_gradient = 0
    m_gradient = 0
    N = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        error = y - ((m_current * x) + b_current)
        b_gradient += -(2 / N) * error
        m_gradient += -(2 / N) * x * error
    new_b = b_current - (learningRate * b_gradient)
    new_m = m_current - (learningRate * m_gradient)
    return [new_b, new_m]


def main():
    """Return and plot function here."""
    plt.figure(num=None, figsize=(20, 10), dpi=80,
               facecolor='w', edgecolor='k')
    plt.axis([0, 780, 0, 260])
    plt.ylabel("Total")
    plt.xlabel("HP")
    plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)

    m = 0.3
    b = -30
    x = np.arange(800)
    y = m * x + b
    for i in range(30):
        error = compute_error_for_line(b, m, test)
        print("error :", error)
        if(error > epsilon):
            y = m * x + b
            plt.plot(x, y)
            b, m = step_gradient(b, m, test, 0.0001)
            print("b , m :", b, ",", m)
            plt.pause(0.01)

    plt.show()

    plt.pause(0.001)

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print('Interrupted')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)


输出为:

error : 193676.072288
b , m : -29.91451362 , 6.46934413315
/usr/local/lib/python3.5/dist-packages/matplotlib/backend_bases.py:2445: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
  warnings.warn(str, mplDeprecation)
error : 16427.2683093
b , m : -29.9134163218 , 6.04491523016
error : 15588.2873385
b , m : -29.9065147511 , 6.07401898958
error : 15583.8939554
b , m : -29.9000125838 , 6.07192788394
error : 15583.4489928
b , m : -29.8934831191 , 6.07198242461
error : 15583.0227312
b , m : -29.8869557061 , 6.07188938575
error : 15582.5965792
b , m : -29.8804283262 , 6.07180649992
error : 15582.1704489
b , m : -29.8739011182 , 6.07172291798
error : 15581.74434
b , m : -29.8673740726 , 6.07163938615
error : 15581.3182523
b , m : -29.86084719 , 6.0715558531
error : 15580.8921858
b , m : -29.8543204704 , 6.07147232236
error : 15580.4661407
b , m : -29.8477939138 , 6.0713887937
error : 15580.0401168
b , m : -29.8412675201 , 6.07130526712
error : 15579.6141143
b , m : -29.8347412894 , 6.07122174263
error : 15579.1881329
b , m : -29.8282152217 , 6.07113822022
error : 15578.7621729
b , m : -29.821689317 , 6.0710546999
error : 15578.3362341
b , m : -29.8151635752 , 6.07097118166
error : 15577.9103166
b , m : -29.8086379963 , 6.07088766551
error : 15577.4844204
b , m : -29.8021125804 , 6.07080415145
error : 15577.0585455
b , m : -29.7955873275 , 6.07072063947
error : 15576.6326918
b , m : -29.7890622375 , 6.07063712957
error : 15576.2068594
b , m : -29.7825373104 , 6.07055362176
error : 15575.7810482
b , m : -29.7760125462 , 6.07047011604
error : 15575.3552583
b , m : -29.769487945 , 6.0703866124
error : 15574.9294897
b , m : -29.7629635067 , 6.07030311084
error : 15574.5037423
b , m : -29.7564392314 , 6.07021961138
error : 15574.0780162
b , m : -29.7499151189 , 6.07013611399
error : 15573.6523114
b , m : -29.7433911694 , 6.07005261869
error : 15573.2266278
b , m : -29.7368673827 , 6.06996912548
error : 15572.8009655
b , m : -29.730343759 , 6.06988563435
[Finished in 73.209s]


因此输出表明一切都按计划进行。但是请看this。第一个蓝色是原始值,线越来越远!我尝试重新编写compute_error_for_line和step_gradient函数,但仍然没有执行任何操作。
感谢您阅读到底。

那么如何获得最适合我的样本空间的线的参数呢?

链接到我的csv文件here(此文件将在22小时后过期)。

最佳答案

    plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)


看起来您交换了x和y值。如果将[1]更改为[0],反之亦然,则该图看起来很好

关于python - 在Python中实现线性回归,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45045355/

10-11 20:29