第五讲 卷积神经网络 - Resnet--cifar10
2021-03-09 18:29
标签:pad str otl use ide __init__ and cross nump 第五讲 卷积神经网络 - Resnet--cifar10 标签:pad str otl use ide __init__ and cross nump 原文地址:https://www.cnblogs.com/wbloger/p/12862315.html 1 import tensorflow as tf
2 import os
3 import numpy as np
4 from matplotlib import pyplot as plt
5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D
6 from tensorflow.keras import Model
7
8 np.set_printoptions(threshold=np.inf)
9
10
11 cifar10 = tf.keras.datasets.cifar10
12 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
13 x_train, x_test = x_train/255.0, x_test/255.0
14
15
16
17 class ResnetBlock(Model):
18 def __init__(self, filters, strides=1, residual_path=False):
19 super(ResnetBlock, self).__init__()
20 self.filters = filters
21 self.strides = strides
22 self.residual_path = residual_path
23
24 self.c1 = Conv2D(filters, (3, 3), strides=strides, padding=‘same‘, use_bias=False)
25 self.b1 = BatchNormalization()
26 self.a1 = Activation(‘relu‘)
27
28 self.c2 = Conv2D(filters, (3, 3), strides=1, padding=‘same‘, use_bias=False)
29 self.b2 = BatchNormalization()
30
31 # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加
32 if residual_path:
33 self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding=‘same‘, use_bias=False)
34 self.down_b1 = BatchNormalization()
35
36 self.a2 = Activation(‘relu‘)
37
38 def call(self, inputs):
39 residual = inputs # residual等于输入值本身,即residual=x
40 x = self.c1(inputs)
41 x = self.b1(x)
42 x = self.a1(x)
43
44 x = self.c2(x)
45 y = self.b2(x)
46
47 if self.residual_path:
48 residual = self.down_c1(inputs)
49 residual = self.down_b1(residual)
50
51 out = self.a2(y + residual) # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数
52 return out
53
54
55
56 class ResNet18(Model):
57 def __init__(self, block_list, initial_filters=64): # block_list表示每个block有几个卷积层
58 super(ResNet18, self).__init__()
59 self.num_blocks = len(block_list) # 共有几个block
60 self.block_list = block_list
61 self.out_filters = initial_filters
62 self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding=‘same‘, use_bias = False)
63 self.b1 = BatchNormalization()
64 self.a1 = Activation(‘relu‘)
65 self.blocks = tf.keras.models.Sequential()
66 # 构建ResNet网络结构
67 for block_id in range(len(block_list)):
68 for layer_id in range(block_list[block_id]):
69 if block_id != 0 and layer_id == 0: # 对除第一个block以外的每个block的输入进行下采样
70 block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
71 else:
72 block = ResnetBlock(self.out_filters, residual_path=False)
73 self.blocks.add(block) # 将构建好的block加入resnet
74 self.out_filters *= 2 # 下一个block的卷积核数是上一个block的2倍
75 self.p1 = tf.keras.layers.GlobalAveragePooling2D()
76 self.f1 = tf.keras.layers.Dense(10, activation=‘softmax‘, kernel_regularizer=tf.keras.regularizers.l2())
77
78
79 def call(self, inputs):
80 x = self.c1(inputs)
81 x = self.b1(x)
82 x = self.a1(x)
83 x = self.blocks(x)
84 x = self.p1(x)
85 y = self.f1(x)
86 return y
87
88
89
90 model = ResNet18([2, 2, 2, 2])
91
92 model.compile(optimizer=‘adam‘,
93 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
94 metrics=[‘sparse_categorical_accuracy‘])
95
96
97 checkpoint_save_path = "./checkpoint/Inception10.ckpt"
98 if os.path.exists(checkpoint_save_path + ‘.index‘):
99 print(‘-------------load the model---------------‘)
100 model.load_weights(checkpoint_save_path)
101
102 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path,
103 save_weights_only = True,
104 save_best_only = True)
105
106 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,
107 callbacks=[cp_callback])
108 model.summary()
109
110
111
112 with open(‘./weights.txt‘, ‘w‘) as f:
113 for v in model.trainable_variables:
114 f.write(str(v.name) + ‘\n‘)
115 f.write(str(v.shape) + ‘\n‘)
116 f.write(str(v.numpy()) + ‘\n‘)
117
118
119 def plot_acc_loss_curve(history):
120 # 显示训练集和验证集的acc和loss曲线
121 from matplotlib import pyplot as plt
122 acc = history.history[‘sparse_categorical_accuracy‘]
123 val_acc = history.history[‘val_sparse_categorical_accuracy‘]
124 loss = history.history[‘loss‘]
125 val_loss = history.history[‘val_loss‘]
126
127 plt.figure(figsize=(15, 5))
128 plt.subplot(1, 2, 1)
129 plt.plot(acc, label=‘Training Accuracy‘)
130 plt.plot(val_acc, label=‘Validation Accuracy‘)
131 plt.title(‘Training and Validation Accuracy‘)
132 plt.legend()
133 #plt.grid()
134
135 plt.subplot(1, 2, 2)
136 plt.plot(loss, label=‘Training Loss‘)
137 plt.plot(val_loss, label=‘Validation Loss‘)
138 plt.title(‘Training and Validation Loss‘)
139 plt.legend()
140 #plt.grid()
141 plt.show()
142
143 plot_acc_loss_curve(history)
文章标题:第五讲 卷积神经网络 - Resnet--cifar10
文章链接:http://soscw.com/index.php/essay/62403.html