借力AI,让普通X线技术不普通

赶在2021年的最后一天po上一篇文章,即是对过去的一年的无比敬畏,也是对未来2022年的殷切期盼!

今天跟大家分享的内容是关于如何利用深度学习算法对普通胸部X线图像进行辅助诊断。肺部肿瘤和肺部感染是最常见的两种肺部疾病,发病率及死亡率都逐年升高,尤其是2020年初的新冠肺炎疫情更是席卷全世界,截止到目前全球已累计确诊2.81亿例,死亡541万例,并仍在持续扩散中。对于新冠肺炎而言,胸部CT扫描是首选的医学影像学检查手段,而胸部X线(CXR)检查在影像学检查中的角色和地位随着科技的发展和临床的需求,逐渐被CT及MRI等所取代。其实CXR仍然是呼吸系统疾病的重要筛查方式,因辐射剂量低且便于移动,广泛应用于体检、入院筛查和危重症床旁检查。CXR检查结果中正常者占大多数,如何快速、便捷和准确的从大量检查中自动筛选出少量异常者,这即是本文的研究目的:借力AI对胸部X线进行辅助诊断。

数据集、模型及验证

采用全球最大的公共胸部X线数据集(ChestXRay14)进行模型训练和验证,该数据集包含30805个患者112120幅正位胸部X线图像(https://nihcc.app.box.com/v/ChestXray-NIHCC ),每一幅图像都对14种胸肺部疾病进行标注,将图像输入模型之前统一下采样为224×224的分辨率,并根据ImageNet训练集中的平均值和标准差对所有胸部X线图像进行标准化,采用随机水平翻转对训练数据进行数据扩张。

参考吴恩达在2017年发表的文章中提出的CheXNet技术,即使用ChestXRay14数据集训练的121层的全卷积神经网络,在论文中,通过CheXNet技术对胸部X线进行肺炎的识别,准确率已经和人类放射科医生持平甚至更高。本文利用PyTorch对论文中的模型进行复现,将其扩展为对肺部肿瘤和肺部炎症的自动识别,将网络最后一层的全连接层替换为二进制输出,并连接一个Sigmoid输出概率值,采用SGD+Momentum的优化算法代替CheXNet中的Adam。使用小的批处理(batch=16)和初始学习率(lr=0.001)训练模型,模型的性能最终稳定在F1=0.812。并对我院随机抽取的胸部X线进行回顾性辅助诊断,与高年资放射科医生的诊断结果比较,一致性较好(肺部肿瘤Kappa=0.939,肺部炎症Kappa=0.959,P<0.05)。

讨论

肺部肿瘤和肺部感染是最常见的两种肺部疾病,据WHO估计,全世界2/3的胸部X线被检者缺乏有效的诊断,主要归因于缺乏解读胸部X线的放射诊断医师和专家,甚至有些地方虽然配备了先进的X线机,反而因误诊或漏诊导致死亡率不降反升。CXR辐射剂量低、移动方便,仍是呼吸系统疾病筛查的首选影像学检查方式,但检查存在大量正常者,进行自动分类诊断对减轻放射诊断工作量及实施临床决策都至关重要。另外,在床旁CXR的诊断中,由于为了避免辐射污染,投照条件都会低于普通CXR检查,且由于卧床患者不能很好的配合,X线投照距离和角度等参数不能做到统一,以致图像变形、灰阶不一致,尤其是对需要反复多次复查床旁CXR的患者,极易导致诊断结果过于严重,甚至误诊。危重症患者因长期卧床或心肺功能较差,容易引起坠积性肺炎或肺水肿等,故相关肺部感染性疾病较多。借助CheXNet诊断时,上述情形对于模型影响甚微,既无漏诊病例亦无误诊病例。近年来,相关学者对于卷积神经网络的视觉可解释性进行了很多探索,本研究利用Grad‐CAM输出可视化技术,对不同类别的特征权重加权求和得到热图,通过热图可以对网络模型分类进行可视化的解释。深度学习是一个黑盒子,为模型输入数据,模型输出类别或回归值等,中间过程却不得而知,如何打开黑盒子,让黑盒子变成灰盒甚至白盒?因此就有了深度学习可解释性这一领域,而CAM(Class Activation Mapping)技术就是其中之一。CAM全称为类别热力图或显著性图等,是一张和原始图片同等大小的图,该图片上每个位置的像素取值范围从0到1,一般用0到255的灰度图表示。可以理解为对预测输出的贡献分布,分数越高的地方表示原始图片对应区域对模型的响应越高、贡献越大。为了更直观表达,一般将灰度图转化为彩色图,再利用OpenCV进行转换和叠加,输出原图+CAM热图的叠加图,Grad-CAM是CAM更通用的做法,适用的网络更多。

talk is cheap, show me the code!

以CPU版本为例的核心代码如下:

# 初始化模型
# 全连接神经网络 DenseNet121
nnArchitecture == 'DENSE-NET-121':
model = DenseNet121(nnClassCount, True)
model = torch.nn.DataParallel(model)    
model.module.densenet121.features    
model.eval()    
weights = list(self.model.parameters())[-2]
# 图像变换,下采样
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transformList = []
transformList.append(transforms.Resize(transCrop))
transformList.append(transforms.ToTensor())
transformList.append(normalize)
....
....
input = torch.autograd.Variable(imageData)
output = self.model(input)
heatmap = None
for i in range(0, len(self.weights)):
map = output[0, i, :, :]
if i == 0:
heatmap = self.weights[i] * map
else:
heatmap += self.weights[i] * map

npHeatmap = heatmap.cpu().data.numpy()
imgOriginal = cv2.imread(pathImageFile, 1)
imgOriginal = cv2.resize(imgOriginal, (transCrop, transCrop))
cam = npHeatmap / np.max(npHeatmap)
cam = cv2.resize(cam, (transCrop, transCrop))
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
img = heatmap * 0.5 + imgOriginal
fused_img = img / np.max(img)
fused_img = 255 * fused_img
fused_img = fused_img.astype(np.uint8)
fused_img = cv2.cvtColor(fused_img, cv2.COLOR_BGR2RGB)
plt.imshow(fused_img)
plt.show()
....
....
modelCheckpoint = torch.load(pathModel, map_location='cpu')
img = Image.open(imagePath).convert('RGB')  # 读取图像
data_transforms = transforms.Compose([
   transforms.Resize(256),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img2 = data_transforms(img)  # 归一化
# 因为是一幅图,所以将维度更新为 [1,3,256,256]
input = img2[None, :, :, :]
model.eval()
input = Variable(input)
output = model(input)
output = output.cpu().data.numpy()
maxIndex = output[0].argsort()[-3:][::-1]
probability = str(CLASS_NAMES[maxIndex[0]] + ":" + format(output[0][maxIndex[0]], '.2%') + ", " + CLASS_NAMES[maxIndex[1]] + ":" + format(output[0][maxIndex[1]], '.2%') + ", " + CLASS_NAMES[maxIndex[2]] + ":" + format(output[0][maxIndex[2]], '.2%'))
h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop)
h.generate(imagePath, imgtransCrop, probability)
print(probability)

全部源码请访问 https://github.com/douruixin/mychexnet ,成功运行后在原有胸部X线图像上会叠加特征分布热图,即病灶位置,并给出肺部肿瘤和肺部炎症的分类概率。

参考文献:

https://arxiv.org/pdf/1711.05225.pdf https://arxiv.org/pdf/1512.04150.pdf https://arxiv.org/pdf/1610.02391.pdf https://zhuanlan.zhihu.com/p/269702192

借力AI,让普通X线技术不普通

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

滚动到顶部