U-Net网络实现医学图像分割
2021-02-14 06:16
标签:pat conv2 告诉 check ilo pycha data begin work U-Net网络是典型的编解码网络,常用于图像分割
首先,准备数据
第一步,将nii格式医学图像转换为png格式,按给定窗位窗宽截断
1 import numpy as np
2 import cv2 as cv
3 import nibabel as nib
4 import os
5 from PIL import Image
6 import imageio
7
8
9 def transform_ctdata(image, windowWidth, windowCenter, normal=False):
10 """
11 注意,这个函数的self.image一定得是float类型的,否则就无效!
12 return: trucated image according to window center and window width
13 """
14 minWindow = float(windowCenter) - 0.5 * float(windowWidth)
15 newimg = (image - minWindow) / float(windowWidth)
16 newimg[newimg 0
17 newimg[newimg > 1] = 1
18 if not normal:
19 newimg = (newimg * 255).astype(‘uint8‘)
20 return newimg
21
22
23 # nii文件存放路径
24 train_path = "D:/pycharm_project/graduate_design_next_semester/dataset/data_nii/train"
25 label_path = "D:/pycharm_project/graduate_design_next_semester/dataset/data_nii/label"
26 # slice存放路径
27 train_save_path = ‘./data_raw_slice_tumour/train‘
28 label_save_path = ‘./data_raw_slice_tumour/label‘
29 if not os.path.isdir(train_save_path):
30 os.makedirs(train_save_path)
31 if not os.path.isdir(label_save_path):
32 os.makedirs(label_save_path)
33
34
35 # 准备导入训练图像以及标签,并对图像和标签进行排序
36 train_images = os.listdir(train_path)
37 train_images.sort(key=lambda x: int(x.split(‘-‘)[1].split(‘.‘)[0]))
38 label_images = os.listdir(label_path)
39 label_images.sort(key=lambda x: int(x.split(‘-‘)[1].split(‘.‘)[0]))
40
41
42 def create_train_label_slice():
43 print(‘-‘ * 30)
44 print("同时分解volume与segmentation文件")
45 for i in range(len(train_images)):
46 train_image = nib.load(train_path + ‘/‘ + train_images[i])
47 label_image = nib.load(label_path + ‘/‘ + label_images[i])
48 # 获取每一个nii文件的行、列、切片数
49 height, width, slice = train_image.shape
50 print("第" + str(i) + "个dir", "(", height, width, slice, ")")
51 # 保存切片的小子文件夹序号,0,1,2等
52 slice_save_path = train_images[i].split(‘-‘)[1].split(‘.‘)[0]
53
54 train_slice_path = train_save_path + ‘/‘ + slice_save_path
55 label_slice_path = label_save_path + ‘/‘ + slice_save_path
56
57 # if not os.path.isdir(train_slice_path):
58 # os.makedirs(train_slice_path)
59 # if not os.path.isdir(label_slice_path):
60 # os.makedirs(label_slice_path)
61
62 img_fdata = label_image.get_fdata()
63 for j in range(slice):
64 train_img = train_image.dataobj[:, :, j]
65 label_img = img_fdata[:, :, j]
66
67 label_img[label_img == 1] = 0
68 label_img[label_img == 2] = 1
69
70 white_pixel = label_img == 1
71 white_pixel_num = len(label_img[white_pixel])
72
73 # # 判断是否为全黑的标签,这样没有意义,剔除
74 # if label_img.max() != 0:
75
76 # 肿瘤标签像素点数量应该大于50,才算作有效数据
77 if white_pixel_num >= 50:
78 set_slice = np.array(train_img).copy()
79 set_slice = set_slice.astype("float32")
80 # 训练用的窗位窗宽
81 # set_slice = transform_ctdata(set_slice, 350, 25)
82 # 知乎上参考,肝脏是40~60
83 set_slice = transform_ctdata(set_slice, 200, 30)
84
85 # set_slice = set_slice.astype("float32")
86 # mean = set_slice.mean()
87 # std = np.std(set_slice)
88 # set_slice -= mean
89 # set_slice /= std
90 # set_slice = (set_slice - set_slice.min()) / (set_slice.max() - set_slice.min())
91 # set_slice *= 255
92 # # set_slice = transform_ctdata(set_slice, 250, 125)
93 # set_slice = set_slice.astype("uint8")
94
95 # 中值滤波,去除椒盐噪声
96 set_slice = cv.medianBlur(set_slice, 3)
97
98 if not os.path.isdir(train_slice_path):
99 os.makedirs(train_slice_path)
100 if not os.path.isdir(label_slice_path):
101 os.makedirs(label_slice_path)
102
103 # 加入直方图均衡处理
104 # set_slice = cv.equalizeHist(set_slice)
105 cv.imwrite(train_slice_path + ‘/‘ + str(j) + ‘.png‘, set_slice)
106 label_img = Image.fromarray(np.uint8(label_img * 255))
107 imageio.imwrite(label_slice_path + ‘/‘ + str(j) + ‘.png‘, label_img)
108 else:
109 pass
110 print("Generating train data set done!")
111
112
113 if __name__ == "__main__":
114 create_train_label_slice()
第二步,计算每一位病人中切片的均值,为降低病人间切片亮度差异作准备
1 import os
2 import cv2 as cv
3 import numpy as np
4
5
6 train_raw_path = "./data_raw_slice_tumour/train"
7 label_raw_path = "./data_raw_slice_tumour/label"
8
9
10 def count_dir_mean_fun():
11 dirs = os.listdir(train_raw_path)
12 dirs.sort(key=lambda x: int(x))
13
14 mean1 = []
15 mean2 = 0.0
16 mean3 = []
17
18 for dir in dirs:
19 train_dir_path = os.path.join(train_raw_path, dir)
20 images = os.listdir(train_dir_path)
21 images.sort(key=lambda x: int(x.split(‘.‘)[0]))
22
23 image_num = len(images)
24 mean = 0.0
25 for name in images:
26 image_path = os.path.join(train_dir_path, name)
27 image = cv.imread(image_path, 0)
28 image = image.astype("float32")
29 # image /= 255
30
31 black_pixel = image 32 black_num = len(image[black_pixel])
33
34 # mean += image.mean()
35 mean += image.mean() * 512 * 512 / (512 * 512 - black_num)
36
37 mean /= image_num
38 mean1.append(mean)
39 mean2 = sum(mean1) / len(mean1)
40 mean3[:] = [x - mean2 for x in mean1]
41 print(mean1)
42 print(mean2)
43 print(mean3)
44 mean1 = np.array(mean1)
45 mean2 = np.array(mean2)
46 mean3 = np.array(mean3)
47 # mean_save_path = "./mean_array"
48 mean_save_path = "./mean_array_without_background"
49 if not os.path.isdir(mean_save_path):
50 os.makedirs(mean_save_path)
51 np.save(mean_save_path + ‘/‘ + "mean1.npy", mean1)
52 np.save(mean_save_path + ‘/‘ + "mean2.npy", mean2)
53 np.save(mean_save_path + ‘/‘ + "mean3.npy", mean3)
54
55
56 if __name__ == "__main__":
57 count_dir_mean_fun()
第三步,降低病人间亮度差异
1 import os
2 import cv2 as cv
3 import numpy as np
4
5
6 train_raw_path = "./data_raw_slice_tumour/train"
7 label_raw_path = "./data_raw_slice_tumour/label"
8
9
10 dirs = os.listdir(train_raw_path)
11 dirs.sort(key=lambda x: int(x))
12
13 train_save_path = "./data_slice_tumour_modify_brightness/train"
14 label_save_path = "./data_slice_tumour_modify_brightness/label"
15
16
17 # mean1 = np.load("./mean_array/mean1.npy")
18 # mean2 = np.load("./mean_array/mean2.npy")
19 mean1 = np.load("./mean_array_without_background/mean1.npy")
20 mean2 = np.load("./mean_array_without_background/mean2.npy")
21 for i in range(len(dirs)):
22 dir_name = dirs[i]
23 mean_sub = mean1[i]
24 mean_add = mean2
25
26 train_dir_path = os.path.join(train_raw_path, dir_name)
27 images = os.listdir(train_dir_path)
28 images.sort(key=lambda x: int(x.split(‘.‘)[0]))
29 for name in images:
30 image_path = os.path.join(train_dir_path, name)
31 image = cv.imread(image_path, 0)
32 image = image.astype("float32")
33 # image /= 255
34 image -= mean_sub
35 image += mean_add
36 # image *= 255
37 # image = image.astype("uint8")
38
39 save_path = os.path.join(train_save_path, dir_name)
40 if not os.path.isdir(save_path):
41 os.makedirs(save_path)
42 cv.imwrite(save_path + ‘/‘ + name, image)
43 print("完成第{}个dir".format(i))
第四步,对切片进行裁剪,减少不必要背景
1 import os
2 import cv2 as cv
3 import numpy as np
4
5
6 raw_slice_train_path = "./data_slice_tumour_modify_brightness/train"
7 raw_slice_label_path = "./data_raw_slice_tumour/label"
8
9
10 train_clip_save_path = "./data_cv_clip_whole/train"
11 label_clip_save_path = "./data_cv_clip_whole/label"
12
13
14 dirs = os.listdir(raw_slice_train_path)
15 dirs.sort(key=lambda x: int(x))
16
17 j = 0
18 for dir in dirs:
19 train_dir_path = os.path.join(raw_slice_train_path, dir)
20 label_dir_path = os.path.join(raw_slice_label_path, dir)
21
22 names = os.listdir(train_dir_path)
23 names.sort(key=lambda x: int(x.split(‘.‘)[0]))
24
25 i = 0
26 for name in names:
27 train_img_path = os.path.join(train_dir_path, name)
28 label_img_path = os.path.join(label_dir_path, name)
29
30 train_img = cv.imread(train_img_path, 0)
31 label_img = cv.imread(label_img_path, 0)
32
33 train_clip_img = train_img[0:400, 50:450]
34 label_clip_img = label_img[0:400, 50:450]
35 label_clip_img[label_clip_img == 255] = 255
36 label_clip_img[label_clip_img != 255] = 0
37
38 if label_clip_img.max() == 0:
39 continue
40
41 train_save_path = os.path.join(train_clip_save_path, dir)
42 label_save_path = os.path.join(label_clip_save_path, dir)
43
44 if not os.path.isdir(train_save_path):
45 os.makedirs(train_save_path)
46 if not os.path.isdir(label_save_path):
47 os.makedirs(label_save_path)
48
49 train_save_name = os.path.join(train_save_path, str(i) + ".png")
50 label_save_name = os.path.join(label_save_path, str(i) + ".png")
51
52 cv.imwrite(train_save_name, train_clip_img)
53 cv.imwrite(label_save_name, label_clip_img)
54 i += 1
55 j += 1
56 print("完成第{}个dir".format(j))
第五步,以病人为单位,制作切片npy文件,用来上传服务器使用
1 """
2 本程序制作上传服务器的肝脏切片数据集
3 一张张切片,按照dir分布
4 """
5 import os
6 import cv2 as cv
7 import numpy as np
8
9
10 # 处理过床位窗宽的train图像
11 train_png_path = "./data_cv_clip_whole/train"
12 # label标签
13 label_png_path = "./data_cv_clip_whole/label"
14 train_dir_npy_save_path = "./data_dir_npy/train"
15 label_dir_npy_save_path = "./data_dir_npy/label"
16 if not os.path.isdir(train_dir_npy_save_path):
17 os.makedirs(train_dir_npy_save_path)
18 if not os.path.isdir(label_dir_npy_save_path):
19 os.makedirs(label_dir_npy_save_path)
20
21 train_dirs = os.listdir(train_png_path)
22 label_dirs = os.listdir(label_png_path)
23 train_dirs.sort(key=lambda x: int(x))
24 label_dirs.sort(key=lambda x: int(x))
25
26 j = 0
27 for dir in train_dirs:
28 train_dir_path = os.path.join(train_png_path, dir)
29 label_dir_path = os.path.join(label_png_path, dir)
30
31 dir_length = len(os.listdir(train_dir_path))
32 train_dir_npy = np.ndarray((dir_length, 400, 400, 1), dtype=np.uint8)
33 label_dir_npy = np.ndarray((dir_length, 400, 400, 1), dtype=np.uint8)
34
35 train_imgs = os.listdir(train_dir_path)
36 label_imgs = os.listdir(label_dir_path)
37 train_imgs.sort(key=lambda x: int(x.split(‘.‘)[0]))
38 label_imgs.sort(key=lambda x: int(x.split(‘.‘)[0]))
39
40 i = 0
41 for img in train_imgs:
42 train_img_path = os.path.join(train_dir_path, img)
43 label_img_path = os.path.join(label_dir_path, img)
44 train_img = cv.imread(train_img_path, 0)
45 label_img = cv.imread(label_img_path, 0)
46
47 # cv.imshow("train", train_img)
48 # cv.imshow("label", label_img)
49 # cv.waitKey(0)
50 # cv.destroyAllWindows()
51
52 train_img = np.reshape(train_img, (400, 400, 1))
53 label_img = np.reshape(label_img, (400, 400, 1))
54 train_dir_npy[i] = train_img
55 label_dir_npy[i] = label_img
56
57 i += 1
58
59 np.save(train_dir_npy_save_path + "/" + str(j) + ".npy", train_dir_npy)
60 np.save(label_dir_npy_save_path + "/" + str(j) + ".npy", label_dir_npy)
61 j += 1
62 print("第{}个文件夹".format(j))
第六步,在服务器中拆解npy文件,以病人为单位保存在各自的dir中
1 """
2 本程序将服务器的dir---npy文件拆解成dir---png文件
3 注意:这里标签是纯肝脏数据,肿瘤作为背景
4 """
5 import os
6 import cv2 as cv
7 import numpy as np
8
9
10 train_png_path = "./data_dir_png/train"
11 label_png_path = "./data_dir_png/label"
12 train_npy_path = "./data_dir_npy/train"
13 label_npy_path = "./data_dir_npy/label"
14 if not os.path.isdir(train_png_path):
15 os.makedirs(train_png_path)
16 if not os.path.isdir(label_png_path):
17 os.makedirs(label_png_path)
18
19 train_npys = os.listdir(train_npy_path)
20 label_npys = os.listdir(label_npy_path)
21 train_npys.sort(key=lambda x: int(x.split(".")[0]))
22 label_npys.sort(key=lambda x: int(x.split(".")[0]))
23
24 j = 0
25 for npy in train_npys:
26 npy_path1 = os.path.join(train_npy_path, npy)
27 npy_path2 = os.path.join(label_npy_path, npy)
28 train_npy = np.load(npy_path1)
29 label_npy = np.load(npy_path2)
30 for i in range(len(train_npy)):
31 train_img = train_npy[i]
32 label_img = label_npy[i]
33 train_img = np.reshape(train_img, (400, 400))
34 label_img = np.reshape(label_img, (400, 400))
35
36 # cv.imshow("train", train_img)
37 # cv.imshow("label", label_img)
38 # cv.waitKey(0)
39 # cv.destroyAllWindows()
40 train_save_dir_path = os.path.join(train_png_path, str(j))
41 label_save_dir_path = os.path.join(label_png_path, str(j))
42 if not os.path.isdir(train_save_dir_path):
43 os.makedirs(train_save_dir_path)
44 if not os.path.isdir(label_save_dir_path):
45 os.makedirs(label_save_dir_path)
46 cv.imwrite(train_save_dir_path + "/" + str(i) + ".png", train_img)
47 cv.imwrite(label_save_dir_path + "/" + str(i) + ".png", label_img)
48 j += 1
49 print("完成第{}个dir".format(j))
然后,训练U-Net网络
U-Net网络程序
1 import keras
2 from keras.models import *
3 from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Dropout
4 from keras.optimizers import *
5
6 from keras.layers import Concatenate
7
8 from keras import backend as K
9
10 from keras.callbacks import ModelCheckpoint
11 from fit_generator import get_path_list, get_train_batch
12 import matplotlib.pyplot as plt
13
14 # 每次训练模型之前,需要修改的三个地方,训练数据地址、保存模型地址、保存训练曲线地址
15
16 train_batch_size = 2
17 epoch = 5
18 img_size = 400
19
20 data_train_path = "./data_dir_png/train"
21 data_label_path = "./data_dir_png/label"
22
23
24 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
25
26
27 # 写一个LossHistory类,保存loss和acc
28 class LossHistory(keras.callbacks.Callback):
29 def on_train_begin(self, logs={}):
30 self.losses = {‘batch‘: [], ‘epoch‘:[]}
31 self.accuracy = {‘batch‘: [], ‘epoch‘:[]}
32 self.val_loss = {‘batch‘: [], ‘epoch‘:[]}
33 self.val_acc = {‘batch‘: [], ‘epoch‘:[]}
34
35 def on_batch_end(self, batch, logs={}):
36 self.losses[‘batch‘].append(logs.get(‘loss‘))
37 self.accuracy[‘batch‘].append(logs.get(‘dice_coef‘))
38 self.val_loss[‘batch‘].append(logs.get(‘val_loss‘))
39 self.val_acc[‘batch‘].append(logs.get(‘val_acc‘))
40
41 def on_epoch_end(self, batch, logs={}):
42 self.losses[‘epoch‘].append(logs.get(‘loss‘))
43 self.accuracy[‘epoch‘].append(logs.get(‘dice_coef‘))
44 self.val_loss[‘epoch‘].append(logs.get(‘val_loss‘))
45 self.val_acc[‘epoch‘].append(logs.get(‘val_acc‘))
46
47 def loss_plot(self, loss_type):
48 iters = range(len(self.losses[loss_type]))
49 plt.figure(1)
50 # acc
51 plt.plot(iters, self.accuracy[loss_type], ‘r‘, label=‘train dice‘)
52 if loss_type == ‘epoch‘:
53 # val_acc
54 plt.plot(iters, self.val_acc[loss_type], ‘b‘, label=‘val acc‘)
55 plt.grid(True)
56 plt.xlabel(loss_type)
57 plt.ylabel(‘dice‘)
58 plt.legend(loc="best")
59 # plt.savefig(‘./curve_figure/unet_pure_liver_raw_0_129_entropy_dice_curve.png‘)
60 plt.savefig(‘./curve_figure/unet_tumour_dice.png‘)
61
62 plt.figure(2)
63 # loss
64 plt.plot(iters, self.losses[loss_type], ‘g‘, label=‘train loss‘)
65 if loss_type == ‘epoch‘:
66 # val_loss
67 plt.plot(iters, self.val_loss[loss_type], ‘k‘, label=‘val loss‘)
68 plt.grid(True)
69 plt.xlabel(loss_type)
70 plt.ylabel(‘loss‘)
71 plt.legend(loc="best")
72 # plt.savefig(‘./curve_figure/unet_pure_liver_raw_0_129_entropy_loss_curve.png‘)
73 plt.savefig(‘./curve_figure/unet_tumour_loss.png‘)
74 plt.show()
75
76
77
78
79 def dice_coef(y_true, y_pred):
80 smooth = 1.
81 y_true_f = K.flatten(y_true)
82 y_pred_f = K.flatten(y_pred)
83 intersection = K.sum(y_true_f * y_pred_f)
84 return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
85
86
87 def dice_coef_loss(y_true, y_pred):
88 return 1. - dice_coef(y_true, y_pred)
89
90
91 def mycrossentropy(y_true, y_pred, e=0.1):
92 nb_classes = 10
93 loss1 = K.categorical_crossentropy(y_true, y_pred)
94 loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / nb_classes, y_pred)
95 return (1 - e) * loss1 + e * loss2
96
97
98 class myUnet(object):
99 def __init__(self, img_rows=img_size, img_cols=img_size):
100 self.img_rows = img_rows
101 self.img_cols = img_cols
102
103 def BN_operation(self, input):
104 output = keras.layers.normalization.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True,
105 scale=True,
106 beta_initializer=‘zeros‘, gamma_initializer=‘ones‘,
107 moving_mean_initializer=‘zeros‘,
108 moving_variance_initializer=‘ones‘,
109 beta_regularizer=None,
110 gamma_regularizer=None, beta_constraint=None,
111 gamma_constraint=None)(input)
112 return output
113
114 def get_unet(self):
115 inputs = Input((self.img_rows, self.img_cols, 1))
116
117 conv1 = Conv2D(64, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(inputs)
118 conv1 = Conv2D(64, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(conv1)
119 pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
120 # BN
121 # pool1 = self.BN_operation(pool1)
122
123 conv2 = Conv2D(128, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(pool1)
124 conv2 = Conv2D(128, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(conv2)
125 pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
126 # BN
127 # pool2 = self.BN_operation(pool2)
128
129 conv3 = Conv2D(256, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(pool2)
130 conv3 = Conv2D(256, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(conv3)
131 pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
132 # BN
133 # pool3 = self.BN_operation(pool3)
134
135 conv4 = Conv2D(512, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(pool3)
136 conv4 = Conv2D(512, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(conv4)
137 drop4 = Dropout(0.5)(conv4)
138 pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
139 # BN
140 # pool4 = self.BN_operation(pool4)
141
142 conv5 = Conv2D(1024, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(pool4)
143
144 conv5 = Conv2D(1024, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(conv5)
145 drop5 = Dropout(0.5)(conv5)
146 # BN
147 # drop5 = self.BN_operation(drop5)
148
149 up6 = Conv2D(512, 2, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(
150 UpSampling2D(size=(2, 2))(drop5))
151 merge6 = Concatenate(axis=3)([drop4, up6])
152 conv6 = Conv2D(512, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(merge6)
153 conv6 = Conv2D(512, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(conv6)
154
155 up7 = Conv2D(256, 2, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(
156 UpSampling2D(size=(2, 2))(conv6))
157 merge7 = Concatenate(axis=3)([conv3, up7])
158 conv7 = Conv2D(256, 3, activation=‘relu‘, padding=‘same‘, kernel_initializer=‘he_normal‘)(merge7)
159 conv7 = Conv2D(256, 3, activation=‘>