前言
面对大数据和人工智能,已经跨过门槛的各路小伙伴们利用深度学习算法不断的炼丹,都已经在各行各业中大显身手,DL给我们带来的不仅是技术革命,更是对世界的重新认识,比如多维、古老而神秘的算法重燃青春、具备自我学习和进化的程序以及对庞大数据的特征提取及预测等等。煲完鸡汤,上干货~




直观的看,程序可以很健壮的识别磁共振图像序列,尽管因运动,图像出现较大伪影,程序仍能正确识别序列。
本文主要汇报的内容为利用PyTorch建立一个深度残差神经网络,并完成对MR图像序列的自动识别,网络输入为jpg或dicom格式的磁共振图像,网络输出结果为多分类数据,包括:T2、T1、T2 Flair。输入的磁共振图像不限部位,尽管最初创建模型时使用的是头部磁共振图像作为训练集和验证集,但是测试时为了检验下泛化能力,我用了一副腹部T2图像,也是完全可以正确识别的,这就是黑盒子的威力!利用交叉验证精确度约为93%,训练好的模型已经上传到了github上,下载地址在文末,本文仅演示使用Qt来可视化模型的应用,若需要了解更多信息,请移步“关于”联系作者。
简介
超简单的介绍下ResNet – 残差神经网络,网上关于此的内容很多,这里只说说自己的理解,深度学习的训练过程其实就是函数拟合过程,随着神经网络层数不断加深,尤其是卷积神经网络,动不动就几十层甚至上百层,随着层数的增加,优化梯度很容易消失或爆炸,导致模型训练饱和,甚至退化。2015年,微软大神何明凯的团队提出的残差网络获得了ILSVRC2015的分类任务第一名,使得尽管网络可以达到152层,不但梯度不会消失,训练误差也比传统卷积神经网络低很多,详见paper原文。以我自己的理解,就是用H(x)来表示最优解映射,但我们不去拟合这个函数,而却拟合另一个映射F(x),并另F(x)=H(x)-x,而F(x)是我们已知的拟合方程,我们只需要找到x的极小值就完成了函数拟合。
写这篇文章时,手贱更新了pytorch,从0.3.0更新到了0.3.1,这么微小的版本变动,竟然就不支持我的显卡了!怪不得深度学习门槛高,不但要有计算机语言门槛,还有硬件门槛,烧脑又烧钱啊!顺便说一句,以个人感受而言,建议涉猎深度学习的小伙伴们,离不开tensorflow是肯定的,但不妨关注下pytorch,首选炼丹神器!
实现
使用Qt作为前端,主要是尝试将模型与实际应用更好的结合起来,因为绝大部分的深度学习文章或应用都是基于matplotlib,pil,或者opencv,程序与用户不能做到很好的交互。
核心代码如下,其中已经包含了必要注释,就不过多解释了:
def do_recognition(self, filename):
# load model
model_ft = models.resnet152(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 3) # 输出几个分类
if use_gpu:
model_ft = model_ft.cuda()
model_ft.load_state_dict(torch.load('./mr_serials_recognition_params.pkl'))
self.label_2.setText("loaded ok!")
class_names = ['T1', 'T2', 'T2Flair']
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
file_ext = os.path.splitext(filename)[1]
if file_ext.upper() == '.JPG':
image_path = filename
image = QtGui.QImage(image_path)
image_data = QtGui.QPixmap(image)
# 拉伸图像,让图像自适应label大小,可按比例缩放
scaredPixmap = image_data.scaled(430, 360, aspectRatioMode=Qt.KeepAspectRatio)
self.label.setPixmap(scaredPixmap)
img = Image.open(image_path).convert('RGB') # 读取图像
img2 = data_transforms(img) # 归一化
# 因为是一幅图,所以将维度更新为 [1,3,512,512]
# input = torch.rand(1, 3, len(img2[1]), len(img2[2]))
input = img2[None, :, :, :] # 转换成 4 维
model_ft.eval()
if use_gpu:
input = Variable(input.cuda())
else:
input, labels = Variable(input), Variable(input)
outputs = model_ft(input)
_, preds = torch.max(outputs.data, 1)
self.label_2.setText('Recognition : ' + class_names[preds[0]] + ' weighted')
elif file_ext.upper() == '.DCM':
pass
ds = pydicom.read_file(filename)
dcm_image = self.get_LUT_value(ds.pixel_array, 150, 80)
im = Image.fromarray(dcm_image).convert('L')
img_tmp = ImageQt.ImageQt(im)
image = QtGui.QImage(img_tmp)
image_data = QtGui.QPixmap(image)
scaredPixmap = image_data.scaled(430, 360, aspectRatioMode=Qt.KeepAspectRatio)
self.label.setPixmap(scaredPixmap)
img = im.convert('RGB') # 读取图像
img2 = data_transforms(img) # 归一化
input = img2[None, :, :, :] # 转换成 4 维
model_ft.eval()
if use_gpu:
input = Variable(input.cuda())
else:
input, labels = Variable(input), Variable(input)
outputs = model_ft(input)
_, preds = torch.max(outputs.data, 1)
self.label_2.setText('Recognition : ' + class_names[preds[0]] + ' weighted')
遇到的问题
1、程序支持JPG和DICOM格式的图像识别,分别需要通过opencv和pydicom读取,并在Qt的label上显示,如果选取的图像是DICOM格式,那还好,普通dcm文件只是256位或者512位的灰度图像,即只有一个灰度通道,但是jpg文件会存在三个颜色通道,而不同的图像处理库对于通道处理顺序是不同的,OpenCV图像通道是BGR,而Qt或matplotlib或pillow等图像通道则是 RGB,处理时应小心。 2、pytorch和Qt原生支持pillow,即PIL,若使用pillow,在qt中可以直接赋值给QImage类,并通过像素映射QPixmap进行显示。但如果因为需要复杂的控制和处理图像而使用了opencv,必须进行颜色通道以及像素位数的转换,以供Qt使用,如下:
def Mat2QImage(img, imgtype=0):
height, width = img.shape[:2]
if imgtype == 0:
if img.ndim == 3:
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif img.ndim == 2:
rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
else:
raise Exception("Incorrect data format!")
qimage = QImage(rgb.flatten(), width, height, QImage.Format_RGB888)
qpixmap = QPixmap.fromImage(qimage)
else:
qimage = QImage(img, width, height, QImage.Format_Grayscale8)
qpixmap = QPixmap.fromImage(qimage)
return qpixmap
def setPixmapBypydicom(self, img):
if img is not None:
img8 = cv2.convertScaleAbs(img) # 将 uint16 转换为 uint8
qpixmap = Mat2QImage(img8, 1)
self.label.setPixmap(qpixmap)
以上代码运行在 Ubuntu 16.04 + Python 3.6 + PyTorch 0.3.0 + Qt 5.6 下。
训练好的模型点击以下链接下载: http://www.douruixin.com/download/mr_serials_recognition_params.pkl 全部源码可以在github上下载: https://github.com/douruixin/HeadMRSerialsRecognition