U-Net网络实现医学图像分割

2021-02-14 06:16

阅读:605

标签:pat   conv2   告诉   check   ilo   pycha   data   begin   work   

U-Net网络是典型的编解码网络,常用于图像分割

 

首先,准备数据

 

第一步,将nii格式医学图像转换为png格式,按给定窗位窗宽截断

  1 import numpy as np
  2 import cv2 as cv
  3 import nibabel as nib
  4 import os
  5 from PIL import Image
  6 import imageio
  7 
  8 
  9 def transform_ctdata(image, windowWidth, windowCenter, normal=False):
 10     """
 11     注意,这个函数的self.image一定得是float类型的,否则就无效!
 12     return: trucated image according to window center and window width
 13     """
 14     minWindow = float(windowCenter) - 0.5 * float(windowWidth)
 15     newimg = (image - minWindow) / float(windowWidth)
 16     newimg[newimg  0
 17     newimg[newimg > 1] = 1
 18     if not normal:
 19         newimg = (newimg * 255).astype(uint8)
 20     return newimg
 21 
 22 
 23 # nii文件存放路径
 24 train_path = "D:/pycharm_project/graduate_design_next_semester/dataset/data_nii/train"
 25 label_path = "D:/pycharm_project/graduate_design_next_semester/dataset/data_nii/label"
 26 # slice存放路径
 27 train_save_path = ./data_raw_slice_tumour/train
 28 label_save_path = ./data_raw_slice_tumour/label
 29 if not os.path.isdir(train_save_path):
 30     os.makedirs(train_save_path)
 31 if not os.path.isdir(label_save_path):
 32     os.makedirs(label_save_path)
 33 
 34 
 35 # 准备导入训练图像以及标签,并对图像和标签进行排序
 36 train_images = os.listdir(train_path)
 37 train_images.sort(key=lambda x: int(x.split(-)[1].split(.)[0]))
 38 label_images = os.listdir(label_path)
 39 label_images.sort(key=lambda x: int(x.split(-)[1].split(.)[0]))
 40 
 41 
 42 def create_train_label_slice():
 43     print(- * 30)
 44     print("同时分解volume与segmentation文件")
 45     for i in range(len(train_images)):
 46         train_image = nib.load(train_path + / + train_images[i])
 47         label_image = nib.load(label_path + / + label_images[i])
 48         # 获取每一个nii文件的行、列、切片数
 49         height, width, slice = train_image.shape
 50         print("" + str(i) + "个dir", "(", height, width, slice, ")")
 51         # 保存切片的小子文件夹序号,0,1,2等
 52         slice_save_path = train_images[i].split(-)[1].split(.)[0]
 53 
 54         train_slice_path = train_save_path + / + slice_save_path
 55         label_slice_path = label_save_path + / + slice_save_path
 56 
 57         # if not os.path.isdir(train_slice_path):
 58         #     os.makedirs(train_slice_path)
 59         # if not os.path.isdir(label_slice_path):
 60         #     os.makedirs(label_slice_path)
 61 
 62         img_fdata = label_image.get_fdata()
 63         for j in range(slice):
 64             train_img = train_image.dataobj[:, :, j]
 65             label_img = img_fdata[:, :, j]
 66 
 67             label_img[label_img == 1] = 0
 68             label_img[label_img == 2] = 1
 69 
 70             white_pixel = label_img == 1
 71             white_pixel_num = len(label_img[white_pixel])
 72 
 73             # # 判断是否为全黑的标签,这样没有意义,剔除
 74             # if label_img.max() != 0:
 75 
 76             # 肿瘤标签像素点数量应该大于50,才算作有效数据
 77             if white_pixel_num >= 50:
 78                 set_slice = np.array(train_img).copy()
 79                 set_slice = set_slice.astype("float32")
 80                 # 训练用的窗位窗宽
 81                 # set_slice = transform_ctdata(set_slice, 350, 25)
 82                 # 知乎上参考,肝脏是40~60
 83                 set_slice = transform_ctdata(set_slice, 200, 30)
 84 
 85                 # set_slice = set_slice.astype("float32")
 86                 # mean = set_slice.mean()
 87                 # std = np.std(set_slice)
 88                 # set_slice -= mean
 89                 # set_slice /= std
 90                 # set_slice = (set_slice - set_slice.min()) / (set_slice.max() - set_slice.min())
 91                 # set_slice *= 255
 92                 # # set_slice = transform_ctdata(set_slice, 250, 125)
 93                 # set_slice = set_slice.astype("uint8")
 94 
 95                 # 中值滤波,去除椒盐噪声
 96                 set_slice = cv.medianBlur(set_slice, 3)
 97 
 98                 if not os.path.isdir(train_slice_path):
 99                     os.makedirs(train_slice_path)
100                 if not os.path.isdir(label_slice_path):
101                     os.makedirs(label_slice_path)
102 
103                 # 加入直方图均衡处理
104                 # set_slice = cv.equalizeHist(set_slice)
105                 cv.imwrite(train_slice_path + / + str(j) + .png, set_slice)
106                 label_img = Image.fromarray(np.uint8(label_img * 255))
107                 imageio.imwrite(label_slice_path + / + str(j) + .png, label_img)
108             else:
109                 pass
110     print("Generating train data set done!")
111 
112 
113 if __name__ == "__main__":
114     create_train_label_slice()

 

第二步,计算每一位病人中切片的均值,为降低病人间切片亮度差异作准备

 

 1 import os
 2 import cv2 as cv
 3 import numpy as np
 4 
 5 
 6 train_raw_path = "./data_raw_slice_tumour/train"
 7 label_raw_path = "./data_raw_slice_tumour/label"
 8 
 9 
10 def count_dir_mean_fun():
11     dirs = os.listdir(train_raw_path)
12     dirs.sort(key=lambda x: int(x))
13 
14     mean1 = []
15     mean2 = 0.0
16     mean3 = []
17 
18     for dir in dirs:
19         train_dir_path = os.path.join(train_raw_path, dir)
20         images = os.listdir(train_dir_path)
21         images.sort(key=lambda x: int(x.split(.)[0]))
22 
23         image_num = len(images)
24         mean = 0.0
25         for name in images:
26             image_path = os.path.join(train_dir_path, name)
27             image = cv.imread(image_path, 0)
28             image = image.astype("float32")
29             # image /= 255
30 
31             black_pixel = image 32             black_num = len(image[black_pixel])
33 
34             # mean += image.mean()
35             mean += image.mean() * 512 * 512 / (512 * 512 - black_num)
36 
37         mean /= image_num
38         mean1.append(mean)
39     mean2 = sum(mean1) / len(mean1)
40     mean3[:] = [x - mean2 for x in mean1]
41     print(mean1)
42     print(mean2)
43     print(mean3)
44     mean1 = np.array(mean1)
45     mean2 = np.array(mean2)
46     mean3 = np.array(mean3)
47     # mean_save_path = "./mean_array"
48     mean_save_path = "./mean_array_without_background"
49     if not os.path.isdir(mean_save_path):
50         os.makedirs(mean_save_path)
51     np.save(mean_save_path + / + "mean1.npy", mean1)
52     np.save(mean_save_path + / + "mean2.npy", mean2)
53     np.save(mean_save_path + / + "mean3.npy", mean3)
54 
55 
56 if __name__ == "__main__":
57     count_dir_mean_fun()

 

第三步,降低病人间亮度差异

 

 1 import os
 2 import cv2 as cv
 3 import numpy as np
 4 
 5 
 6 train_raw_path = "./data_raw_slice_tumour/train"
 7 label_raw_path = "./data_raw_slice_tumour/label"
 8 
 9 
10 dirs = os.listdir(train_raw_path)
11 dirs.sort(key=lambda x: int(x))
12 
13 train_save_path = "./data_slice_tumour_modify_brightness/train"
14 label_save_path = "./data_slice_tumour_modify_brightness/label"
15 
16 
17 # mean1 = np.load("./mean_array/mean1.npy")
18 # mean2 = np.load("./mean_array/mean2.npy")
19 mean1 = np.load("./mean_array_without_background/mean1.npy")
20 mean2 = np.load("./mean_array_without_background/mean2.npy")
21 for i in range(len(dirs)):
22     dir_name = dirs[i]
23     mean_sub = mean1[i]
24     mean_add = mean2
25 
26     train_dir_path = os.path.join(train_raw_path, dir_name)
27     images = os.listdir(train_dir_path)
28     images.sort(key=lambda x: int(x.split(.)[0]))
29     for name in images:
30         image_path = os.path.join(train_dir_path, name)
31         image = cv.imread(image_path, 0)
32         image = image.astype("float32")
33         # image /= 255
34         image -= mean_sub
35         image += mean_add
36         # image *= 255
37         # image = image.astype("uint8")
38 
39         save_path = os.path.join(train_save_path, dir_name)
40         if not os.path.isdir(save_path):
41             os.makedirs(save_path)
42         cv.imwrite(save_path + / + name, image)
43     print("完成第{}个dir".format(i))

 

第四步,对切片进行裁剪,减少不必要背景

 

 1 import os
 2 import cv2 as cv
 3 import numpy as np
 4 
 5 
 6 raw_slice_train_path = "./data_slice_tumour_modify_brightness/train"
 7 raw_slice_label_path = "./data_raw_slice_tumour/label"
 8 
 9 
10 train_clip_save_path = "./data_cv_clip_whole/train"
11 label_clip_save_path = "./data_cv_clip_whole/label"
12 
13 
14 dirs = os.listdir(raw_slice_train_path)
15 dirs.sort(key=lambda x: int(x))
16 
17 j = 0
18 for dir in dirs:
19     train_dir_path = os.path.join(raw_slice_train_path, dir)
20     label_dir_path = os.path.join(raw_slice_label_path, dir)
21 
22     names = os.listdir(train_dir_path)
23     names.sort(key=lambda x: int(x.split(.)[0]))
24 
25     i = 0
26     for name in names:
27         train_img_path = os.path.join(train_dir_path, name)
28         label_img_path = os.path.join(label_dir_path, name)
29 
30         train_img = cv.imread(train_img_path, 0)
31         label_img = cv.imread(label_img_path, 0)
32 
33         train_clip_img = train_img[0:400, 50:450]
34         label_clip_img = label_img[0:400, 50:450]
35         label_clip_img[label_clip_img == 255] = 255
36         label_clip_img[label_clip_img != 255] = 0
37 
38         if label_clip_img.max() == 0:
39             continue
40 
41         train_save_path = os.path.join(train_clip_save_path, dir)
42         label_save_path = os.path.join(label_clip_save_path, dir)
43 
44         if not os.path.isdir(train_save_path):
45             os.makedirs(train_save_path)
46         if not os.path.isdir(label_save_path):
47             os.makedirs(label_save_path)
48 
49         train_save_name = os.path.join(train_save_path, str(i) + ".png")
50         label_save_name = os.path.join(label_save_path, str(i) + ".png")
51 
52         cv.imwrite(train_save_name, train_clip_img)
53         cv.imwrite(label_save_name, label_clip_img)
54         i += 1
55     j += 1
56     print("完成第{}个dir".format(j))

 

第五步,以病人为单位,制作切片npy文件,用来上传服务器使用

 

 1 """
 2 本程序制作上传服务器的肝脏切片数据集
 3 一张张切片,按照dir分布
 4 """
 5 import os
 6 import cv2 as cv
 7 import numpy as np
 8 
 9 
10 # 处理过床位窗宽的train图像
11 train_png_path = "./data_cv_clip_whole/train"
12 # label标签
13 label_png_path = "./data_cv_clip_whole/label"
14 train_dir_npy_save_path = "./data_dir_npy/train"
15 label_dir_npy_save_path = "./data_dir_npy/label"
16 if not os.path.isdir(train_dir_npy_save_path):
17     os.makedirs(train_dir_npy_save_path)
18 if not os.path.isdir(label_dir_npy_save_path):
19     os.makedirs(label_dir_npy_save_path)
20 
21 train_dirs = os.listdir(train_png_path)
22 label_dirs = os.listdir(label_png_path)
23 train_dirs.sort(key=lambda x: int(x))
24 label_dirs.sort(key=lambda x: int(x))
25 
26 j = 0
27 for dir in train_dirs:
28     train_dir_path = os.path.join(train_png_path, dir)
29     label_dir_path = os.path.join(label_png_path, dir)
30 
31     dir_length = len(os.listdir(train_dir_path))
32     train_dir_npy = np.ndarray((dir_length, 400, 400, 1), dtype=np.uint8)
33     label_dir_npy = np.ndarray((dir_length, 400, 400, 1), dtype=np.uint8)
34 
35     train_imgs = os.listdir(train_dir_path)
36     label_imgs = os.listdir(label_dir_path)
37     train_imgs.sort(key=lambda x: int(x.split(.)[0]))
38     label_imgs.sort(key=lambda x: int(x.split(.)[0]))
39 
40     i = 0
41     for img in train_imgs:
42         train_img_path = os.path.join(train_dir_path, img)
43         label_img_path = os.path.join(label_dir_path, img)
44         train_img = cv.imread(train_img_path, 0)
45         label_img = cv.imread(label_img_path, 0)
46 
47         # cv.imshow("train", train_img)
48         # cv.imshow("label", label_img)
49         # cv.waitKey(0)
50         # cv.destroyAllWindows()
51 
52         train_img = np.reshape(train_img, (400, 400, 1))
53         label_img = np.reshape(label_img, (400, 400, 1))
54         train_dir_npy[i] = train_img
55         label_dir_npy[i] = label_img
56 
57         i += 1
58 
59     np.save(train_dir_npy_save_path + "/" + str(j) + ".npy", train_dir_npy)
60     np.save(label_dir_npy_save_path + "/" + str(j) + ".npy", label_dir_npy)
61     j += 1
62     print("第{}个文件夹".format(j))

 

第六步,在服务器中拆解npy文件,以病人为单位保存在各自的dir中

 

 1 """
 2 本程序将服务器的dir---npy文件拆解成dir---png文件
 3 注意:这里标签是纯肝脏数据,肿瘤作为背景
 4 """
 5 import os
 6 import cv2 as cv
 7 import numpy as np
 8 
 9 
10 train_png_path = "./data_dir_png/train"
11 label_png_path = "./data_dir_png/label"
12 train_npy_path = "./data_dir_npy/train"
13 label_npy_path = "./data_dir_npy/label"
14 if not os.path.isdir(train_png_path):
15     os.makedirs(train_png_path)
16 if not os.path.isdir(label_png_path):
17     os.makedirs(label_png_path)
18 
19 train_npys = os.listdir(train_npy_path)
20 label_npys = os.listdir(label_npy_path)
21 train_npys.sort(key=lambda x: int(x.split(".")[0]))
22 label_npys.sort(key=lambda x: int(x.split(".")[0]))
23 
24 j = 0
25 for npy in train_npys:
26     npy_path1 = os.path.join(train_npy_path, npy)
27     npy_path2 = os.path.join(label_npy_path, npy)
28     train_npy = np.load(npy_path1)
29     label_npy = np.load(npy_path2)
30     for i in range(len(train_npy)):
31         train_img = train_npy[i]
32         label_img = label_npy[i]
33         train_img = np.reshape(train_img, (400, 400))
34         label_img = np.reshape(label_img, (400, 400))
35 
36         # cv.imshow("train", train_img)
37         # cv.imshow("label", label_img)
38         # cv.waitKey(0)
39         # cv.destroyAllWindows()
40         train_save_dir_path = os.path.join(train_png_path, str(j))
41         label_save_dir_path = os.path.join(label_png_path, str(j))
42         if not os.path.isdir(train_save_dir_path):
43             os.makedirs(train_save_dir_path)
44         if not os.path.isdir(label_save_dir_path):
45             os.makedirs(label_save_dir_path)
46         cv.imwrite(train_save_dir_path + "/" + str(i) + ".png", train_img)
47         cv.imwrite(label_save_dir_path + "/" + str(i) + ".png", label_img)
48     j += 1
49     print("完成第{}个dir".format(j))

 

然后,训练U-Net网络

 

U-Net网络程序

 

  1 import keras
  2 from keras.models import *
  3 from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Dropout
  4 from keras.optimizers import *
  5 
  6 from keras.layers import Concatenate
  7 
  8 from keras import backend as K
  9 
 10 from keras.callbacks import ModelCheckpoint
 11 from fit_generator import get_path_list, get_train_batch
 12 import matplotlib.pyplot as plt
 13 
 14 # 每次训练模型之前,需要修改的三个地方,训练数据地址、保存模型地址、保存训练曲线地址
 15 
 16 train_batch_size = 2
 17 epoch = 5
 18 img_size = 400
 19 
 20 data_train_path = "./data_dir_png/train"
 21 data_label_path = "./data_dir_png/label"
 22 
 23 
 24 train_path_list, label_path_list, count = get_path_list(data_train_path, data_label_path)
 25 
 26 
 27 # 写一个LossHistory类,保存loss和acc
 28 class LossHistory(keras.callbacks.Callback):
 29    def on_train_begin(self, logs={}):
 30        self.losses = {batch: [], epoch:[]}
 31        self.accuracy = {batch: [], epoch:[]}
 32        self.val_loss = {batch: [], epoch:[]}
 33        self.val_acc = {batch: [], epoch:[]}
 34 
 35    def on_batch_end(self, batch, logs={}):
 36        self.losses[batch].append(logs.get(loss))
 37        self.accuracy[batch].append(logs.get(dice_coef))
 38        self.val_loss[batch].append(logs.get(val_loss))
 39        self.val_acc[batch].append(logs.get(val_acc))
 40 
 41    def on_epoch_end(self, batch, logs={}):
 42        self.losses[epoch].append(logs.get(loss))
 43        self.accuracy[epoch].append(logs.get(dice_coef))
 44        self.val_loss[epoch].append(logs.get(val_loss))
 45        self.val_acc[epoch].append(logs.get(val_acc))
 46 
 47    def loss_plot(self, loss_type):
 48        iters = range(len(self.losses[loss_type]))
 49        plt.figure(1)
 50        # acc
 51        plt.plot(iters, self.accuracy[loss_type], r, label=train dice)
 52        if loss_type == epoch:
 53            # val_acc
 54            plt.plot(iters, self.val_acc[loss_type], b, label=val acc)
 55        plt.grid(True)
 56        plt.xlabel(loss_type)
 57        plt.ylabel(dice)
 58        plt.legend(loc="best")
 59 #       plt.savefig(‘./curve_figure/unet_pure_liver_raw_0_129_entropy_dice_curve.png‘)
 60        plt.savefig(./curve_figure/unet_tumour_dice.png)
 61        
 62        plt.figure(2)
 63        # loss
 64        plt.plot(iters, self.losses[loss_type], g, label=train loss)
 65        if loss_type == epoch:
 66            # val_loss
 67            plt.plot(iters, self.val_loss[loss_type], k, label=val loss)
 68        plt.grid(True)
 69        plt.xlabel(loss_type)
 70        plt.ylabel(loss)
 71        plt.legend(loc="best")
 72 #       plt.savefig(‘./curve_figure/unet_pure_liver_raw_0_129_entropy_loss_curve.png‘)
 73        plt.savefig(./curve_figure/unet_tumour_loss.png)
 74        plt.show()
 75 
 76 
 77 
 78 
 79 def dice_coef(y_true, y_pred):
 80     smooth = 1.
 81     y_true_f = K.flatten(y_true)
 82     y_pred_f = K.flatten(y_pred)
 83     intersection = K.sum(y_true_f * y_pred_f)
 84     return (2. * intersection + smooth) / (K.sum(y_true_f * y_true_f) + K.sum(y_pred_f * y_pred_f) + smooth)
 85 
 86 
 87 def dice_coef_loss(y_true, y_pred):
 88     return 1. - dice_coef(y_true, y_pred)
 89 
 90 
 91 def mycrossentropy(y_true, y_pred, e=0.1):
 92     nb_classes = 10
 93     loss1 = K.categorical_crossentropy(y_true, y_pred)
 94     loss2 = K.categorical_crossentropy(K.ones_like(y_pred) / nb_classes, y_pred)
 95     return (1 - e) * loss1 + e * loss2
 96 
 97 
 98 class myUnet(object):
 99     def __init__(self, img_rows=img_size, img_cols=img_size):
100         self.img_rows = img_rows
101         self.img_cols = img_cols
102 
103     def BN_operation(self, input):
104         output = keras.layers.normalization.BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001, center=True,
105                                                                scale=True,
106                                                                beta_initializer=zeros, gamma_initializer=ones,
107                                                                moving_mean_initializer=zeros,
108                                                                moving_variance_initializer=ones,
109                                                                beta_regularizer=None,
110                                                                gamma_regularizer=None, beta_constraint=None,
111                                                                gamma_constraint=None)(input)
112         return output
113 
114     def get_unet(self):
115         inputs = Input((self.img_rows, self.img_cols, 1))
116 
117         conv1 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(inputs)
118         conv1 = Conv2D(64, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv1)
119         pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
120         # BN
121         # pool1 = self.BN_operation(pool1)
122 
123         conv2 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool1)
124         conv2 = Conv2D(128, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv2)
125         pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
126         # BN
127         # pool2 = self.BN_operation(pool2)
128 
129         conv3 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool2)
130         conv3 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv3)
131         pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
132         # BN
133         # pool3 = self.BN_operation(pool3)
134 
135         conv4 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool3)
136         conv4 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv4)
137         drop4 = Dropout(0.5)(conv4)
138         pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
139         # BN
140         # pool4 = self.BN_operation(pool4)
141 
142         conv5 = Conv2D(1024, 3, activation=relu, padding=same, kernel_initializer=he_normal)(pool4)
143 
144         conv5 = Conv2D(1024, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv5)
145         drop5 = Dropout(0.5)(conv5)
146         # BN
147         # drop5 = self.BN_operation(drop5)
148 
149         up6 = Conv2D(512, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
150             UpSampling2D(size=(2, 2))(drop5))
151         merge6 = Concatenate(axis=3)([drop4, up6])
152         conv6 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge6)
153         conv6 = Conv2D(512, 3, activation=relu, padding=same, kernel_initializer=he_normal)(conv6)
154 
155         up7 = Conv2D(256, 2, activation=relu, padding=same, kernel_initializer=he_normal)(
156             UpSampling2D(size=(2, 2))(conv6))
157         merge7 = Concatenate(axis=3)([conv3, up7])
158         conv7 = Conv2D(256, 3, activation=relu, padding=same, kernel_initializer=he_normal)(merge7)
159         conv7 = Conv2D(256, 3, activation=&gt


评论


亲,登录后才可以留言!