• 微信公众号:美女很有趣。 工作之余,放松一下,关注即送10G+美女照片!

鸢尾花数据集—-决策树vs神经网络

开发技术 开发技术 3小时前 1次浏览

为方便理解两种不同预测分类算法 我们均调用 sklearn 里 datasets 的鸢尾花数据集

决策树:

  1 import numpy as np
  2 from sklearn import datasets
  3 from sklearn.model_selection import train_test_split
  4 import matplotlib as mpl
  5 import matplotlib.pyplot as plt
  6 from sklearn import tree
  7 from sklearn.pipeline import Pipeline
  8 from sklearn.tree import DecisionTreeClassifier
  9 from sklearn.preprocessing import StandardScaler
 10 
 11 # 防止画图汉字乱码
 12 mpl.rcParams['font.sans-serif'] = [u'SimHei']
 13 mpl.rcParams['axes.unicode_minus'] = False
 14 
 15 #数据准备
 16 dataset = datasets.load_iris()  # 此时 训练数据(train)与标签(target) 已经分离 为 字典 数据集
 17 # 数据集 已经将标签数据化(化为0-2标签值) 无需再处理
 18 
 19 data = dataset['data']  # 取出对应键 的值  值为array类型
 20 target = dataset['target']
 21 # input = torch.FloatTensor(dataset['data'])
 22 # y = torch.LongTensor(dataset['target'])
 23 
 24 x = np.array(data)
 25 y = np.array(target)
 26 x = x[:, :2]  # 此时的数据为 150行 4列   为方便画图  我们只取前两个特征
 27 # 将数据集 7 / 3 分
 28 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)  
 29 
 30 model = Pipeline([
 31     ('ss', StandardScaler()),
 32     ('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))])
 33 # clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
 34 model = model.fit(x_train, y_train)
 35 y_test_hat = model.predict(x_test)  # 测试数据  y_test_hat 为预测值
 36 # print(y_test)         45个预测样本的真实标签
 37 # [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2 1 2 1 2 2 0 1 0 1 2 2 0 2 2 1]
 38 # print(y_test_hat)     45个预测样本的预测标签
 39 # [0 1 2 0 2 2 2 0 0 2 1 0 2 2 1 0 1 1 0 0 1 0 2 0 2 1 0 0 1 2 1 2 1 2 1 0 1 0 2 2 2 0 1 2 2]
 40 
 41 
 42 # 保存
 43 # dot -Tpng -o 1.png 1.dot
 44 f = open('.\iris_tree.dot', 'w')
 45 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)
 46 
 47 # 画图
 48 N, M = 100, 100  # 横纵各采样多少个值
 49 x1_min, x1_max = x[:, 0].min(), x[:, 0].max()  # 第0列的范围
 50 x2_min, x2_max = x[:, 1].min(), x[:, 1].max()  # 第1列的范围
 51 t1 = np.linspace(x1_min, x1_max, N)
 52 t2 = np.linspace(x2_min, x2_max, M)
 53 x1, x2 = np.meshgrid(t1, t2)  # 生成 v 网格采样点
 54 x_show = np.stack((x1.flat, x2.flat), axis=1)  # 测试点
 55 
 56 # # 无意义,只是为了凑另外两个维度
 57 # # 打开该注释前,确保注释掉x = x[:, :2]
 58 # x3 = np.ones(x1.size) * np.average(x[:, 2])
 59 # x4 = np.ones(x1.size) * np.average(x[:, 3])
 60 # x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1)  # 测试点
 61 
 62 cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
 63 cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
 64 y_show_hat = model.predict(x_show)  # 预测值  预测的标签值
 65 
 66 y_show_hat = y_show_hat.reshape(x1.shape)  # 使之与输入的形状相同
 67 plt.figure(facecolor='w')
 68 plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light)  # 预测值的显示
 69 plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o')  # 测试数据
 70 plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark)  # 全部数据
 71 plt.xlabel("花萼长度", fontsize=15)  # 花萼长度、花萼宽度
 72 plt.ylabel("花萼宽度", fontsize=15)
 73 plt.xlim(x1_min, x1_max)
 74 plt.ylim(x2_min, x2_max)
 75 plt.grid(True)
 76 plt.title(u'鸢尾花数据的决策树分类', fontsize=17)
 77 plt.show()
 78 
 79 # 训练集上的预测结果
 80 y_test = y_test.reshape(-1)
 81 
 82 result = (y_test_hat == y_test)  # True则预测正确,False则预测错误
 83 acc = np.mean(result)
 84 print('准确度: %.2f%%' % (100 * acc))
 85 
 86 # 过拟合:错误率
 87 depth = np.arange(1, 45)
 88 err_list = []
 89 for d in depth:  # 进行15
 90     clf = DecisionTreeClassifier(criterion='entropy', max_depth=d)
 91     clf = clf.fit(x_train, y_train)
 92     y_test_hat = clf.predict(x_test)  # 测试数据
 93     result = (y_test_hat == y_test)  # True则预测正确,False则预测错误
 94     err = 1 - np.mean(result)
 95     err_list.append(err)
 96     print(d, ' 准确度: %.2f%%' % (100 * err))
 97 plt.figure(facecolor='w')
 98 plt.plot(depth, err_list, 'ro-', lw=2)
 99 plt.xlabel(u'决策树深度', fontsize=15)
100 plt.ylabel(u'错误率', fontsize=15)
101 plt.title(u'决策树深度与过拟合', fontsize=17)
102 plt.grid(True)
103 
104 plt.show()
105 
106 from sklearn import tree  # 需要导入的包
107 
108 f = open('D:\py_project\iris_tree.dot', 'w')
109 
110 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)

鸢尾花数据集----决策树vs神经网络

鸢尾花数据集----决策树vs神经网络


 

神经网络:

 1 import numpy as np
 2 from collections import Counter
 3 from sklearn import datasets
 4 import torch.nn.functional as Fun
 5 from torch.autograd import Variable
 6 import matplotlib.pyplot as plt
 7 import torch
 8 
 9 dataset = datasets.load_iris()
10 dataut=dataset['data']
11 priciple=dataset['target']
12 
13 input=torch.FloatTensor(dataset['data'])
14 label=torch.LongTensor(dataset['target'])
15 
16 #定义BP神经网络
17 class Net(torch.nn.Module):
18     def __init__(self, n_feature, n_hidden, n_output):
19         super(Net, self).__init__()
20         self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
21         self.out = torch.nn.Linear(n_hidden, n_output)   # output layer
22 
23     def forward(self, x):
24         x = Fun.relu(self.hidden(x))      # activation function for hidden layer we choose sigmoid
25         x = self.out(x)
26         return x
27 
28 net = Net(n_feature=4, n_hidden=20, n_output=3)
29 optimizer = torch.optim.SGD(net.parameters(), lr=0.02) #SGD: 随机梯度下降
30 loss_func = torch.nn.CrossEntropyLoss() #针对分类问题的损失函数!
31 
32 #训练数据
33 for t in range(500):
34     out = net(input)                 # input x and predict based on x
35     loss = loss_func(out, label)     # 输出与label对比
36     optimizer.zero_grad()   # clear gradients for next train
37     loss.backward()         # backpropagation, compute gradients
38     optimizer.step()        # apply gradients
39 
40 out = net(input) #out是一个计算矩阵,可以用Fun.softmax(out)转化为概率矩阵
41 prediction = torch.max(out, 1)[1] # 1返回index  0返回原值
42 pred_y = prediction.data.numpy()
43 target_y = label.data.numpy()
44 accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
45 print("莺尾花预测准确率",accuracy)

 

鸢尾花数据集:

共150个分为 三种类别  setosa,versicolor,virginnica
花萼长度、花萼宽度,花瓣长度,花瓣宽度,种类

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica




程序员灯塔
转载请注明原文链接:鸢尾花数据集—-决策树vs神经网络
喜欢 (0)