周志华西瓜书3.4题。
本文所编写的代码均使用python3.7进行调试,依靠的sklearn进行的实验。
第一步,导入iris数据集,数据集使用sklearn包里面自带的。

from sklearn.linear_model import LogisticRegression
from sklearn import model_selection
from sklearn.datasets import load_iris

# 载入iris数据
data = load_iris()

第二步,用10次十折交叉验证法估计对率回归的精度。(这里所用的循环即为10次)

# 十折交叉验证生成训练集和测试集
def tenfolds():
    k = 0
    truth = []
    while k < 10:
        kf = model_selection.KFold(n_splits=10, random_state=None, shuffle=True)
        for x_train_index, x_test_index in kf.split(data.data):
            x_train = data.data[x_train_index]
            y_train = data.target[x_train_index]
            x_test = data.data[x_test_index]
            y_test = data.target[x_test_index]

        # 验证生成数组长度是否符合规格
        print(len(x_train),len(x_test))

        # 用对率回归进行训练,拟合数据
        log_model = LogisticRegression(multi_class= 'ovr', solver = 'liblinear')
        log_model.fit(x_train, y_train)

        # 用训练好的模型预测
        y_pred = log_model.predict(x_test)
        for i in range(15):
            if y_pred[i] == y_test[i]:
                truth.append(y_pred[i] == y_test)
        k += 1

        # 计算精度
    accuracy = len(truth)/150
    print("用10折交叉验证对率回归的精度是:", accuracy)

第三步,用留一法估计对率回归的精度。(这里循环了150次)

# 用留一法验证
def leaveone():
    loo = model_selection.LeaveOneOut()
    i = 0
    true = 0
    while i < 150:
        for x_train_index, x_test_index in loo.split(data.data):
            x_train = data.data[x_train_index]
            y_train = data.target[x_train_index]
            x_test = data.data[x_test_index]
            y_test = data.target[x_test_index]

        # 用对率回归进行训练,拟合数据
        log_model = LogisticRegression(multi_class='ovr', solver='liblinear')
        log_model.fit(x_train, y_train)

        # 用训练好的模型预测

        y_pred = log_model.predict(x_test)
        if y_pred == y_test:
            true += 1

        i += 1

    # 计算精度
    accuracy = true / 150
    print("用留一法验证对率回归的精度是:", accuracy)

注:使用的时候直接调用相应的定义函数即可。
主要参考的博文:https://blog.csdn.net/catherined/article/details/82015857
欢迎学习交流!

03-14 09:59