深度学习算法对磁共振图像序列的识别

前言

面对大数据和人工智能,已经跨过门槛的各路小伙伴们利用深度学习算法不断的炼丹,都已经在各行各业中大显身手,DL给我们带来的不仅是技术革命,更是对世界的重新认识,比如多维、古老而神秘的算法重燃青春、具备自我学习和进化的程序以及对庞大数据的特征提取及预测等等。煲完鸡汤,上干货~

直观的看,程序可以很健壮的识别磁共振图像序列,尽管因运动,图像出现较大伪影,程序仍能正确识别序列。

本文主要汇报的内容为利用PyTorch建立一个深度残差神经网络,并完成对MR图像序列的自动识别,网络输入为jpg或dicom格式的磁共振图像,网络输出结果为多分类数据,包括:T2、T1、T2 Flair。输入的磁共振图像不限部位,尽管最初创建模型时使用的是头部磁共振图像作为训练集和验证集,但是测试时为了检验下泛化能力,我用了一副腹部T2图像,也是完全可以正确识别的,这就是黑盒子的威力!利用交叉验证精确度约为93%,训练好的模型已经上传到了github上,下载地址在文末,本文仅演示使用Qt来可视化模型的应用,若需要了解更多信息,请移步“关于”联系作者。

简介

超简单的介绍下ResNet – 残差神经网络,网上关于此的内容很多,这里只说说自己的理解,深度学习的训练过程其实就是函数拟合过程,随着神经网络层数不断加深,尤其是卷积神经网络,动不动就几十层甚至上百层,随着层数的增加,优化梯度很容易消失或爆炸,导致模型训练饱和,甚至退化。2015年,微软大神何明凯的团队提出的残差网络获得了ILSVRC2015的分类任务第一名,使得尽管网络可以达到152层,不但梯度不会消失,训练误差也比传统卷积神经网络低很多,详见paper原文。以我自己的理解,就是用H(x)来表示最优解映射,但我们不去拟合这个函数,而却拟合另一个映射F(x),并另F(x)=H(x)-x,而F(x)是我们已知的拟合方程,我们只需要找到x的极小值就完成了函数拟合。

写这篇文章时,手贱更新了pytorch,从0.3.0更新到了0.3.1,这么微小的版本变动,竟然就不支持我的显卡了!怪不得深度学习门槛高,不但要有计算机语言门槛,还有硬件门槛,烧脑又烧钱啊!顺便说一句,以个人感受而言,建议涉猎深度学习的小伙伴们,离不开tensorflow是肯定的,但不妨关注下pytorch,首选炼丹神器!

实现

使用Qt作为前端,主要是尝试将模型与实际应用更好的结合起来,因为绝大部分的深度学习文章或应用都是基于matplotlib,pil,或者opencv,程序与用户不能做到很好的交互。

核心代码如下,其中已经包含了必要注释,就不过多解释了:

def do_recognition(self, filename):
   # load model
   model_ft = models.resnet152(pretrained=True)
   num_ftrs = model_ft.fc.in_features
   model_ft.fc = nn.Linear(num_ftrs, 3)  # 输出几个分类
   if use_gpu:
       model_ft = model_ft.cuda()
   model_ft.load_state_dict(torch.load('./mr_serials_recognition_params.pkl'))
   self.label_2.setText("loaded ok!")
   class_names = ['T1', 'T2', 'T2Flair']
   data_transforms = transforms.Compose([
     transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ])

   file_ext = os.path.splitext(filename)[1]
   if file_ext.upper() == '.JPG':
       image_path = filename
       image = QtGui.QImage(image_path)
       image_data = QtGui.QPixmap(image)
       # 拉伸图像,让图像自适应label大小,可按比例缩放
       scaredPixmap = image_data.scaled(430, 360, aspectRatioMode=Qt.KeepAspectRatio)
       self.label.setPixmap(scaredPixmap)
       img = Image.open(image_path).convert('RGB')  # 读取图像
       img2 = data_transforms(img)  # 归一化
       # 因为是一幅图,所以将维度更新为 [1,3,512,512]
       # input = torch.rand(1, 3, len(img2[1]), len(img2[2]))
       input = img2[None, :, :, :]  # 转换成 4 维
       model_ft.eval()
       if use_gpu:
           input = Variable(input.cuda())
       else:
           input, labels = Variable(input), Variable(input)
       outputs = model_ft(input)
       _, preds = torch.max(outputs.data, 1)
       self.label_2.setText('Recognition : ' + class_names[preds[0]] + ' weighted')
   elif file_ext.upper() == '.DCM':
       pass
       ds = pydicom.read_file(filename)
       dcm_image = self.get_LUT_value(ds.pixel_array, 150, 80)
       im = Image.fromarray(dcm_image).convert('L')
       img_tmp = ImageQt.ImageQt(im)
       image = QtGui.QImage(img_tmp)
       image_data = QtGui.QPixmap(image)
       scaredPixmap = image_data.scaled(430, 360, aspectRatioMode=Qt.KeepAspectRatio)
       self.label.setPixmap(scaredPixmap)
       img = im.convert('RGB')  # 读取图像
       img2 = data_transforms(img)  # 归一化
       input = img2[None, :, :, :]  # 转换成 4 维
       model_ft.eval()
       if use_gpu:
           input = Variable(input.cuda())
       else:
           input, labels = Variable(input), Variable(input)
       outputs = model_ft(input)
       _, preds = torch.max(outputs.data, 1)
       self.label_2.setText('Recognition : ' + class_names[preds[0]] + ' weighted')

遇到的问题

1、程序支持JPG和DICOM格式的图像识别,分别需要通过opencv和pydicom读取,并在Qt的label上显示,如果选取的图像是DICOM格式,那还好,普通dcm文件只是256位或者512位的灰度图像,即只有一个灰度通道,但是jpg文件会存在三个颜色通道,而不同的图像处理库对于通道处理顺序是不同的,OpenCV图像通道是BGR,而Qt或matplotlib或pillow等图像通道则是 RGB,处理时应小心。 2、pytorch和Qt原生支持pillow,即PIL,若使用pillow,在qt中可以直接赋值给QImage类,并通过像素映射QPixmap进行显示。但如果因为需要复杂的控制和处理图像而使用了opencv,必须进行颜色通道以及像素位数的转换,以供Qt使用,如下:

def Mat2QImage(img, imgtype=0):
   height, width = img.shape[:2]
   if imgtype == 0:
       if img.ndim == 3:
           rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
       elif img.ndim == 2:
           rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
       else:
           raise Exception("Incorrect data format!")
       qimage = QImage(rgb.flatten(), width, height, QImage.Format_RGB888)
       qpixmap = QPixmap.fromImage(qimage)
   else:
       qimage = QImage(img, width, height, QImage.Format_Grayscale8)
       qpixmap = QPixmap.fromImage(qimage)
   return qpixmap

def setPixmapBypydicom(self, img):
   if img is not None:
       img8 = cv2.convertScaleAbs(img)  # 将 uint16 转换为 uint8
       qpixmap = Mat2QImage(img8, 1)
       self.label.setPixmap(qpixmap)

以上代码运行在 Ubuntu 16.04 + Python 3.6 + PyTorch 0.3.0 + Qt 5.6 下。

训练好的模型点击以下链接下载: http://www.douruixin.com/download/mr_serials_recognition_params.pkl 全部源码可以在github上下载: https://github.com/douruixin/HeadMRSerialsRecognition

深度学习算法对磁共振图像序列的识别

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

滚动到顶部