神经网络准确但不可解释,决策树是可解释的,但在计算机视觉中是不准确的。对于这种问题,我们在本文有一个解决办法。
来自IEEE会员Cuntai Guan,他承认“many machine decisions are still poorly understood "。大多数论文甚至在准确性和可解释性之间提出严格的二分法.
Explainable AI (XAI)试图填补这个鸿沟,但正如我们下面所解释的,XAI在不直接解释模型的情况下证明了决策的合理性。这意味着在金融和医学等应用领域的实践者被迫陷入两难境地:在不可解释的、精确的模型,和不准确、可解释的模型之间选择一个。
定义计算机视觉的可解释性是有挑战性的:它甚至意味着如何解释像图像那样的高维输入的分类,正如我们下面讨论的,两种常见的定义涉及显著图和决策树,但这两种方法都有其弱点。
许多XAI方法生成显著性图,它突出了影响预测的重要输入像素。然而,显著性图侧重于输入但忽略了对模型如何做出决策的解释。
为了说明为什么显著性图不能完全解释模型预测的过程,这里有一个例子:下面两个显著性图是相同的,但是预测不同。即使两个显著性图都突出了正确的对象,但其中一个预测是不正确的。为什么?回答这个问题可以帮助我们改进模型,但是正如下图所示,显著图不能解释模型的决策过程。
另一种方法是用可解释的模型代替神经网络。在深度学习之前,决策树是准确性和可解释性的黄金标准。下面,我们解释决策树的可解释性,它通过将预测分解成一系列决策来工作。
然而,对于准确度而言,决策树在图像分类数据集²上落后于神经网络高达40%的准确度。神经网络和决策树混合使用也表现不佳,甚至不能在数据集CIFAR10上匹配神经网络,该数据集特征在于微小的32X32图像,如下图:
这种精度差距降低了可解释性:高精度,可解释的模型需要被用来解释高精度神经网络。
我们通过建立既可解释又准确的模型来挑战二分法。主要是将神经网络与决策树相结合,保持高层的可解释性,同时使用神经网络进行低层决策,如下所示。我们称这些模型为神经决策树(NBDTS)表明,它们可以保持神经网络的准确性,同时保留了决策树的可解释性。
NBDTs与决策树一样具有可解释性。不同于神经网络,NBDTs可以输出用于预测的中间决策。例如,给定图像,神经网络可以输出狗。然而,NBDT可以输出狗和动物,脊索动物,食肉动物(如下)。
NBDTs实现了神经网络的精度。与其他基于决策树的方法不同,NBDTS在3个图像分类数据集上匹配神经网络精度(小于1%的差异)。NBDTs 还实现了在ImageNet上的2%的精度,其中最大的图像分类数据集具有120 万个224x224图像。
此外,NBDTs为可解释的模型设置了新的现代精度。NBDT的ImageNet精度为75.30%,优于14%的基于决策树的最佳竞争方法 。为了把不可解释的神经网络的精度提高14%,花费了3年的时间去研究⁴。
最具洞察力的证据是模型从未见过的对象。例如,考虑一个NBDT (如下),并在一个斑马上运行推理。虽然这个模型从未见过斑马,但下面的中间决策是正确的。斑马是动物和蹄动物。对没见过的物体来说,对个体预测的合理性是必不可少的。
此外,我们发现,利用NBDTs,可解释性随着精度的提高而提高。这与引言中的二分法相反:NBDTs不仅具有准确性和可解释性,而且能够使准确性和可解释性都比较高。
例如,较低精度的ResNet⁶层次结构(左)的意义较低,因为其将青蛙、猫和飞机分组在一起,这是“不太明智的”,因为很难找到这三种类别共同的明显视觉特征。相比之下,WideResNet(右)的精度更高,更清晰地将动物从车辆中分离开来。
通过使用低维表格,决策树中的决策规则很容易解释,例如,如果盘子种包含圆面包,则选择正确的分支,如下所示。然而,决策规则并不像高维图像那样直接输入。
正如我们在论文(Sec 5.3) 中定性描述的那样,该模型的决策规则不仅基于对象类型,而且还基于上下文、形状和颜色。
为了定量地解释决策规则,我们利用一个称为WordNet⁷的层次结构,通过这个层次结构,我们可以发现类之间特定的相同特征。例如,给出猫和狗,WordNet将提供哺乳动物。在我们的论文(Sec 5.2) 和下面的图片,我们定量地验证这些WordNet假设。
注意,在具有10个类(即CIFAR10)的小数据集中,我们可以找到所有节点的WordNet假设。然而,在具有1000类( ImageNet)的大型数据集中,我们只能为节点子集找到WordNet假设。
对NBDT感兴趣?不用安装任何东西,你可以使用在线示例输出甚至多次尝试我们的网络演示,或者,使用命令行实运行推断(用pip安装nbdt)。下面,我们对猫的图片进行推理。
nbdt https://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32 # this can also be a path to local image
既输出类预测,又输出所有中间决策。
Prediction: cat // Decisions: animal (99.47%), chordate (99.20%), carnivore (99.42%), cat (99.86%)
你可以使用几行Python代码加载预训练的NBDT。使用下面的方法开始,我们支持几个神经网络和数据集。
from nbdt.model import HardNBDT
from nbdt.models import wrn28_10_cifar10
model = wrn28_10_cifar10()
model = HardNBDT(
pretrained=True,
dataset='CIFAR10',
arch='wrn28_10_cifar10',
model=model)
作为参考,可以参见我们在上面运行的命令行脚本,只有20行代码直接参与转换输入和运行推理。有关示例的更多说明,请参见Github repository https://github.com/alvinwan/neural-backed-decision-trees
神经决策树的训练和推理过程可以分解为四个步骤。
1.构造决策树的层次结构。该层次结构决定了NBDT必须在哪些类之间进行决策。我们把这个层次称为诱导层。
2.这种层次结构产生了一个特殊的损失函数,我们称之为“树监督损失⁵”。使用这种新的损失训练原始神经网络,不用任何修改。
3.通过将样本传递到神经网络主干来开始推理。主干是最终完全连接层之前的所有神经网络层。
4.通过将最终完全连接层作为一系列决策规则来完成推理,我们称之为嵌入式决策规则。这些决策在最终的预测中达到顶峰。
XAI并不能完全解释神经网络是如何达到预测的:现有的方法能够解释图像对模型预测的影响,但不能解释决策过程。决策树能解决这个问题,但不幸的是,图像是决策树准确性的克星⁷。
因此,我们结合神经网络和决策树。不同于一些混合设计的前辈,我们的神经决策树(NBDTs)同时解决神经网络的不可解释(1)和决策树的高精度(2)。这为医学和金融等应用提供了一个新的精确、可解释的NBDTs。
作者Alvin Wan
译者:张森豪