[Image recognition] NBDT- NEURAL-BACKED DECISION TREE Review

5 분 소요


NBDT - NEURAL-BACKED-DECISION TREE Review


title

본 리뷰는 논문 요약 보다는 제가 이해한 틀에 맞춰서 재구성하였습니다.


Abstract

  • Finance나 Medicine 분야는 정확하고, 믿을 수 있는 예측이 중요하다.

  • 지금까지의 Deep learning based decision tree들은 Accuracy와 Interpretability가 양립 관계에 놓여있었다.

  • NBDT는 Accuracy와 Interpretability를 둘 다 잡았다.

  • Acc면에서는 CIFAR Dataset에서 다른 모델보다 16%가 높았고(추후 비교 결과를 생각해 보았을 때 Decision tree를 이용한 모델들과의 비교인 듯 하다.), Backbone model의 성능을 2% 끌어올렸다.

  • Interpretability면에서는 model의 약점 및 Dataset debugging을 용이하게 함으로 써 더욱 향상시켰다.

  • Code는 github.com/alvinwan/neural-backed-decision-trees 에서 확인 가능하다.

Introduction

  • 지금까지의 Deep learning은 Blackbox였다.

  • Blackbox를 explain해보고자 한 방법은 크게 Saliency maps / Sequential decision process로 나뉜다.

  • Saliency map은 과정에서 잘못된 부분이 있었다고 해도, 결과가 바르게 나온다면 그 잘못된 것을 확인할 수 없는 한계점이 있다.

  • 따라서 본 논문은 model의 decision process를 rule-based의 decision tree로 나타냄으로써 model에 대한 insight를 향상시켰다.

  • 현 decision tree 관련 연구들은 다른 ML Model과 비교해 ACC가 낮고 / ACC 향상을 위해 Interpretability를 낮추고 / tree structure이 model의 안정성을 낮추는 한계가 있었다.

  • 현재는 Interpretability에 대한 Universial한 Definition이 없었다.

  • 따라서 이전 연구(Poursabzi-Sangdeh et al)의 정의에 따라 일정 수준 이상 불완전한 모델에서 사람이 제대로 Prediction을 판별할 수 있을 때 Interpretability가 좋다고 판단했다.

  • 따라서 본 논문에서는 model, dataset을 debugging할 수 있고, 사람의 신뢰를 얻을 수 있게 하는 것을 목표로 했다.

  • 이를 위해 NBDT에서는, 마지막 FC를 Differentiable oblique Decision Tree로 변경하고 Hierarchical softmax를 사용하지 않았다.

Hierarchical softmax

Contribution

  • Tree supervision loss를 Propose했다. 이 Loss를 사용해서 WideResnet, EfficientNet을 Outperform했다.

  • Oblique decision tree를 위한 alternative hierarchies를 Propose했다. Pretrained parameters를 활용했고, infomation gain 위주로 학습 된 모델을 Outperform했다.

  • NBDT가 Model의 Interpretability를 향상시킨다는 것을 확인했다.

  • Saliency maps

  • Transfer to explainable models

  • Hybrid models

  • Hierarchical Classification

Method

  • 모델은 Backbone Network - Hierarchical Decision Tree (FC를 대체)의 형태로 이루어진다.

Model - Decision tree configuration

title

  • 우선 Backbone network를 training 시킨다.

  • Backbone network의 FC layer ($R^(D*K)$)를 불러온다. (A)

  • 불러온 Parameter들에 대해서 (한 클래스당 $D*1$의 Row vector.) Agglomerative clustering을 수행하고, L2 Norm으로 나누어서 Normalize 시킨 후에 각 Row vector를 Leaf node로 삼는다. (B)

  • 참고로 논문에는 L2 Norm으로 나눈다는 것을 $\Sigma_{k\inL(i)}{w_k / L(i)}$로 나타나있는데, 여기서의 $L(i)$는 Hierarchical softmax와는 다르게 i번째 Inner node의 Subtree의 Leaf node들의 집합을 의미한다고 나는 이해했다.(논문에는 이에 대한 정확한 설명이 없어서 덧붙인다.)
  • Leafnode로 부터 출발해서 자식 노드의 평균으로 부모 노드를 설정한다. (C)

  • 이렇게 쭉 올라가서 Ancestor node까지 죽 올라간다. (D)

Model - inference

  1. Seed oblique decision rule weights with neural networks weights.

    • 앞서 설명한대로, FC의 마지막 Linear layer에서 K개의 Node를 추출해내고, 이를 Leaf node로 써서 N-K개의 inner node를 만든다.

    • 만드는 과정은 해당 Inner node에 속하는 모든 Node들의 평균을 취하는 것 같다.

  • 자식 노드의 확률은 Softmax inner product로 계산한다. (즉, 다음 노드가 될 수 있는 후보군(2개)에 대해서 부모 노드와 Inner product를 하고 Softmax를 취하는 듯 하다.)
- 자식 노드의 확률은 각 자식노드
  1. Pick a left using path probabilities

    • Class K의 Probability는 아래 공식과 같이 나타난다.

title

  • $p(k)는 Class K의 확률을, P_k$는 Root node의 Probability를 의미. $C_k(i)$는 i번째 자식 노드를 의미힌다. 즉, Class k의 확률은 Root node의 확률 * 그 다음 Node의 확률 … * 그 다음 다음 Node의 확률 … 이 된다는 것이다. (1)

title

  • Class의 결정은 모든 Leaf node, 즉 Class를 의미하는 node들 간의 Argmax로 결정한다. (2)

Labeling decision nodes with WORDNET

  • Subtree 구조를 완성하기 위해서, 우선 ancestor node부터 먼저 찾았다.

  • 이 “찾았다”는 의미는 아무래도, WordNet은 의미를 가진 구조체이지만, 실제 학습시에는 어떤 노드에 어떤 객체를 할당한다는 것이 불가능할테니까 일단 학습시키고 쭉 올라가서 공통된 노드들의 속성을 추정하는 식으로 구성한 것 같다.

  • 이유는 우리가 파악 가능한 Node의 속성은 오직 Class 뿐이고, 여러 Leaf node들의 공통된 속성 등은 학습시키는게 아니라 의도치 않게 학습되는 것 이기 때문이라고 생각한다.

  • 추가적으로 논문의 마지막 페이지에 보면, 몇몇 Inner node들은 이름이 붙여져 있지만,

  • 또한 추상적인 개념은 여기서는 반영하지 못했다.

Fine-Tuning with tree supervision loss

title

  • Normal cross-entropy loss는 각 Leaf Node. 즉, 각 Class에 대해 학습시킬 수는 있지만, 각 Inner-Node에 대해 학습시킬 수는 없다. ($L_(original)$)

  • 따라서 Tree supervision loss를 제안한다.($L_(soft)$)

  • 이 Loss는 Path probabilities에 대한 Loss이다.

  • 난 이 Loss 부분이 좀 이해가 잘 안갔다. 도대체 $L_(Original)$이랑 $L_(soft)$가 정확하게 뭐가 다른거지? 후자가 Path에 대한 것이라고 하더라도, 전자를 학습시키는 과정에서 어차피 Path에 있는 Node들도 훈련될텐데? 굳이 왜 나눈거지? 생각이 들었다.

  • 이것과 관련해서는 Appendix - C를 읽어보고 어느 정도 결론이 났는데, 결론은 $L_(Original)$은 마지막 K개의 Class에 대해서 일반적인 Cross-entropy를 적용하는 Loss를 의미하고

  • $L_(soft)$는 마지막 K개의 Class외, Inner node들 하나하나에서 다음 분기로 갈라지는 것에 대해서 Cross entropy를 각각 적용했다고 생각된다.

  • 이것은 본 Tree 구조가 Binary tree 구조라서 각 Class로 가는 정답 Path가 하나밖에 없기에, 우리는 정답 Path를 미리 알고있고, 그에 따라서 Labeling 및 Training이 가능했다고 생각된다.

  • 본 Loss는 Leaf node가 잘 학습되지 않은 Training 초기에는 Training에 좋지 않은 영향을 준다.

  • 따라서 CIFAR 100 Dataset으로 훈련시킬 때에는 $W_T$를 0부터 0.5까지 점진적으로 증가시키며 훈련시켰다.

  • $\beta_t$는 [0 1] 구간에 속한 Parameter로 선형적으로 감소시켰다.

  • Backbone 모델의 ACC보다 더 잘 나오지 않을 경우에는 $L_(soft)$만 활성화한채로 fine-tuning 시켰다.

  • Hierarchical softmax와는 다르게, path-probability cross entropy loss인 $L_(soft)$는 Hierarchy 초반에 불균형하게 Up-weighting해서 high-Level decision을 도왔다.

  • 이것은 Backbone network보다 Out-sample을 분류하는데서 16%의 정확도 향상을 만들어냈다.

Experiments

title

Interpretability

  • Decision tree의 직관적인 Interpretability에도 불구하고, NeuralNet에 들어가게 되면 중간 Process를 분석하는게 쉽지 않다.

  • 분석을 위해, Poursabzi-Sangdeh et al. (2018)의 연구를 차용해 Saliency map보다 Misclassification이 일어난 경위를 사람이 좀 더 잘 알수있도록 하였다. (Sec 5.1)

  • 그를 위해, 3가지 조건에 맞춰서 분석을 진행했다.

  • (1) discrete하고, sequential한 decisions. (One path is selected)

  • (2) Pure leaves, one class - one path (Loss 부분에 관련 설명이 있음)

  • (3) Non-ensembled predictions.

Interpretability - Survey : identifying faulty model predictions

  • 모델이 Misclassification 했을 때, 이를 사람이 분석할 수 있는가?

title

  • 600개의 설문 조사를 수행

  • 각 설문 조사는 최종 결과가 제외된 모델의 예측만을 이용해서 2개의 정답 및 1개의 오답 중 오답을 찾아내야 함.

  • Saliency map을 이용한 결과는 87/600 성공

  • NBDT는 237/600 성공

Interpretability - Survey : Explanation-Guided Image Classification

  • 판단이 어려운 상황에서 사람이 모델의 예측을 신뢰가능한가?

title

  • Blur 이미지를 주고, 모델의 정보를 이용해서 이미지를 분류해야 하는 Task

  • 모델의 정보는 NBDT는 맞추고 Deep learning model은 틀린 Case, NBDT는 틀리고 Deep learning model은 맞춘 Case, 둘 다 틀린 Case, 각 Case의 비율은 3:3:4 이다.

  • 결과는 255/600이 정답을 맞추는데 성공하였다.

  • 600 Case 중에, NBDT의 Prediction에 동의한 case는 312, Saliency-map Prediction에 동의한 case는 167이었다.

  • 이로써 NBDT가 좀 더 사람의 신뢰를 얻을 수 있는 Model이라는 것을 증명.

Interpretability - Survey : Human-Diagnosed level of Trust

  • 사람이 어떤 모델의 판단을 더 신뢰하는가?

  • 모델의 분류 이유를 하나는 Saliency map을 통해, 하나는 NBDT를 통해 나타내었다.

  • 374명의 응답자 중, 맞은 예측의 경우에는 65.9%의 응답자가, 틀린 예측의 경우에는 73.5%의 응답자가 NBDT의 Prediction 과정이 좀 더 믿을만하다고 하였다.

Conclusion

  • 이번 연구를 통해서, 우리는 NBDT를 제안하였고, BDT는 out-sample에 대해서 16% 향상된 결과를 (Backbone network에 비교해서)를 보였고, 전체적인 결과가 2% 향상됐고, 정답률은 0.15%, 다른 모델과 비교했을 때는 1%(CIFAR 10, 100에서 SOTA Model보다) 더 뛰어난 결과를 보였다.

  • 또한 우리의 Hierarchy를 통해서 사람의 모델 신뢰성을 향상시켰고, 애매한 Label에 대해서 path entropy가 활용될 수 있음을 보여주었다.

  • 이 연구는 Accuracy 및 Interpretability를 실제 depolyments 안에서 향상시켰다.

댓글남기기