BP算法实例—鸢尾花的分类(Python)

2021-06-26 23:05

阅读:667

标签:python   hang   变量   init   .com   random   tps   输出   更新   

首先了解下Iris鸢尾花数据集:

       Iris数据集(https://en.wikipedia.org/wiki/Iris_flower_data_set)是常用的分类实验数据集,由Fisher,1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
iris以鸢尾花的特征作为数据来源,常用在分类操作中。该数据集由3种不同类型的鸢尾花的50个样本数据构成。其中的一个种类与另外两个种类是线性可分离的,后两个种类是非线性可分离的。

该数据集包含了4个属性:
        Sepal.Length(花萼长度),单位是cm;
        Sepal.Width(花萼宽度),单位是cm;
        Petal.Length(花瓣长度),单位是cm;
        Petal.Width(花瓣宽度),单位是cm;
种类:Iris Setosa(1.山鸢尾)、Iris Versicolour(2.杂色鸢尾),以及Iris Virginica(3.维吉尼亚鸢尾)。

技术分享图片Python源码:

  1 from __future__ import division
  2 import math
  3 import random
  4 import pandas as pd
  5  
  6  
  7 flowerLables = {0: Iris-setosa,
  8                 1: Iris-versicolor,
  9                 2: Iris-virginica}
 10  
 11 random.seed(0)
 12  
 13  
 14 # 生成区间[a, b)内的随机数
 15 def rand(a, b):
 16     return (b - a) * random.random() + a
 17  
 18  
 19 # 生成大小 I*J 的矩阵,默认零矩阵
 20 def makeMatrix(I, J, fill=0.0):
 21     m = []
 22     for i in range(I):
 23         m.append([fill] * J)
 24     return m
 25  
 26  
 27 # 函数 sigmoid
 28 def sigmoid(x):
 29     return 1.0 / (1.0 + math.exp(-x))
 30  
 31  
 32 # 函数 sigmoid 的导数
 33 def dsigmoid(x):
 34     return x * (1 - x)
 35  
 36  
 37 class NN:
 38     """ 三层反向传播神经网络 """
 39  
 40     def __init__(self, ni, nh, no):
 41         # 输入层、隐藏层、输出层的节点(数)
 42         self.ni = ni + 1  # 增加一个偏差节点
 43         self.nh = nh + 1
 44         self.no = no
 45  
 46         # 激活神经网络的所有节点(向量)
 47         self.ai = [1.0] * self.ni
 48         self.ah = [1.0] * self.nh
 49         self.ao = [1.0] * self.no
 50  
 51         # 建立权重(矩阵)
 52         self.wi = makeMatrix(self.ni, self.nh)
 53         self.wo = makeMatrix(self.nh, self.no)
 54         # 设为随机值
 55         for i in range(self.ni):
 56             for j in range(self.nh):
 57                 self.wi[i][j] = rand(-0.2, 0.2)
 58         for j in range(self.nh):
 59             for k in range(self.no):
 60                 self.wo[j][k] = rand(-2, 2)
 61  
 62     def update(self, inputs):
 63         if len(inputs) != self.ni - 1:
 64             raise ValueError(与输入层节点数不符!)
 65  
 66         # 激活输入层
 67         for i in range(self.ni - 1):
 68             self.ai[i] = inputs[i]
 69  
 70         # 激活隐藏层
 71         for j in range(self.nh):
 72             sum = 0.0
 73             for i in range(self.ni):
 74                 sum = sum + self.ai[i] * self.wi[i][j]
 75             self.ah[j] = sigmoid(sum)
 76  
 77         # 激活输出层
 78         for k in range(self.no):
 79             sum = 0.0
 80             for j in range(self.nh):
 81                 sum = sum + self.ah[j] * self.wo[j][k]
 82             self.ao[k] = sigmoid(sum)
 83  
 84         return self.ao[:]
 85  
 86     def backPropagate(self, targets, lr):
 87         """ 反向传播 """
 88  
 89         # 计算输出层的误差
 90         output_deltas = [0.0] * self.no
 91         for k in range(self.no):
 92             error = targets[k] - self.ao[k]
 93             output_deltas[k] = dsigmoid(self.ao[k]) * error
 94  
 95         # 计算隐藏层的误差
 96         hidden_deltas = [0.0] * self.nh
 97         for j in range(self.nh):
 98             error = 0.0
 99             for k in range(self.no):
100                 error = error + output_deltas[k] * self.wo[j][k]
101             hidden_deltas[j] = dsigmoid(self.ah[j]) * error
102  
103         # 更新输出层权重
104         for j in range(self.nh):
105             for k in range(self.no):
106                 change = output_deltas[k] * self.ah[j]
107                 self.wo[j][k] = self.wo[j][k] + lr * change
108  
109         # 更新输入层权重
110         for i in range(self.ni):
111             for j in range(self.nh):
112                 change = hidden_deltas[j] * self.ai[i]
113                 self.wi[i][j] = self.wi[i][j] + lr * change
114  
115         # 计算误差
116         error = 0.0
117         error += 0.5 * (targets[k] - self.ao[k]) ** 2
118         return error
119  
120     def test(self, patterns):
121         count = 0
122         for p in patterns:
123             target = flowerLables[(p[1].index(1))]
124             result = self.update(p[0])
125             index = result.index(max(result))
126             print(p[0], :, target, ->, flowerLables[index])
127             count += (target == flowerLables[index])
128         accuracy = float(count / len(patterns))
129         print(accuracy: %-.9f % accuracy)
130  
131     def weights(self):
132         print(输入层权重:)
133         for i in range(self.ni):
134             print(self.wi[i])
135         print()
136         print(输出层权重:)
137         for j in range(self.nh):
138             print(self.wo[j])
139  
140     def train(self, patterns, iterations=1000, lr=0.1):
141         # lr: 学习速率(learning rate)
142         for i in range(iterations):
143             error = 0.0
144             for p in patterns:
145                 inputs = p[0]
146                 targets = p[1]
147                 self.update(inputs)
148                 error = error + self.backPropagate(targets, lr)
149             if i % 100 == 0:
150                 print(error: %-.9f % error)
151  
152  
153  
154 def iris():
155     data = []
156     # 读取数据
157     raw = pd.read_csv(iris.csv)
158     raw_data = raw.values
159     raw_feature = raw_data[0:, 0:4]
160     for i in range(len(raw_feature)):
161         ele = []
162         ele.append(list(raw_feature[i]))
163         if raw_data[i][4] == Iris-setosa:
164             ele.append([1, 0, 0])
165         elif raw_data[i][4] == Iris-versicolor:
166             ele.append([0, 1, 0])
167         else:
168             ele.append([0, 0, 1])
169         data.append(ele)
170     # 随机排列数据
171     random.shuffle(data)
172     training = data[0:100]
173     test = data[101:]
174     nn = NN(4, 7, 3)
175     nn.train(training, iterations=10000)
176     nn.test(test)
177  
178  
179 if __name__ == __main__:
180     iris()

 

BP算法实例—鸢尾花的分类(Python)

标签:python   hang   变量   init   .com   random   tps   输出   更新   

原文地址:https://www.cnblogs.com/duanhx/p/9655217.html


评论


亲,登录后才可以留言!