快速报名
首页 / 干货教程 / 干货教程 / python机器学习教程03 感知机对偶形式(鸢尾花分类)

python机器学习教程03 感知机对偶形式(鸢尾花分类)

上一张简单讲解了感知机得问题,后面就需要深入得解析感知机得各种形式,需要学习的小伙伴不要错过本章节哦~~~~

感知机对偶形式(鸢尾花分类)

导入模块

from matplotlib.font_manager import FontProperties

import matplotlib.pyplot as plt

import numpy as np

import pandas as pd

import random

%matplotlib inline

font = FontProperties(fname='/Library/Fonts/Heiti.ttc')

获取数据

def get_data():

    df = pd.read_csv(

        'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)

    X = df.iloc[0:100, [0, 2]].values

    train_data_p = df.iloc[0:50, [0, 2, 4]].values

    train_data_n = df.iloc[50:100, [0, 2, 4]].values

    train_data_p[:, [2]], train_data_n[:, [2]] = -1, 1

    train_data = train_data_p.tolist() + train_data_n.tolist()

    return train_data, X

训练模型

def train(num_iter, train_data, learning_rate):

    w = 0.0

    b = 0

    data_length = len(train_data)

    alpha = [0 for _ in range(data_length)]

    train_data = np.array(train_data)

    gram = np.matmul(train_data[:, 0:-1], train_data[:, 0:-1].T)

    for i in range(num_iter):

        count = 0

        i = random.randint(0, data_length - 1)

        yi = train_data[i, -1]

        for j in range(data_length):

            count += alpha[j] * train_data[j, -1] * gram[i, j]

        count += b

        if (yi * count <= 0):

            alpha[i] = alpha[i] + learning_rate

            b = b + learning_rate * yi

    for i in range(data_length):

        w += alpha[i] * train_data[i, 0:-1] * train_data[i, -1]

    return w, b, alpha, gram

可视化

def plot_points(w, b, X):

    plt.figure()

    x1 = np.linspace(4, 7, 100)

    x2 = (-b - w[0] * x1) / (w[1] + 1e-10)

    plt.plot(x1, x2, color='k')

    plt.scatter(X[:50, 0], X[:50, 1], color='r', s=50, marker='o', label='山鸢尾')

    plt.scatter(X[50:100, 0], X[50:100, 1], color='b',

                s=50, marker='x', label='变色鸢尾')

    plt.xlabel('萼片长度(cm)', fontproperties=font)

    plt.ylabel('花瓣长度(cm)', fontproperties=font)

    plt.legend(prop=font)

    plt.show()

运行

train_data, X = get_data()

w, b, alpha, gram = train(

    num_iter=1000, train_data=train_data, learning_rate=0.1)

plot_points(w, b, X)

抢先报名    优先占座