ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (2024)

Evelyn Mannix
Melbourne Centre for Data Science
University of Melbourne
Melbourne, Australia 3010
evelyn.mannix@unimelb.edu.au
&Howard Bondell
Melbourne Centre for Data Science
University of Melbourne
Melbourne, Australia 3010
howard.bondell@unimelb.edu.au

Abstract

Interpretable computer vision models are able to explain their reasoning through comparing the distances between the image patch embeddings and prototypes within a latent space. However, many of these approaches introduce additional complexity, can require multiple training steps and often have a performance cost in comparison to black-box approaches. In this work, we introduce Component Features (ComFe), a novel interpretable-by-design image classification approach that is highly scalable and can obtain better accuracy and robustness in comparison to non-interpretable methods. Inspired by recent developments in computer vision foundation models, ComFe uses a transformer-decoder head and a hierarchical mixture-modelling approach with a foundation model backbone to obtain higher accuracy compared to previous interpretable models across a range of fine-grained vision benchmarks, without the need to individually tune hyper-parameters for each dataset. With only global image labels and no segmentation or part annotations, ComFe can identify consistent component features within an image and determine which of these features are informative in making a prediction.

1 Introduction

From identifying disease in medical imaging [1], to species identification [2] and self-driving cars [3], deep learning is used in numerous contexts to solve computer vision problems. However, standard deep learning approaches are black boxes [4], and it can be challenging to determine when a prediction is being made based on the most relevant features in an image. For example, neural networks can often learn spurious correlations from image backgrounds [5].

Post-hoc interpretability techniques can uncover this behaviour, such as class activation maps [6], and recent advances have considered neural network architectures that provide explainability using attention maps [7] or other specialised features [8]. However, these approaches are not transparent in that they cannot be used to understand which parts of the training dataset might be leading to these relationships. Interpretable models, that are designed to reason in a logical fashion, can achieve this goal by identifying prototypical parts within a training dataset that provide evidence for a particular category [9, 10, 11].

Nevertheless, interpretable models can have a poor semantic correspondence between the embeddings of prototypes and their visual characteristics [12, 13], which is a problem that crosses over into the field of self-supervised learning (SSL). SSL techniques use an optimisation task defined by the data itself to train models, and neural networks trained using SSL approaches for computer vision tasks have shown an impressive ability to create an embedding space that reflects semantic similarity [14, 15, 16]. Recent advances in this area have lead to the development of foundation models that can produce useful embeddings across a range of vision contexts, such as DINOv2 [17] and CLIP [18, 19].

To take advantage of these developments we present Component Features (ComFe), a novel interpretable-by-design image classification approach designed to be used with foundation models that employs a transformer decoder [20] head and a hierarchical mixture modelling framework to make explainable predictions. As shown in Fig.1, ComFe is able to identify a set of components within an image and compare them to a library of class prototypes that can be cross-referenced to the embeddings of the training dataset to answer why a particular prediction is made. This approach is inspired by the Detection Transformer [21], Mask2Former [22] and PlainSeg [23], in addition to recent advances in semi-supervised learning [24, 25, 26]. Our main contributions in this paper include

  1. 1.

    We present Component Features (ComFe), an interpretable image classification approach designed to be used with foundation models that is highly scalable and performs well on a wide range of datasets. ComFe uses a transformer decoder architecture trained with a hierarchical mixture modelling approach to summarise an image into a set of component features, which are then compared to class prototypes to make a prediction.

  2. 2.

    We demonstrate the competitive performance of ComFe in comparison to other interpretable and non-interpretable approaches, and show that ComFe is more generalisable and robust in comparison to a non-interpretable linear head on the ImageNet dataset.

  3. 3.

    We show that the components identified by ComFe are able to consistently detect particular regions of an image (e.g. a birds head, body and wings) and that ComFe can identify informative versus non-informative (i.e. background) patches within an image.

2 Related work

Interpretable computer vision models. ProtoPNet [9] was the first method to show that deep learning computer vision models could be designed to explain their predictions and obtain competitive performance compared to non-interpretable approaches. They introduced prototypes for interpretability, visual representations of concepts from the training data that can be used to explain why a model made a particular prediction. Further work, including ProtoTree [27], ProtoPShare [28], ProtoPool [29], ProtoPFormer [11] and PIP-Net [10] improve on this approach. However, these methods can be challenging to adapt to new datasets, introduce additional complexities and may require multiple training steps [7]. In contrast, the recent INTR [7] approach uses the cross-attention layer from a transformer decoder head to produce explanations of image classifications. While elegant, this moves away from prototypes associated with the training data, making these models less transparent than the ProtoP family.

Prototypical learning approaches. Prototypical deep learning methods learn a mapping from the input data to a metric or semi-metric space, where classification is performed by determining the distance of new inputs to a set of prototypical representations of each class [30]. Interpretable vision models that use a distance metric between prototypes and image patches are a form of prototypical learning, and other work has found prototypical approaches beneficial in improving out-of-distribution detection [31, 32], out-of-distribution generalisation [33] and semi-supervised learning [24, 25, 26]. They have also been used to improve model robustness [34], for few-shot learning [30] and also in multi-modal applications [18].

Self-supervised learning and foundation models. In computer vision, self-supervised approaches train neural networks as functional mappings between the image domain and a representation or latent space that captures the semantic information contained within the image with minimal loss. State-of-the-art self-supervised approaches can obtain competitive performance on downstream tasks such as image classification, segmentation and object detection when compared to fully supervised approaches without any further fine-tuning [35, 17]. This is achieved by using techniques such as bootstrapping [36], masked autoencoders [37], or contrastive learning [14], which use image augmentations to design losses that encourage a neural network to project similar images into similar regions of the latent space. By combining these approaches on large scale datasets, foundation models such as DINOv2 [17] can be trained that perform well across a range of visual contexts. Language-image pretraining can also be an effective self-supervised approach, but requires large image datasets with captions to train models [18, 19].

3 Methodology

Motivation. The key idea of the ProtoPNet approach was to introduce a set of prototypes within the latent space of an encoder network, allowing for classification to be based on the distance between image patch embeddings and these prototypes within the latent space [9]. To ensure the interpretability of these prototypes, they constrain them to the embeddings of a patch within the training dataset. Later works move away from this inflexibility and consider the training image patch that represents each prototype to be the closest one within the latent space or the one with the highest probability under the model [10].

Underlying the reasoning process of these interpretable approaches is the encoder network. If this network performs poorly, then the explanation provided by ProtoP approaches will be nonsensical [13, 12]. Recent advances in foundation models for computer vision have lead to foundation models, like the DINOv2 family [17], that provide highly informative patch features. These models define an embedding space where the cosine similarity between image patches describes their degree of semantic similarity [17]. Given this, we ask can we learn prototypes directly within this embedding space that represent our classes?

Notation. We refer to an image as X3×h×wXsuperscript3𝑤\textbf{X}\in\mathbb{R}^{3\times h\times w}X ∈ blackboard_R start_POSTSUPERSCRIPT 3 × italic_h × italic_w end_POSTSUPERSCRIPT, where w𝑤witalic_w is the image width and hhitalic_h is the height. Throughout this paper, we use bolded characters A to refer to matrices or arrays, and the notation Ai:subscriptA:𝑖absent\textbf{A}_{i:}A start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT refers to the ithsuperscript𝑖thi^{\text{th}}italic_i start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT row of a matrix. The respective lowercase letter with two subscripts aijsubscript𝑎𝑖𝑗a_{ij}italic_a start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT refers to a specific element of the matrix A. We consider a frozen backbone encoder model f𝑓fitalic_f and a head network gθsubscript𝑔𝜃g_{\theta}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT that make predictions, where θ𝜃\thetaitalic_θ refers to the model parameters.

Interpretability with component features. We consider the encoded patch representation of an image, Z=f(X)NZ×dZ𝑓Xsuperscriptsubscript𝑁𝑍𝑑\textbf{Z}=f(\textbf{X})\in\mathbb{R}^{N_{Z}\times d}Z = italic_f ( X ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT, where d𝑑ditalic_d is the dimensionality of the latent space and NZsubscript𝑁𝑍N_{Z}italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT is the number of patches created created by the backbone model f𝑓fitalic_f. For the ViT architecture [38] as used by the DINOv2 models [17], this depends on the input image resolution. We also consider class prototypes CNC×dCsuperscriptsubscript𝑁𝐶𝑑\textbf{C}\in\mathbb{R}^{N_{C}\times d}C ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT, where we have NCsubscript𝑁𝐶N_{C}italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT class prototypes with a fixed association to each class as defined by a smoothed one-hot encoded matrix ϕNC×cbold-italic-ϕsuperscriptsubscript𝑁𝐶𝑐\bm{\phi}\in\mathbb{R}^{N_{C}\times c}bold_italic_ϕ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT × italic_c end_POSTSUPERSCRIPT, which can be written as

ϕly::subscriptitalic-ϕ𝑙𝑦absent\displaystyle\phi_{ly}:italic_ϕ start_POSTSUBSCRIPT italic_l italic_y end_POSTSUBSCRIPT :=1α+αcif(jmodNCc)=yelseαcabsent1𝛼𝛼𝑐ifmodulo𝑗subscript𝑁𝐶𝑐𝑦else𝛼𝑐\displaystyle=1-\alpha+\frac{\alpha}{c}\text{ if }(j\mod\frac{N_{C}}{c})=y%\text{ else }\frac{\alpha}{c}= 1 - italic_α + divide start_ARG italic_α end_ARG start_ARG italic_c end_ARG if ( italic_j roman_mod divide start_ARG italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_ARG start_ARG italic_c end_ARG ) = italic_y else divide start_ARG italic_α end_ARG start_ARG italic_c end_ARG(1)

where α𝛼\alphaitalic_α is the smoothing parameter and c𝑐citalic_c is the number of classes.

Similarly to the ProtoP methods, the challenge here is to learn the class prototypes C, such that we identify the locations within this feature space that relate to particular concepts (e.g. a particular species of bird). This would allow us to undertake interpretable inference on an image by identifying the patches Z that are close within the latent space to particular class prototypes C. This can be posed as a clustering problem, where the goal is to learn distributions parameterised by the class prototypes C that cluster the patch embeddings Z into meaningful groups.

We note that while there can be a large number of embedded patches Z per image, many of them will be similar, particularly if they make up the same part of a particular object or scene. We can use this observation to reduce the complexity of the problem by introducing image prototypes PNP×dPsuperscriptsubscript𝑁𝑃𝑑\textbf{P}\in\mathbb{R}^{N_{P}\times d}P ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT which represent component features within an image, where NPsubscript𝑁𝑃N_{P}italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT is the number of desired image prototypes and is much smaller than the number of image patches NZsubscript𝑁𝑍N_{Z}italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT. We pose finding the image prototypes for a particular image as another clustering problem, such that the prototype Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT represents the center of a cluster of embeddings from Z. This allows for these image prototypes P to be used to summarise an image, and facilitate interpretable inference that may be more natural than considering similarity heatmaps or individual patches.

Modelling framework. To formalise the model, we consider the following joint probability

p(Zi:,Pj:,y)::𝑝subscriptZ:𝑖absentsubscriptP:𝑗absent𝑦absent\displaystyle p(\textbf{Z}_{i:},\textbf{P}_{j:},y):italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT , italic_y ) :=p(Zi:|Pj:)p(Pj:|y)p(y)absent𝑝conditionalsubscriptZ:𝑖absentsubscriptP:𝑗absent𝑝conditionalsubscriptP:𝑗absent𝑦𝑝𝑦\displaystyle=p(\textbf{Z}_{i:}|\textbf{P}_{j:})p(\textbf{P}_{j:}|y)p(y)= italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | italic_y ) italic_p ( italic_y )(2)

where p(y)𝑝𝑦p(y)italic_p ( italic_y ) is the prior distribution for the classes. This joint probability describes a generative model where the classes y𝑦yitalic_y parameterised by class prototypes C generates image prototypes P, which in turn generate the patch embeddings Z. This defines the image patches Z as independent of the class y𝑦yitalic_y, when conditioned on an image prototype Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT. We parameterise these distributions using a hierarchical mixture modelling approach [39, 40]

p(Zi:|Pj:)::𝑝conditionalsubscriptZ:𝑖absentsubscriptP:𝑗absentabsent\displaystyle p(\textbf{Z}_{i:}|\textbf{P}_{j:}):italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) :=k(Zi:,Pj:;τ1)absent𝑘subscriptZ:𝑖absentsubscriptP:𝑗absentsubscript𝜏1\displaystyle=k(\textbf{Z}_{i:},\textbf{P}_{j:};\tau_{1})= italic_k ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )(3)
p(Pj:|y)::𝑝conditionalsubscriptP:𝑗absent𝑦absent\displaystyle p(\textbf{P}_{j:}|y):italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | italic_y ) :=lNCϕlyk(Pj:,Cl:;τ2)absentsuperscriptsubscript𝑙subscript𝑁𝐶subscriptitalic-ϕ𝑙𝑦𝑘subscriptP:𝑗absentsubscriptC:𝑙absentsubscript𝜏2\displaystyle=\sum_{l}^{N_{C}}\phi_{ly}k(\textbf{P}_{j:},\textbf{C}_{l:};\tau_%{2})= ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT italic_l italic_y end_POSTSUBSCRIPT italic_k ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT , C start_POSTSUBSCRIPT italic_l : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )(4)
k(Zi:,Pj:;τ1)::𝑘subscriptZ:𝑖absentsubscriptP:𝑗absentsubscript𝜏1absent\displaystyle k(\textbf{Z}_{i:},\textbf{P}_{j:};\tau_{1}):italic_k ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) :=Cd(1τ1)exp(1τ1Zi:.Pj:Zi:Pj:)absentsubscript𝐶𝑑1subscript𝜏11subscript𝜏1formulae-sequencesubscriptZ:𝑖absentsubscriptP:𝑗absentnormsubscriptZ:𝑖absentnormsubscriptP:𝑗absent\displaystyle=C_{d}\left(\frac{1}{\tau_{1}}\right)\exp\left(\frac{1}{\tau_{1}}%\frac{\textbf{Z}_{i:}.\textbf{P}_{j:}}{||\textbf{Z}_{i:}||||\textbf{P}_{j:}||}\right)= italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ) roman_exp ( divide start_ARG 1 end_ARG start_ARG italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG divide start_ARG Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT end_ARG start_ARG | | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | | | | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | | end_ARG )(5)

where k𝑘kitalic_k is a von Mises-Fischer distribution [41] with normalising constant Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT and concentration defined by the temperature parameters τ1subscript𝜏1\tau_{1}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and τ2subscript𝜏2\tau_{2}italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.

Training objective. To learn the image P and class C prototypes we minimise the negative log-likelihood of our model. While we do not know which patches Z belong to particular classes y𝑦yitalic_y, we can instead consider the global image labels ν𝜈\nuitalic_ν as a one-hot encoded class vector. The probability of a class being present within an image can be defined as the maximum class probability over all image patches,

p(νl|Z)::𝑝conditionalsubscript𝜈𝑙Zabsent\displaystyle p(\nu_{l}|\textbf{Z}):italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | Z ) :=maxip(y=l|Zi:)absentsubscript𝑖𝑝𝑦conditional𝑙subscriptZ:𝑖absent\displaystyle=\max_{i}p(y=l|\textbf{Z}_{i:})= roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p ( italic_y = italic_l | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )(6)

where we use Bayes’ rule and marginalise out the image prototypes to obtain the class prediction for a particular patch under our modelling framework

p(y|Zi:)𝑝conditional𝑦subscriptZ:𝑖absent\displaystyle p(y|\textbf{Z}_{i:})italic_p ( italic_y | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )=jNPp(y|Pj:)p(Pj:|Zi:)absentsuperscriptsubscript𝑗subscript𝑁𝑃𝑝conditional𝑦subscriptP:𝑗absent𝑝conditionalsubscriptP:𝑗absentsubscriptZ:𝑖absent\displaystyle=\sum_{j}^{N_{P}}p(y|\textbf{P}_{j:})p(\textbf{P}_{j:}|\textbf{Z}%_{i:})= ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( italic_y | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )(7)
=jNPσ(P^j:.C^/τ2)ϕ:yσ(Z^i:.P^/τ1)j\displaystyle=\sum_{j}^{N_{P}}\sigma\left(\hat{\textbf{P}}_{j:}.\hat{\textbf{C%}}/\tau_{2}\right)\bm{\phi}_{:y}\sigma\left(\hat{\textbf{Z}}_{i:}.\hat{\textbf%{P}}/\tau_{1}\right)_{j}= ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_σ ( over^ start_ARG P end_ARG start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT . over^ start_ARG C end_ARG / italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) bold_italic_ϕ start_POSTSUBSCRIPT : italic_y end_POSTSUBSCRIPT italic_σ ( over^ start_ARG Z end_ARG start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . over^ start_ARG P end_ARG / italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT(8)

where σ𝜎\sigmaitalic_σ is the softmax function and Z^i:subscript^Z:𝑖absent\hat{\textbf{Z}}_{i:}over^ start_ARG Z end_ARG start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT is the L2 normalised form of Zi:subscriptZ:𝑖absent\textbf{Z}_{i:}Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT. Our final training objective is given by minimizing the negative log likelihood

logp(Z,ν)𝑝Z𝜈\displaystyle-\log p(\textbf{Z},\nu)- roman_log italic_p ( Z , italic_ν )=lclogp(νl|Z)logp(Z)absentsuperscriptsubscript𝑙𝑐𝑝conditionalsubscript𝜈𝑙Z𝑝Z\displaystyle=-\sum_{l}^{c}\log p(\nu_{l}|\textbf{Z})-\log p(\textbf{Z})%\approx\mathcal{L}= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | Z ) - roman_log italic_p ( Z ) ≈ caligraphic_L(9)

which we separate into a discriminative term using binary cross entropy and a clustering term, with additional auxiliary objectives which ensure the consistency of image prototypes, the uniqueness of class prototypes and improve the fitting process.

\displaystyle\mathcal{L}caligraphic_L=discrim+cluster+auxabsentsubscriptdiscrimsubscriptclustersubscriptaux\displaystyle=\mathcal{L}_{\text{discrim}}+\mathcal{L}_{\text{cluster}}+%\mathcal{L}_{\text{aux}}= caligraphic_L start_POSTSUBSCRIPT discrim end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT cluster end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT aux end_POSTSUBSCRIPT(10)
clustersubscriptcluster\displaystyle\mathcal{L}_{\text{cluster}}caligraphic_L start_POSTSUBSCRIPT cluster end_POSTSUBSCRIPT=1NZiNZlogjNPk(Zi:,Pj:;τ1)absent1subscript𝑁𝑍superscriptsubscript𝑖subscript𝑁𝑍superscriptsubscript𝑗subscript𝑁𝑃𝑘subscriptZ:𝑖absentsubscriptP:𝑗absentsubscript𝜏1\displaystyle=-\frac{1}{N_{Z}}\sum_{i}^{N_{Z}}\log\sum_{j}^{N_{P}}k(\textbf{Z}%_{i:},\textbf{P}_{j:};\tau_{1})= - divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_k ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )(11)
discrimsubscriptdiscrim\displaystyle\mathcal{L}_{\text{discrim}}caligraphic_L start_POSTSUBSCRIPT discrim end_POSTSUBSCRIPT=lcνllogp(νl=1|Z)+(1νl)log(1p(νl=1|Z))absentsuperscriptsubscript𝑙𝑐subscript𝜈𝑙𝑝subscript𝜈𝑙conditional1Z1subscript𝜈𝑙1𝑝subscript𝜈𝑙conditional1Z\displaystyle=-\sum_{l}^{c}\nu_{l}\log p(\nu_{l}=1|\textbf{Z})+(1-\nu_{l})\log%(1-p(\nu_{l}=1|\textbf{Z}))= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 1 | Z ) + ( 1 - italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) roman_log ( 1 - italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 1 | Z ) )(12)

A more detailed description of our modelling approach and the auxiliary terms is provided in the supporting information, along with an ablation study.

Fitting the prototypes. We consider the class prototypes C as a matrix of learnable parameters, while the image prototypes P are fit parametrically. For each image, we generate the image prototypes using a transformer decoder gθsubscript𝑔𝜃g_{\theta}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT [20]

Pi:=gθ(Z,Qi:)subscriptP:𝑖absentsubscript𝑔𝜃ZsubscriptQ:𝑖absent\displaystyle\textbf{P}_{i:}=g_{\theta}(\textbf{Z},\textbf{Q}_{i:})P start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT = italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( Z , Q start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )(13)

where QNP×dQsuperscriptsubscript𝑁𝑃𝑑\textbf{Q}\in\mathbb{R}^{N_{P}\times d}Q ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT is a learnable query matrix, that prompts the decoder to calculate the image prototypes by considering the entire patch representation of the image Z.

Background classes. To allow the model to learn non-informative patches, we add an additional background class and assume that each image has one or more background patches. This turns a multi-class image classification problem into a multi-label one, where we add the background class to the label for each image and add NNsubscript𝑁𝑁N_{N}italic_N start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT additional background class prototypes

νw/bsubscript𝜈𝑤𝑏\displaystyle\nu_{w/b}italic_ν start_POSTSUBSCRIPT italic_w / italic_b end_POSTSUBSCRIPT=[ν,1]absent𝜈1\displaystyle=\left[\nu,1\right]= [ italic_ν , 1 ](14)
ϕw/bsubscriptbold-italic-ϕ𝑤𝑏\displaystyle\bm{\phi}_{w/b}bold_italic_ϕ start_POSTSUBSCRIPT italic_w / italic_b end_POSTSUBSCRIPT=[[ϕ>αc,0],[ec+1]×NN](1α)+αc+1absentdelimited-[]bold-italic-ϕ𝛼𝑐0delimited-[]subscript𝑒𝑐1subscript𝑁𝑁1𝛼𝛼𝑐1\displaystyle=\left[\left[\bm{\phi}>\frac{\alpha}{c},0\right],\left[e_{c+1}%\right]\times N_{N}\right](1-\alpha)+\frac{\alpha}{c+1}= [ [ bold_italic_ϕ > divide start_ARG italic_α end_ARG start_ARG italic_c end_ARG , 0 ] , [ italic_e start_POSTSUBSCRIPT italic_c + 1 end_POSTSUBSCRIPT ] × italic_N start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ( 1 - italic_α ) + divide start_ARG italic_α end_ARG start_ARG italic_c + 1 end_ARG(15)

where eic+1subscript𝑒𝑖superscript𝑐1e_{i}\in\mathbb{R}^{c+1}italic_e start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_c + 1 end_POSTSUPERSCRIPT is a unit vector with one in the ithsuperscript𝑖thi^{\text{th}}italic_i start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT position. We also extend C to obtain Cw/b(NC+NN)×dsubscriptC𝑤𝑏superscriptsubscript𝑁𝐶subscript𝑁𝑁𝑑\textbf{C}_{w/b}\in\mathbb{R}^{(N_{C}+N_{N})\times d}C start_POSTSUBSCRIPT italic_w / italic_b end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_N start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) × italic_d end_POSTSUPERSCRIPT. This can be done without changing the form of the discriminative loss terms in Eq.12 as the binary cross-entropy loss naturally handles multi-label classification problems.

Comparison to closely related work. The ComFe head shares a similar architecture to the INTR [7] approach, in that they both use a transformer decoder on the outputs of a backbone network to make a prediction. However, INTR uses the query Q to define classes and relies on cross-attention within the decoder heads to interpret why the model is making a particular prediction. For ComFe, the query has a fixed size that does not depend on the number of classes and we use the similarity of the image P and class C prototypes to explain the model outputs. This makes ComFe much more scalable, allowing us to easily train models on ImageNet with large batch sizes on a single GPU.

4 Experiments

4.1 Implementation details

The weights of our backbone model f𝑓fitalic_f are frozen during training, and the parameters of the transformer decoder gθsubscript𝑔𝜃g_{\theta}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT are randomly initialized. We also randomly initialise our queries Q and class prototypes C. We use the AdamW optimiser [42], cosine learning rate decay with linear warmup [43, 44] and gradient clipping [45]. For image augmentations, we follow DINOv2 and other works, including random cropping, flipping, color distortion and random greyscale [14, 17].

For the architecture of the ComFe head gθsubscript𝑔𝜃g_{\theta}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, we use two transformer decoder layers with eight attention heads. We calculate the loss after each decoder layer and take gradients on the average [23]. We feed the output of the last layer of the backbone model directly into the transformer decoder, which results in a higher input dimension in comparison to other works [7], but we find that this gives best results. We use a total of five image prototypes (which means Q will have five rows), and 6c6𝑐6c6 italic_c class prototypes C. The first 3c3𝑐3c3 italic_c class prototypes are assigned to each class (for three per class), while the remaining 3c3𝑐3c3 italic_c are assigned to the background class. We use similar temperature parameters to previous works [24], with τ1=0.1subscript𝜏10.1\tau_{1}=0.1italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0.1, τ2=0.02subscript𝜏20.02\tau_{2}=0.02italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0.02 and τc=0.02subscript𝜏𝑐0.02\tau_{c}=0.02italic_τ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 0.02. An ablation study on these hyperparameters is undertaken in AppendixD of the supporting information.

The same set of hyperparameters are used across all of the training runs, with the exception of the batch size which we increase from 64 images to 1024 for the ImageNet dataset. We also reduce the number of epochs for ImageNet [46], from 50 epochs per a training run to 20 epochs. Training on ImageNet with the ViT-S model takes around 21 hours, and for the ViT-L model around 35 hours with an 80GB Nvidia A100 GPU. Smaller datasets such as Oxford Pets [47] take less than fifteen minutes. For the ViT-S model, a batch size of 64 requires around 2GB of GPU memory, and a batch size of 1024 requires 20GB of GPU memory. While there is the potential for pretraining on a dataset with ComFe by dropping the discriminative loss, we have not found that this is necessary to obtain good performance. Further details on the hyperparameters are available in the codebase released on GitHub.

4.2 Datasets

Finegrained image benchmarking datasets including Oxford Pets (37 classes) [47], FGVC Aircraft (100 classes) [48], Stanford Cars (196 classes) [49] and CUB200 (200 classes) [50] have all previously been used to benchmark interpretable computer vision models. We use these to test the performance of ComFe, in addition to other datasets including ImageNet-1K (1000 classes) [46], CIFAR-10 (10 classes), CIFAR-100 (100 classes) [51], Flowers-102 (102 classes) [52] and Food-101 (101 classes) [53]. These cover a range of different domains and have been previously used to evaluate the linear fine-tuning performance of the DINOv2 models [17].

We also consider a number of different test datasets for ImageNet that are designed to measure model generalisability and robustness. ImageNet-V2 [54] follows the original data distribution but contains harder examples, Sketch [55] tests if models generalise to sketches of the ImageNet class and ImageNet-R [56] tests if models generalise to art, cartoons, graphics and other renditions. Finally, the ImageNet-A [57] test dataset consists of real-world, unmodified and naturally occurring images of the ImageNet classes that are often misclassified by ResNet models. While ImageNet-V2 and Sketch contain all of the ImageNet classes, ImageNet-A and ImageNet-R only provide images for a 200 class subset.

4.3 Main results

HeadBackboneDataset
CUBPetsCarsAircr
Non-interpretableLinear [7]ResNet-5026M83.889.589.380.9
Linear [17]DINOv2 ViT-S/14 (f)21M88.195.181.674.0
InterpretableProtoPNet [9]ResNet-3422M79.286.1
ProtoTree [27]ResNet-3422M82.286.6
ProtoPShare [28]ResNet-3422M74.786.4
ProtoPool [29]ResNet-5026M85.588.9
ProtoPFormer [11]CaiT-XXS-2412M84.990.9
ProtoPFormer [11]DeiT-S22M84.891.0
PIP-Net [10]ConvNeXt-tiny29M84.392.088.2
PIP-Net [10]ResNet-5026M82.088.586.5
INTR [7]10MResNet-5026M71.890.486.876.1
ComFe8MDINOv2 ViT-S/14 (f)21M87.694.691.177.5

ComFe obtains better performance in comparison to previous interpretable approaches. Table1 shows that ComFe obtains competitive performance in comparison to previous interpretable models of similar sizes. ComFe is able to achieve these results using the same set of hyperparameters for all datasets, with the exception of the number of class prototypes C, where previous methods tune the hyperparameters for each dataset [7, 10, 11]. ComFe is also efficient to fit, as it uses a frozen backbone trained using self-supervised learning while other methods fit all of the backbone weights. While for some datasets this frozen backbone provides a performance boost when considering a linear probe (CUB200, Oxford Pets), for other datasets it actually performs worse than fitting a full ResNet-50 model (Stanford Cars, FGVC Aircraft) and ComFe still outperforms other approaches in these cases.

HeadBackboneDataset
IN-1KC10C100FoodCUBPetsCarsAircrFlowers
Linear [17]DINOv2 ViT-S/14 (f)21M81.197.787.589.188.195.181.674.099.6
Linear [17]DINOv2 ViT-B/14 (f)86M84.598.791.392.889.696.288.279.499.6
Linear [17]DINOv2 ViT-L/14 (f)300M86.399.393.494.390.596.690.181.599.7
ComFe8MDINOv2 ViT-S/14 (f)21M83.098.389.292.187.694.691.177.599.0
ComFe32MDINOv2 ViT-B/14 (f)86M85.699.192.294.288.395.392.681.199.3
ComFe57MDINOv2 ViT-L/14 (f)300M86.799.493.694.689.295.993.683.999.4

ComFe obtains competitive performance compared to a non-interpretable linear head. Table2 shows that ComFe outperforms using a linear head on frozen features on a number of datasets, including ImageNet, CIFAR-10, CIFAR-100, Food-101, StanfordCars and FGVC Aircraft. For smaller models, we observe that ComFe outperforms the linear head by a larger margin. For species identification datasets, such as CUB200, Oxford Pets and Flowers-102, ComFe results in slightly lower accuracy compared to the non-interpretable linear head.

HeadBackboneTest Dataset
IN-V2SketchIN-RIN-A
Linear [17]DINOv2 ViT-S/14 (f)21M70.941.237.518.9
Linear [17]DINOv2 ViT-B/14 (f)86M75.150.647.337.3
Linear [17]DINOv2 ViT-L/14 (f)300M78.059.357.952.0
ComFe8MDINOv2 ViT-S/14 (f)21M73.445.342.126.4
ComFe32MDINOv2 ViT-B/14 (f)86M77.354.851.943.2
ComFe57MDINOv2 ViT-L/14 (f)300M78.759.558.855.0

ComFe generalises better and is more robust than a linear head. Table3 shows that ComFe is able to improve performance on the ImageNet-V2 test set while also improving performance on the generalisation and robustness benchmarks considered. Previous work has shown that while self-supervised learning approaches such as CLIP can improve their performance on ImageNet and ImageNet-V2 by fine-tuning the model backbone, doing so reduces performance on Sketch, ImageNet-A and ImageNet-R [18]. This suggests that ComFe may be a viable option to improve model performance for some datasets, without compromising on the generalisability and robustness of the model.

\begin{overpic}[height=34.14322pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Input image\@add@centering}}\end{overpic}

\begin{overpic}[height=34.14322pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Image\\prototypes\@add@centering}}\end{overpic}

\begin{overpic}[height=34.14322pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Class\\prediction\@add@centering}}\end{overpic}

\begin{overpic}[height=34.14322pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Class prototype\\exemplars\@add@centering}}\end{overpic}

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (8)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (9)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (10)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (11)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (12)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (13)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (14)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (15)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (16)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (17)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (18)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (19)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (20)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (21)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (22)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (23)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (24)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (25)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (26)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (27)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (28)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (29)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (30)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (31)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (32)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (33)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (34)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (35)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (36)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (37)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (38)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (39)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (40)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (41)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (42)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (43)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (44)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (45)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (46)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (47)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (48)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (49)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (50)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (51)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (52)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (53)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (54)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (55)

The image prototypes learned by ComFe identify visual components across classes. In Fig.3 it can be seen that the image prototypes P partition images containing different classes in a consistent manner. On the CUB200 dataset, the first image prototype (red) identifies the head of birds, while the second image prototype (yellow) identifies the wings and tail feathers. The third image prototype (green) identifies the body, and the final two image prototypes (blue and cyan) are always associated with the background. Similar patterns are observed across the other datasets, with particular image prototypes identifying the front, side and upper cabin and roof of the car for Stanford Cars, and the FGVC Aircraft dataset has image prototypes locating the wings, tail and the body of a plane.

The class prototypes learned by ComFe identify the category which image prototypes belong to. Fig.3 also shows how the image prototypes P are classified as informative (belonging to a specific class) or non-informative by the class prototypes C, as demonstrated by the class predictions. We observe that when identifying birds the head appears to be non-informative on the CUB200 dataset, and that the upper vehicle cabin and roof are non-informative in the Stanford Cars dataset, whereas the whole plane is regarded as informative on the FGVC Aircraft dataset. This reflects the image prototypes from the training set found to be class prototype exemplars, as shown in the last five rows of Fig.3, and also in Fig.S5 in the supporting information which shows exemplars for all of the class prototypes.

We match exemplars to class prototypes by looking for the image prototypes within the training dataset with the largest cosine similarity to each class prototype, similar to previous works [10]. When predictions are made using ComFe, the most relevant class prototype C can be found by looking for the one with the largest cosine similarity to the most confident image prototype given by p(y|P)𝑝conditional𝑦Pp(y|\textbf{P})italic_p ( italic_y | P ) (excluding the background class and pruning unused prototypes), as shown in Eq.24 in the supporting information. This finds the class prototype C from Eq.8 that provides the largest contribution to the prediction.

Additionally, we note that the features that are found to be informative by ComFe can vary depending on the initialisation, as discussed in AppendixC in the supporting information.

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (56)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (57)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (58)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (59)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (60)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (61)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (62)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (63)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (64)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (65)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (66)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (67)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (68)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (69)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (70)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (71)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (72)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (73)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (74)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (75)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (76)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (77)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (78)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (79)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (80)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (81)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (82)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (83)

ComFe is able to localise salient image features across a range of datasets. Fig.4 shows that ComFe learns to identify the image prototypes that relate to the classes of interest across a range of datasets, from small low resolution datasets with few classes like CIFAR-10, to large high resolution datasets with a thousand classes such as ImageNet. We find that this property of ComFe reflects the quality of the backbone encoder model f𝑓fitalic_f. We show in AppendixC of the supporting information that using the DINOv2 models with registers [58] improves the interpretability of ComFe with the larger ViT variants and results in higher accuracy.

5 Conclusion

In this work we have presented ComFe, a novel interpretable-by-design image classifier that uses a transformer decoder architecture and a hierarchical mixture modelling approach to identify the regions of the embedding space defined by a computer vision foundation model that relate to particular concepts. This allows for the model to explain how an image is classified, by examining the semantic similarity between component features within the image and a set of class prototypes. The ComFe approach outperforms other interpretable methods across a range of benchmarking datasets, without the need to finetune the hyperparameters in each case. Additionally, we demonstrate that ComFe outperforms a non-interpretable linear head over a number of different datasets, and that ComFe also improves performance on robustness and generalisablity benchmarks.

Acknowledgements

We acknowledge the Traditional Owners of the unceded land on which the research detailed in this paper was undertaken: the Wurundjeri Woi-wurrung and Bunurong peoples.This research was undertaken using the LIEF HPC-GPGPU Facility hosted at the University of Melbourne. This Facility was established with the assistance of LIEF Grant LE170100200.This research was also undertaken with the assistance of resources and services from the National Computational Infrastructure (NCI), which is supported by the Australian Government.Evelyn Mannix was supported by a Australian Government Research Training Program Scholarship to complete this work.

References

  • Zhou etal. [2023]SKevin Zhou, Hayit Greenspan, and Dinggang Shen.Deep learning for medical image analysis.Academic Press, 2023.
  • Beloiu etal. [2023]Mirela Beloiu, Lucca Heinzmann, Nataliia Rehush, Arthur Gessler, and VerenaC Griess.Individual tree-crown detection and species identification in heterogeneous forests using aerial rgb imagery and deep learning.Remote Sensing, 15(5):1463, 2023.
  • Lee and Liu [2023]Der-Hau Lee and Jinn-Liang Liu.End-to-end deep learning of lane detection and path prediction for real-time autonomous driving.Signal, Image and Video Processing, 17(1):199–205, 2023.
  • Rudin [2019]Cynthia Rudin.Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead.Nature machine intelligence, 1(5):206–215, 2019.
  • Yang etal. [2022a]Yao-Yuan Yang, Chi-Ning Chou, and Kamalika Chaudhuri.Understanding rare spurious correlations in neural networks.In ICML 2022: Workshop on Spurious Correlations, Invariance and Stability, 2022a.URL https://openreview.net/forum?id=iHU9Ze_5X7n.
  • Zhou etal. [2016]Bolei Zhou, Aditya Khosla, Agata Lapedriza, Aude Oliva, and Antonio Torralba.Learning deep features for discriminative localization.In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 2921–2929, 2016.
  • Paul etal. [2023]Dipanjyoti Paul, Arpita Chowdhury, Xinqi Xiong, Feng-Ju Chang, David Carlyn, Samuel Stevens, Kaiya Provost, Anuj Karpatne, Bryan Carstens, Daniel Rubenstein, etal.A simple interpretable transformer for fine-grained image classification and analysis.arXiv preprint arXiv:2311.04157, 2023.
  • Kim etal. [2022a]Sangwon Kim, Jaeyeal Nam, and ByoungChul Ko.Vit-net: Interpretable vision transformers with neural tree decoder.In International Conference on Machine Learning, pages 11162–11172. PMLR, 2022a.
  • Chen etal. [2019]Chaofan Chen, Oscar Li, Daniel Tao, Alina Barnett, Cynthia Rudin, and JonathanK Su.This looks like that: deep learning for interpretable image recognition.Advances in neural information processing systems, 32, 2019.
  • Nauta etal. [2023]Meike Nauta, Jörg Schlötterer, Maurice van Keulen, and Christin Seifert.Pip-net: Patch-based intuitive prototypes for interpretable image classification.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2744–2753, 2023.
  • Xue etal. [2022]Mengqi Xue, Qihan Huang, Haofei Zhang, Lechao Cheng, Jie Song, Minghui Wu, and Mingli Song.Protopformer: Concentrating on prototypical parts in vision transformers for interpretable image recognition.arXiv preprint arXiv:2208.10431, 2022.
  • Kim etal. [2022b]SunnieSY Kim, Nicole Meister, VikramV Ramaswamy, Ruth Fong, and Olga Russakovsky.Hive: Evaluating the human interpretability of visual explanations.In European Conference on Computer Vision, pages 280–298. Springer, 2022b.
  • Hoffmann etal. [2021]Adrian Hoffmann, Claudio Fanconi, Rahul Rade, and Jonas Kohler.This looks like that… does it? shortcomings of latent space prototype interpretability in deep networks.arXiv preprint arXiv:2105.02968, 2021.
  • Chen etal. [2020a]Ting Chen, Simon Kornblith, Mohammad Norouzi, and Geoffrey Hinton.A simple framework for contrastive learning of visual representations.In International conference on machine learning, pages 1597–1607. PMLR, 2020a.
  • Chen etal. [2021]Ting Chen, Calvin Luo, and Lala Li.Intriguing properties of contrastive losses.Advances in Neural Information Processing Systems, 34:11834–11845, 2021.
  • Chen etal. [2020b]Ting Chen, Simon Kornblith, Kevin Swersky, Mohammad Norouzi, and GeoffreyE Hinton.Big self-supervised models are strong semi-supervised learners.Advances in neural information processing systems, 33:22243–22255, 2020b.
  • Oquab etal. [2023]Maxime Oquab, Timothée Darcet, Théo Moutakanni, Huy Vo, Marc Szafraniec, Vasil Khalidov, Pierre Fernandez, Daniel Haziza, Francisco Massa, Alaaeldin El-Nouby, etal.Dinov2: Learning robust visual features without supervision.arXiv preprint arXiv:2304.07193, 2023.
  • Radford etal. [2021]Alec Radford, JongWook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, etal.Learning transferable visual models from natural language supervision.In International conference on machine learning, pages 8748–8763. PMLR, 2021.
  • Cherti etal. [2023]Mehdi Cherti, Romain Beaumont, Ross Wightman, Mitchell Wortsman, Gabriel Ilharco, Cade Gordon, Christoph Schuhmann, Ludwig Schmidt, and Jenia Jitsev.Reproducible scaling laws for contrastive language-image learning.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 2818–2829, 2023.
  • Vaswani etal. [2017]Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, AidanN Gomez, Łukasz Kaiser, and Illia Polosukhin.Attention is all you need.Advances in neural information processing systems, 30, 2017.
  • Carion etal. [2020]Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, and Sergey Zagoruyko.End-to-end object detection with transformers.In European conference on computer vision, pages 213–229. Springer, 2020.
  • Cheng etal. [2022]Bowen Cheng, Ishan Misra, AlexanderG Schwing, Alexander Kirillov, and Rohit Girdhar.Masked-attention mask transformer for universal image segmentation.In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 1290–1299, 2022.
  • Hong etal. [2023]Yuanduo Hong, Jue Wang, Weichao Sun, and Huihui Pan.Minimalist and high-performance semantic segmentation with plain vision transformers.arXiv preprint arXiv:2310.12755, 2023.
  • Assran etal. [2021]Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Armand Joulin, Nicolas Ballas, and Michael Rabbat.Semi-supervised learning of visual features by non-parametrically predicting view assignments with support samples.In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 8443–8452, 2021.
  • Mo etal. [2023]Sangwoo Mo, Jong-Chyi Su, Chih-Yao Ma, Mido Assran, Ishan Misra, Licheng Yu, and Sean Bell.Ropaws: Robust semi-supervised representation learning from uncurated data.arXiv preprint arXiv:2302.14483, 2023.
  • Mannix and Bondell [2023]Evelyn Mannix and Howard Bondell.Efficient out-of-distribution detection with prototypical semi-supervised learning and foundation models.arXiv preprint, 2023.
  • Nauta etal. [2021]Meike Nauta, Ron VanBree, and Christin Seifert.Neural prototype trees for interpretable fine-grained image recognition.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 14933–14943, 2021.
  • Rymarczyk etal. [2021]Dawid Rymarczyk, Łukasz Struski, Jacek Tabor, and Bartosz Zieliński.Protopshare: Prototypical parts sharing for similarity discovery in interpretable image classification.In Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining, pages 1420–1430, 2021.
  • Rymarczyk etal. [2022]Dawid Rymarczyk, Łukasz Struski, Michał Górszczak, Koryna Lewandowska, Jacek Tabor, and Bartosz Zieliński.Interpretable image classification with differentiable prototypes assignment.In European Conference on Computer Vision, pages 351–368. Springer, 2022.
  • Snell etal. [2017]Jake Snell, Kevin Swersky, and Richard Zemel.Prototypical networks for few-shot learning.Advances in neural information processing systems, 30, 2017.
  • Ming etal. [2022]Yifei Ming, Yiyou Sun, Ousmane Dia, and Yixuan Li.How to exploit hyperspherical embeddings for out-of-distribution detection?arXiv preprint arXiv:2203.04450, 2022.
  • Lu etal. [2024]Haodong Lu, Dong Gong, Shuo Wang, Jason Xue, Lina Yao, and Kristen Moore.Learning with mixture of prototypes for out-of-distribution detection.In The Twelfth International Conference on Learning Representations, 2024.URL https://openreview.net/forum?id=uNkKaD3MCs.
  • Bai etal. [2024]Haoyue Bai, Yifei Ming, Julian Katz-Samuels, and Yixuan Li.Provable out-of-distribution generalization in hypersphere.In The Twelfth International Conference on Learning Representations, 2024.URL https://openreview.net/forum?id=VXak3CZZGC.
  • Yang etal. [2018]Hong-Ming Yang, Xu-Yao Zhang, Fei Yin, and Cheng-Lin Liu.Robust classification with convolutional prototype learning.In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 3474–3482, 2018.
  • Tukra etal. [2023]Samyakh Tukra, Frederick Hoffman, and Ken Chatfield.Improving visual representation learning through perceptual understanding.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 14486–14495, 2023.
  • Grill etal. [2020]Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre Richemond, Elena Buchatskaya, Carl Doersch, Bernardo AvilaPires, Zhaohan Guo, Mohammad GheshlaghiAzar, etal.Bootstrap your own latent-a new approach to self-supervised learning.Advances in neural information processing systems, 33:21271–21284, 2020.
  • He etal. [2022]Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross Girshick.Masked autoencoders are scalable vision learners.In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 16000–16009, 2022.
  • Dosovitskiy etal. [2020]Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, etal.An image is worth 16x16 words: Transformers for image recognition at scale.arXiv preprint arXiv:2010.11929, 2020.
  • Marin etal. [2005]Jean-Michel Marin, Kerrie Mengersen, and ChristianP Robert.Bayesian modelling and inference on mixtures of distributions.Handbook of statistics, 25:459–507, 2005.
  • Zhou etal. [2020]Shuo Zhou, Howard Bondell, Antoinette Tordesillas, Benjamin I.P. Rubinstein, and James Bailey.Early identification of an impending rockslide location via a spatially-aided Gaussian mixture model.The Annals of Applied Statistics, 14(2):977 – 992, 2020.doi: 10.1214/20-AOAS1326.URL https://doi.org/10.1214/20-AOAS1326.
  • Govindarajan etal. [2022]Hariprasath Govindarajan, Per Sidén, Jacob Roll, and Fredrik Lindsten.Dino as a von mises-fisher mixture model.In The Eleventh International Conference on Learning Representations, 2022.
  • Loshchilov and Hutter [2017]Ilya Loshchilov and Frank Hutter.Decoupled weight decay regularization.arXiv preprint arXiv:1711.05101, 2017.
  • Loshchilov and Hutter [2016]Ilya Loshchilov and Frank Hutter.Sgdr: Stochastic gradient descent with warm restarts.arXiv preprint arXiv:1608.03983, 2016.
  • Gotmare etal. [2018]Akhilesh Gotmare, NitishShirish Keskar, Caiming Xiong, and Richard Socher.A closer look at deep learning heuristics: Learning rate restarts, warmup and distillation.arXiv preprint arXiv:1810.13243, 2018.
  • Mikolov etal. [2012]Tomáš Mikolov etal.Statistical language models based on neural networks.Presentation at Google, Mountain View, 2nd April, 80(26), 2012.
  • Russakovsky etal. [2015]Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, AlexanderC. Berg, and LiFei-Fei.ImageNet Large Scale Visual Recognition Challenge.International Journal of Computer Vision (IJCV), 115(3):211–252, 2015.doi: 10.1007/s11263-015-0816-y.
  • Parkhi etal. [2012]OmkarM Parkhi, Andrea Vedaldi, Andrew Zisserman, and CVJawahar.Cats and dogs.In 2012 IEEE conference on computer vision and pattern recognition, pages 3498–3505. IEEE, 2012.
  • Maji etal. [2013]Subhransu Maji, Esa Rahtu, Juho Kannala, Matthew Blaschko, and Andrea Vedaldi.Fine-grained visual classification of aircraft.arXiv preprint arXiv:1306.5151, 2013.
  • Krause etal. [2013]Jonathan Krause, Michael Stark, Jia Deng, and LiFei-Fei.3d object representations for fine-grained categorization.In Proceedings of the IEEE international conference on computer vision workshops, pages 554–561, 2013.
  • Welinder etal. [2010]Peter Welinder, Steve Branson, Takeshi Mita, Catherine Wah, Florian Schroff, Serge Belongie, and Pietro Perona.Caltech-ucsd birds 200.2010.
  • Krizhevsky etal. [2009]Alex Krizhevsky, Geoffrey Hinton, etal.Learning multiple layers of features from tiny images.2009.
  • Nilsback and Zisserman [2008]Maria-Elena Nilsback and Andrew Zisserman.Automated flower classification over a large number of classes.In 2008 Sixth Indian conference on computer vision, graphics & image processing, pages 722–729. IEEE, 2008.
  • Bossard etal. [2014]Lukas Bossard, Matthieu Guillaumin, and Luc VanGool.Food-101–mining discriminative components with random forests.In Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part VI 13, pages 446–461. Springer, 2014.
  • Recht etal. [2019]Benjamin Recht, Rebecca Roelofs, Ludwig Schmidt, and Vaishaal Shankar.Do imagenet classifiers generalize to imagenet?In International conference on machine learning, pages 5389–5400. PMLR, 2019.
  • Wang etal. [2019]Haohan Wang, Songwei Ge, Zachary Lipton, and EricP Xing.Learning robust global representations by penalizing local predictive power.Advances in Neural Information Processing Systems, 32, 2019.
  • Hendrycks etal. [2021a]Dan Hendrycks, Steven Basart, Norman Mu, Saurav Kadavath, Frank Wang, Evan Dorundo, Rahul Desai, Tyler Zhu, Samyak Parajuli, Mike Guo, etal.The many faces of robustness: A critical analysis of out-of-distribution generalization.In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 8340–8349, 2021a.
  • Hendrycks etal. [2021b]Dan Hendrycks, Kevin Zhao, Steven Basart, Jacob Steinhardt, and Dawn Song.Natural adversarial examples.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 15262–15271, 2021b.
  • Darcet etal. [2023]Timothée Darcet, Maxime Oquab, Julien Mairal, and Piotr Bojanowski.Vision transformers need registers.arXiv preprint arXiv:2309.16588, 2023.
  • Silva and Rivera [2022]Thalles Silva and AdínRamírez Rivera.Representation learning via consistent assignment of views to clusters.In Proceedings of the 37th ACM/SIGAPP Symposium on Applied Computing, pages 987–994, 2022.
  • Jiang etal. [2021]Zi-Hang Jiang, Qibin Hou, LiYuan, Daquan Zhou, Yujun Shi, Xiaojie Jin, Anran Wang, and Jiashi Feng.All tokens matter: Token labeling for training better vision transformers.Advances in neural information processing systems, 34:18590–18602, 2021.
  • Zhao etal. [2023]Bingyin Zhao, Zhiding Yu, Shiyi Lan, Yutao Cheng, Anima Anandkumar, Yingjie Lao, and JoseM Alvarez.Fully attentional networks with self-emerging token labeling.In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 5585–5595, 2023.
  • Yang etal. [2022b]Jingkang Yang, Pengyun Wang, Dejian Zou, Zitang Zhou, Kunyuan Ding, Wenxuan Peng, Haoqi Wang, Guangyao Chen, BoLi, Yiyou Sun, etal.Openood: Benchmarking generalized out-of-distribution detection.Advances in Neural Information Processing Systems, 35:32598–32611, 2022b.
  • Sun etal. [2022]Yiyou Sun, Yifei Ming, Xiaojin Zhu, and Yixuan Li.Out-of-distribution detection with deep nearest neighbors.In International Conference on Machine Learning, pages 20827–20840. PMLR, 2022.
  • Liu etal. [2020]Weitang Liu, Xiaoyun Wang, John Owens, and Yixuan Li.Energy-based out-of-distribution detection.Advances in neural information processing systems, 33:21464–21475, 2020.
  • Park etal. [2023]Jaewoo Park, YoonGyo Jung, and Andrew BengJin Teoh.Nearest neighbor guidance for out-of-distribution detection.In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 1686–1695, 2023.
  • Wang etal. [2022]Haoqi Wang, Zhizhong Li, Litong Feng, and Wayne Zhang.Vim: Out-of-distribution with virtual-logit matching.In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 4921–4930, June 2022.
  • Bitterwolf etal. [2023]Julian Bitterwolf, Maximilian Müller, and Matthias Hein.In or out? fixing imagenet out-of-distribution detection evaluation, 2023.
  • Vaze etal. [2022]Sagar Vaze, Kai Han, Andrea Vedaldi, and Andrew Zisserman.Open-set recognition: A good closed-set classifier is all you need.In International Conference on Learning Representations, 2022.URL https://openreview.net/forum?id=5hLP5JY9S2d.
  • Horn etal. [2018]GrantVan Horn, OisinMac Aodha, Yang Song, Yin Cui, Chen Sun, Alex Shepard, Hartwig Adam, Pietro Perona, and Serge Belongie.The inaturalist species classification and detection dataset, 2018.
  • Cimpoi etal. [2013]Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, and Andrea Vedaldi.Describing textures in the wild, 2013.
  • Al-Daoud and Roberts [1996]Moh’dB Al-Daoud and StuartA Roberts.New methods for the initialisation of clusters.Pattern Recognition Letters, 17(5):451–455, 1996.

Supplementary Material

Appendix A Detailed derivation of the learning objective

Modelling framework. Consider the embeddings of a matrix of image patches, Z=f(X)NZ×dZ𝑓Xsuperscriptsubscript𝑁𝑍𝑑\textbf{Z}=f(\textbf{X})\in\mathbb{R}^{N_{Z}\times d}Z = italic_f ( X ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT as provided by the backbone foundation model f𝑓fitalic_f, and a matrix of image prototypes PNP×dPsuperscriptsubscript𝑁𝑃𝑑\textbf{P}\in\mathbb{R}^{N_{P}\times d}P ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT, where NPsubscript𝑁𝑃N_{P}italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT is the number of image prototypes and d𝑑ditalic_d is the dimensionality of the representation space. Each prototype is generated by a class y𝑦yitalic_y, and we can write the joint probability as

p(Zi:,Pj:,y)::𝑝subscriptZ:𝑖absentsubscriptP:𝑗absent𝑦absent\displaystyle p(\textbf{Z}_{i:},\textbf{P}_{j:},y):italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT , italic_y ) :=p(Zi:|Pj:)p(Pj:|y)p(y)absent𝑝conditionalsubscriptZ:𝑖absentsubscriptP:𝑗absent𝑝conditionalsubscriptP:𝑗absent𝑦𝑝𝑦\displaystyle=p(\textbf{Z}_{i:}|\textbf{P}_{j:})p(\textbf{P}_{j:}|y)p(y)= italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | italic_y ) italic_p ( italic_y )(16)

where p(y)𝑝𝑦p(y)italic_p ( italic_y ) is the prior distribution for the classes, and Z is conditionally independent of y𝑦yitalic_y given an image prototype Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT. Here, the image prototype Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT acts as a latent variable that relates the image patch Zi:subscriptZ:𝑖absent\textbf{Z}_{i:}Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT to the class y𝑦yitalic_y using a generative approach.

We parameterise these distributions using the following mixture model,

p(Zi:|Pj:)::𝑝conditionalsubscriptZ:𝑖absentsubscriptP:𝑗absentabsent\displaystyle p(\textbf{Z}_{i:}|\textbf{P}_{j:}):italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) :=k(Zi:,Pj:;τ1)absent𝑘subscriptZ:𝑖absentsubscriptP:𝑗absentsubscript𝜏1\displaystyle=k(\textbf{Z}_{i:},\textbf{P}_{j:};\tau_{1})= italic_k ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )(17)
p(Pj:|y)::𝑝conditionalsubscriptP:𝑗absent𝑦absent\displaystyle p(\textbf{P}_{j:}|y):italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | italic_y ) :=lNCϕlyk(Pj:,Cl:;τ2)absentsuperscriptsubscript𝑙subscript𝑁𝐶subscriptitalic-ϕ𝑙𝑦𝑘subscriptP:𝑗absentsubscriptC:𝑙absentsubscript𝜏2\displaystyle=\sum_{l}^{N_{C}}\phi_{ly}k(\textbf{P}_{j:},\textbf{C}_{l:};\tau_%{2})= ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ϕ start_POSTSUBSCRIPT italic_l italic_y end_POSTSUBSCRIPT italic_k ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT , C start_POSTSUBSCRIPT italic_l : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )(18)
k(Zi:,Pj:;τ1)::𝑘subscriptZ:𝑖absentsubscriptP:𝑗absentsubscript𝜏1absent\displaystyle k(\textbf{Z}_{i:},\textbf{P}_{j:};\tau_{1}):italic_k ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) :=Cd(1τ1)exp(1τ1Zi:.Pj:Zi:Pj:)absentsubscript𝐶𝑑1subscript𝜏11subscript𝜏1formulae-sequencesubscriptZ:𝑖absentsubscriptP:𝑗absentnormsubscriptZ:𝑖absentnormsubscriptP:𝑗absent\displaystyle=C_{d}\left(\frac{1}{\tau_{1}}\right)\exp\left(\frac{1}{\tau_{1}}%\frac{\textbf{Z}_{i:}.\textbf{P}_{j:}}{||\textbf{Z}_{i:}||||\textbf{P}_{j:}||}\right)= italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ) roman_exp ( divide start_ARG 1 end_ARG start_ARG italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG divide start_ARG Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT end_ARG start_ARG | | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | | | | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | | end_ARG )(19)

where k𝑘kitalic_k is a von Mises-Fisher distribution with normalising constant Cdsubscript𝐶𝑑C_{d}italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT and concentration defined by the temperature parameters τ1subscript𝜏1\tau_{1}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and τ2subscript𝜏2\tau_{2}italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. We consider NCsubscript𝑁𝐶N_{C}italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT components to the mixture model parameterised with means CNC×dCsuperscriptsubscript𝑁𝐶𝑑\textbf{C}\in\mathbb{R}^{N_{C}\times d}C ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT to describe the relationship between the image prototypes Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT and the classes y𝑦yitalic_y. The membership of each mixture component to a particular class is described by the smoothed one-hot encoded matrix ϕNC×cbold-italic-ϕsuperscriptsubscript𝑁𝐶𝑐\bm{\phi}\in\mathbb{R}^{N_{C}\times c}bold_italic_ϕ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT × italic_c end_POSTSUPERSCRIPT

ϕly::subscriptitalic-ϕ𝑙𝑦absent\displaystyle\phi_{ly}:italic_ϕ start_POSTSUBSCRIPT italic_l italic_y end_POSTSUBSCRIPT :=1α+αcif(jmodNCc)=yelseαcabsent1𝛼𝛼𝑐ifmodulo𝑗subscript𝑁𝐶𝑐𝑦else𝛼𝑐\displaystyle=1-\alpha+\frac{\alpha}{c}\text{ if }(j\mod\frac{N_{C}}{c})=y%\text{ else }\frac{\alpha}{c}= 1 - italic_α + divide start_ARG italic_α end_ARG start_ARG italic_c end_ARG if ( italic_j roman_mod divide start_ARG italic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_ARG start_ARG italic_c end_ARG ) = italic_y else divide start_ARG italic_α end_ARG start_ARG italic_c end_ARG(20)

where α𝛼\alphaitalic_α is the smoothing parameter and c𝑐citalic_c is the number of classes.

Predicting patch class membership. We can use Bayes’ rule to obtain the probability that a patch belongs to a particular class under this model, assuming that the prior distributions for the image prototypes p(Pj:)𝑝subscriptP:𝑗absentp(\textbf{P}_{j:})italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) and the classes p(y)𝑝𝑦p(y)italic_p ( italic_y ) are uniform. First, we consider the probability that a patch Zi:subscriptZ:𝑖absent\textbf{Z}_{i:}Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT was generated by a specific image prototype Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT

p(Pj:|Zi:)𝑝conditionalsubscriptP:𝑗absentsubscriptZ:𝑖absent\displaystyle p(\textbf{P}_{j:}|\textbf{Z}_{i:})italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )=p(Zi:|Pj:)p(Pj:)p(Zi:)absent𝑝conditionalsubscriptZ:𝑖absentsubscriptP:𝑗absent𝑝subscriptP:𝑗absent𝑝subscriptZ:𝑖absent\displaystyle=\frac{p(\textbf{Z}_{i:}|\textbf{P}_{j:})p(\textbf{P}_{j:})}{p(%\textbf{Z}_{i:})}= divide start_ARG italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT ) end_ARG(21)
=σ(Z^i:.P^/τ1)j\displaystyle=\sigma\left(\hat{\textbf{Z}}_{i:}.\hat{\textbf{P}}/\tau_{1}%\right)_{j}= italic_σ ( over^ start_ARG Z end_ARG start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . over^ start_ARG P end_ARG / italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT(22)

where σ𝜎\sigmaitalic_σ is the softmax function, and Z^i:subscript^Z:𝑖absent\hat{\textbf{Z}}_{i:}over^ start_ARG Z end_ARG start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT is the L2 normalised form of Zi:subscriptZ:𝑖absent\textbf{Z}_{i:}Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT. We may compute the probability that an image prototype Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT was generated by a particular class y𝑦yitalic_y using the same approach

p(y|Pj:)𝑝conditional𝑦subscriptP:𝑗absent\displaystyle p(y|\textbf{P}_{j:})italic_p ( italic_y | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT )=p(Pj:|y)p(y)p(Pj:)absent𝑝conditionalsubscriptP:𝑗absent𝑦𝑝𝑦𝑝subscriptP:𝑗absent\displaystyle=\frac{p(\textbf{P}_{j:}|y)p(y)}{p(\textbf{P}_{j:})}= divide start_ARG italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | italic_y ) italic_p ( italic_y ) end_ARG start_ARG italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) end_ARG(23)
=σ(P^j:.C^/τ2)ϕ:y\displaystyle=\sigma\left(\hat{\textbf{P}}_{j:}.\hat{\textbf{C}}/\tau_{2}%\right)\bm{\phi}_{:y}= italic_σ ( over^ start_ARG P end_ARG start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT . over^ start_ARG C end_ARG / italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) bold_italic_ϕ start_POSTSUBSCRIPT : italic_y end_POSTSUBSCRIPT(24)

which allows us to determine the probability of a particular class y𝑦yitalic_y, given patch Zi:subscriptZ:𝑖absent\textbf{Z}_{i:}Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT through

p(y|Zi:)𝑝conditional𝑦subscriptZ:𝑖absent\displaystyle p(y|\textbf{Z}_{i:})italic_p ( italic_y | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )=jNPp(y|Pj:)p(Pj:|Zi:)absentsuperscriptsubscript𝑗subscript𝑁𝑃𝑝conditional𝑦subscriptP:𝑗absent𝑝conditionalsubscriptP:𝑗absentsubscriptZ:𝑖absent\displaystyle=\sum_{j}^{N_{P}}p(y|\textbf{P}_{j:})p(\textbf{P}_{j:}|\textbf{Z}%_{i:})= ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( italic_y | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )(25)

where the image prototypes Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT are introduced as a latent variable which relates the patch Zi:subscriptZ:𝑖absent\textbf{Z}_{i:}Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT and class y𝑦yitalic_y, and are then marginalised from the expression.

Deriving the Training objective. We wish to maximise the log-likelihood of our model, however we do not know which patches Z belong to which classes y𝑦yitalic_y. Instead, we consider the image level label ν𝜈\nuitalic_ν, which we view as a one-hot encoded class vector. The probability of a class being present within an image can be defined as the maximum class probability over all image patches, creating a vector of Bernoulli distributions, of which the lthsuperscript𝑙thl^{\text{th}}italic_l start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT element is given by

p(νl|Z)::𝑝conditionalsubscript𝜈𝑙Zabsent\displaystyle p(\nu_{l}|\textbf{Z}):italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | Z ) :=maxip(y=l|Zi:)absentsubscript𝑖𝑝𝑦conditional𝑙subscriptZ:𝑖absent\displaystyle=\max_{i}p(y=l|\textbf{Z}_{i:})= roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p ( italic_y = italic_l | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )(26)

Using this approach, the negative log likelihood for our model for a single image is

logp(Z,ν)𝑝Z𝜈\displaystyle-\log p(\textbf{Z},\nu)- roman_log italic_p ( Z , italic_ν )=lclogp(νl|Z)logp(Z)absentsuperscriptsubscript𝑙𝑐𝑝conditionalsubscript𝜈𝑙Z𝑝Z\displaystyle=-\sum_{l}^{c}\log p(\nu_{l}|\textbf{Z})-\log p(\textbf{Z})= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | Z ) - roman_log italic_p ( Z )(27)

where the clustering portion of the loss is given by the marginal distribution

clustersubscriptcluster\displaystyle\mathcal{L}_{\text{cluster}}caligraphic_L start_POSTSUBSCRIPT cluster end_POSTSUBSCRIPT=logp(Z)absent𝑝Z\displaystyle=-\log p(\textbf{Z})= - roman_log italic_p ( Z )(28)
=1NZiNZlogp(Zi:)absent1subscript𝑁𝑍superscriptsubscript𝑖subscript𝑁𝑍𝑝subscriptZ:𝑖absent\displaystyle=-\frac{1}{N_{Z}}\sum_{i}^{N_{Z}}\log p(\textbf{Z}_{i:})= - divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT )(29)
=1NZiNZlogjNPp(Zi:|Pj:)p(Pj:)absent1subscript𝑁𝑍superscriptsubscript𝑖subscript𝑁𝑍superscriptsubscript𝑗subscript𝑁𝑃𝑝conditionalsubscriptZ:𝑖absentsubscriptP:𝑗absent𝑝subscriptP:𝑗absent\displaystyle=-\frac{1}{N_{Z}}\sum_{i}^{N_{Z}}\log\sum_{j}^{N_{P}}p(\textbf{Z}%_{i:}|\textbf{P}_{j:})p(\textbf{P}_{j:})= - divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT )(30)
=1NZiNZlogjNPk(Zi:,Pj:;τ1)+Cabsent1subscript𝑁𝑍superscriptsubscript𝑖subscript𝑁𝑍superscriptsubscript𝑗subscript𝑁𝑃𝑘subscriptZ:𝑖absentsubscriptP:𝑗absentsubscript𝜏1𝐶\displaystyle=-\frac{1}{N_{Z}}\sum_{i}^{N_{Z}}\log\sum_{j}^{N_{P}}k(\textbf{Z}%_{i:},\textbf{P}_{j:};\tau_{1})+C= - divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_k ( Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT , P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ; italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_C(31)

with the constant C𝐶Citalic_C resulting from the uniform prior on the image prototypes Pj:subscriptP:𝑗absent\textbf{P}_{j:}P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT. Finally, the discriminative portion is given a binary cross entropy loss

discrimsubscriptdiscrim\displaystyle\mathcal{L}_{\text{discrim}}caligraphic_L start_POSTSUBSCRIPT discrim end_POSTSUBSCRIPT=lclogp(νl|Z)absentsuperscriptsubscript𝑙𝑐𝑝conditionalsubscript𝜈𝑙Z\displaystyle=-\sum_{l}^{c}\log p(\nu_{l}|\textbf{Z})= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | Z )(32)
=lcνllogp(νl=1|Z)+(1νl)log(1p(νl=1|Z))absentsuperscriptsubscript𝑙𝑐subscript𝜈𝑙𝑝subscript𝜈𝑙conditional1Z1subscript𝜈𝑙1𝑝subscript𝜈𝑙conditional1Z\displaystyle=-\sum_{l}^{c}\nu_{l}\log p(\nu_{l}=1|\textbf{Z})+(1-\nu_{l})\log%(1-p(\nu_{l}=1|\textbf{Z}))= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 1 | Z ) + ( 1 - italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) roman_log ( 1 - italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 1 | Z ) )(33)

Auxilary loss. We add three additional constraints to the model to improve model quality using an auxilary loss term Lauxsubscript𝐿auxL_{\text{aux}}italic_L start_POSTSUBSCRIPT aux end_POSTSUBSCRIPT. For the first constraint, we observe that through Eq.25 we have

p(νl|Z)𝑝conditionalsubscript𝜈𝑙Z\displaystyle p(\nu_{l}|\textbf{Z})italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | Z )p(νl|P)absent𝑝conditionalsubscript𝜈𝑙P\displaystyle\leq p(\nu_{l}|\textbf{P})≤ italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | P )(34)

where

p(νl|P)::𝑝conditionalsubscript𝜈𝑙Pabsent\displaystyle p(\nu_{l}|\textbf{P}):italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | P ) :=maxjp(y=l|Pj:)absentsubscript𝑗𝑝𝑦conditional𝑙subscriptP:𝑗absent\displaystyle=\max_{j}p(y=l|\textbf{P}_{j:})= roman_max start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_p ( italic_y = italic_l | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT )(35)

We propose that adding an additional discriminative loss term to maximise this probability using binary cross entropy

p-discrim::subscriptp-discrimabsent\displaystyle\mathcal{L}_{\text{p-discrim}}:caligraphic_L start_POSTSUBSCRIPT p-discrim end_POSTSUBSCRIPT :=lclogp(νl|P)absentsuperscriptsubscript𝑙𝑐𝑝conditionalsubscript𝜈𝑙P\displaystyle=-\sum_{l}^{c}\log p(\nu_{l}|\textbf{P})= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | P )(36)
=lcνllogp(νl=1|P)+(1νl)log(1p(νl=1|P))absentsuperscriptsubscript𝑙𝑐subscript𝜈𝑙𝑝subscript𝜈𝑙conditional1P1subscript𝜈𝑙1𝑝subscript𝜈𝑙conditional1P\displaystyle=-\sum_{l}^{c}\nu_{l}\log p(\nu_{l}=1|\textbf{P})+(1-\nu_{l})\log%(1-p(\nu_{l}=1|\textbf{P}))= - ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_log italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 1 | P ) + ( 1 - italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) roman_log ( 1 - italic_p ( italic_ν start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = 1 | P ) )(37)

could result in a small performance improvement.

To ensure that the image and class prototypes have no redundancy, we add a second term to penalise duplicate image prototype vectors and class prototype vectors. We achieve this by using a contrastive loss term [14]

contrastsubscriptcontrast\displaystyle\mathcal{L}_{\text{contrast}}caligraphic_L start_POSTSUBSCRIPT contrast end_POSTSUBSCRIPT=iNclogexp(Ci:.Ci:/τc)jNcexp(Ci:.Cj:/τc)iNPexp(Pi:.Pi:/τc)jNPexp(Pi:.Pj:/τc)\displaystyle=-\sum_{i}^{N_{c}}\log\frac{\exp(\textbf{C}_{i:}.\textbf{C}_{i:}/%\tau_{c})}{\sum_{j}^{N_{c}}\exp(\textbf{C}_{i:}.\textbf{C}_{j:}/\tau_{c})}-%\sum_{i}^{N_{P}}\frac{\exp(\textbf{P}_{i:}.\textbf{P}_{i:}/\tau_{c})}{\sum_{j}%^{N_{P}}\exp(\textbf{P}_{i:}.\textbf{P}_{j:}/\tau_{c})}= - ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log divide start_ARG roman_exp ( C start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . C start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT / italic_τ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( C start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . C start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT / italic_τ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG - ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG roman_exp ( P start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . P start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT / italic_τ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( P start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT . P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT / italic_τ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG(38)

with the temperature parameter τcsubscript𝜏𝑐\tau_{c}italic_τ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT. We note that this constraint is not memory intensive, as the contrastive loss is defined only between the prototypes that belong to a particular image.

The third constraint encourages consistency in the prototypes assigned to particular patches when the image is augmented. We use a loss term inspired by Consistent Assignment for Representation Learning [10, 59] (CARL)

CARLsubscriptCARL\displaystyle\mathcal{L}_{\text{CARL}}caligraphic_L start_POSTSUBSCRIPT CARL end_POSTSUBSCRIPT=1NZiNZjNPlogp(Pj:|Zi:)p(Pj:|Zi:)absent1subscript𝑁𝑍superscriptsubscript𝑖subscript𝑁𝑍superscriptsubscript𝑗subscript𝑁𝑃𝑝conditionalsubscriptP:𝑗absentsubscriptZ:𝑖absent𝑝conditionalsuperscriptsubscriptP:𝑗absentsuperscriptsubscriptZ:𝑖absent\displaystyle=-\frac{1}{N_{Z}}\sum_{i}^{N_{Z}}\sum_{j}^{N_{P}}\log p(\textbf{P%}_{j:}|\textbf{Z}_{i:})p(\textbf{P}_{j:}^{*}|\textbf{Z}_{i:}^{*})= - divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_Z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_log italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT | Z start_POSTSUBSCRIPT italic_i : end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT )(39)

where ZsuperscriptZ\textbf{Z}^{*}Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and PsuperscriptP\textbf{P}^{*}P start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT are the patch embeddings and image prototypes from an augmented view XsuperscriptX\textbf{X}^{*}X start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT of the image X. This loss term is minimised when the softmax vectors given by Eq.22 from different views are consistent and confident. To create the augmented views we apply the same cropping and flipping augmentations, so that the patches cover the same image region, but different color and blur augmentations.

Final training objective. Our final auxiliary loss term is given by adding these three component terms.

auxsubscriptaux\displaystyle\mathcal{L}_{\text{aux}}caligraphic_L start_POSTSUBSCRIPT aux end_POSTSUBSCRIPT=p-discrim+contrast+CARLabsentsubscriptp-discrimsubscriptcontrastsubscriptCARL\displaystyle=\mathcal{L}_{\text{p-discrim}}+\mathcal{L}_{\text{contrast}}+%\mathcal{L}_{\text{CARL}}= caligraphic_L start_POSTSUBSCRIPT p-discrim end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT contrast end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT CARL end_POSTSUBSCRIPT(40)

which fully specifies the complete training objective for a single image, given by

\displaystyle\mathcal{L}caligraphic_L=cluster+discrim+auxabsentsubscriptclustersubscriptdiscrimsubscriptaux\displaystyle=\mathcal{L}_{\text{cluster}}+\mathcal{L}_{\text{discrim}}+%\mathcal{L}_{\text{aux}}= caligraphic_L start_POSTSUBSCRIPT cluster end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT discrim end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT aux end_POSTSUBSCRIPT(41)

When fitting the loss over more than one image, we using a mini-batching approach and use the average over the batch as the final optimisation loss.

Comparison to ProtoP approaches. The ProtoP family considers the relationship between prototypes and classes to be represented by a learnable linear layer. Inference is done by max-pooling the prototype activations over an image, which are multiplied by a set of weights that aggregate them to make a final prediction based on the presence or absence of prototypes within an image [9, 10]. The ComFe approach deviates from this by using two sets of prototypes, image prototypes P that describe particular components of an image, and class prototypes C which represent regions of the latent space that are relevant for describing a particular concept (e.g. a type of car). In a sense, the image prototypes of ComFe are similar to the prototypes considered by ProtoPNet and PIP-Net, as they represent key attributes of an image, while the class prototypes in ComFe take the role of the final layer of these approaches that identifies which prototypes are relevant to particular classes.

Input: Training set T𝑇Titalic_T, NEsubscript𝑁𝐸N_{E}italic_N start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPT number of epochs, f𝑓fitalic_f backbone model, Aug(.)\text{Aug}(.)Aug ( . ) augmentation strategy that creates two augmentations per image with the same cropping and flipping operations

Randomly initialise transformer decoder head gθsubscript𝑔𝜃g_{\theta}italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, input queries Q, class prototypes C and generate class assignment matrix ϕbold-italic-ϕ\bm{\phi}bold_italic_ϕ;

i=0𝑖0i=0italic_i = 0;

whilei<NE𝑖subscript𝑁𝐸i<N_{E}italic_i < italic_N start_POSTSUBSCRIPT italic_E end_POSTSUBSCRIPTdo

Randomly split T𝑇Titalic_T into B𝐵Bitalic_B mini-batches;

for(xb,yb){T1,,Tb,,TB}subscript𝑥𝑏subscript𝑦𝑏subscript𝑇1subscript𝑇𝑏subscript𝑇𝐵(x_{b},y_{b})\in\{T_{1},...,T_{b},...,T_{B}\}( italic_x start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) ∈ { italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_T start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , … , italic_T start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT }do

X=Aug(xb)XAugsubscript𝑥𝑏\textbf{X}=\text{Aug}(x_{b})X = Aug ( italic_x start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT )

ν=OneHot(yb)𝜈OneHotsubscript𝑦𝑏\nu=\text{OneHot}(y_{b})italic_ν = OneHot ( italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT )

ifusing background class prototypesthen

ν={ν,[1,,1,,1]}𝜈𝜈111\nu=\{\nu,[1,...,1,...,1]\}italic_ν = { italic_ν , [ 1 , … , 1 , … , 1 ] }; \triangleright Add the background class to all images.

endif

Z=f(X)Z𝑓X\textbf{Z}=f(\textbf{X})Z = italic_f ( X );

P=gθ(Z,Q)Psubscript𝑔𝜃ZQ\textbf{P}=g_{\theta}(\textbf{Z},\textbf{Q})P = italic_g start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( Z , Q );

p(y|P)=σ(P^.C^/τ2)ϕ:yp(y|\textbf{P})=\sigma\left(\hat{\textbf{P}}.\hat{\textbf{C}}/\tau_{2}\right)%\bm{\phi}_{:y}italic_p ( italic_y | P ) = italic_σ ( over^ start_ARG P end_ARG . over^ start_ARG C end_ARG / italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) bold_italic_ϕ start_POSTSUBSCRIPT : italic_y end_POSTSUBSCRIPT;

p(P|Z)=σ(Z^.P^/τ1)p(\textbf{P}|\textbf{Z})=\sigma\left(\hat{\textbf{Z}}.\hat{\textbf{P}}/\tau_{1%}\right)italic_p ( P | Z ) = italic_σ ( over^ start_ARG Z end_ARG . over^ start_ARG P end_ARG / italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT );

p(y|Z)=jNPp(y|Pj:)p(Pj:|Z)𝑝conditional𝑦Zsuperscriptsubscript𝑗subscript𝑁𝑃𝑝conditional𝑦subscriptP:𝑗absent𝑝conditionalsubscriptP:𝑗absentZp(y|\textbf{Z})=\sum_{j}^{N_{P}}p(y|\textbf{P}_{j:})p(\textbf{P}_{j:}|\textbf{%Z})italic_p ( italic_y | Z ) = ∑ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( italic_y | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | Z );

p(ν|P)=MaxPool(p(y|P))𝑝conditional𝜈PMaxPool𝑝conditional𝑦Pp(\nu|\textbf{P})=\text{MaxPool}(p(y|\textbf{P}))italic_p ( italic_ν | P ) = MaxPool ( italic_p ( italic_y | P ) );

p(ν|Z)=MaxPool(p(y|Z))𝑝conditional𝜈ZMaxPool𝑝conditional𝑦Zp(\nu|\textbf{Z})=\text{MaxPool}(p(y|\textbf{Z}))italic_p ( italic_ν | Z ) = MaxPool ( italic_p ( italic_y | Z ) );

Compute aux loss aux=p-discrim+contrast+CARLsubscriptauxsubscriptp-discrimsubscriptcontrastsubscriptCARL\mathcal{L}_{\text{aux}}=\mathcal{L}_{\text{p-discrim}}+\mathcal{L}_{\text{%contrast}}+\mathcal{L}_{\text{CARL}}caligraphic_L start_POSTSUBSCRIPT aux end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT p-discrim end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT contrast end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT CARL end_POSTSUBSCRIPT;

Compute final loss =cluster+discrim+auxsubscriptclustersubscriptdiscrimsubscriptaux\mathcal{L}=\mathcal{L}_{\text{cluster}}+\mathcal{L}_{\text{discrim}}+\mathcal%{L}_{\text{aux}}caligraphic_L = caligraphic_L start_POSTSUBSCRIPT cluster end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT discrim end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT aux end_POSTSUBSCRIPT;

Minimise loss \mathcal{L}caligraphic_L by updating θ𝜃\thetaitalic_θ and C;

endfor

i=i+1𝑖𝑖1i=i+1italic_i = italic_i + 1;

endwhile

Appendix B Visualisation of class prototypes

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (84)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (85)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (86)

In Fig.S5 we show the class prototype exemplars for all classes within the FGCV Aircraft, Stanford Cars and CUB200 datasets. The class prototype exemplars are given by the image prototypes from the training dataset with the smallest cosine distance to the class prototypes for each label. When visualising the class prototypes in this way, we observed that not all of them are used by ComFe to make a prediction. For example, for the CUB200 dataset only one class prototype is used to learn each particular bird species, and the other class prototypes are located in regions with a low cosine similarity to the rest of the training dataset. For the CIFAR-10 dataset, which has a much larger number of images per class, we find that two class prototypes per class are used. When the number of class prototypes are reduced in our ablation study in AppendixD, this has minimal impacts on the accuracy of ComFe.

Appendix C Additional results

\begin{overpic}[height=42.67912pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Input image\@add@centering}}\end{overpic}

\begin{overpic}[height=42.67912pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Image\\prototypes\@add@centering}}\end{overpic}

\begin{overpic}[height=42.67912pt,width=433.62pt]{example_images/1280px-HD_%transparent_picture.png}\put(0.0,50.0){\parbox{56.9055pt}{\centering Class\\prediction\@add@centering}}\end{overpic}

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (87)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (88)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (89)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (90)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (91)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (92)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (93)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (94)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (95)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (96)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (97)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (98)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (99)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (100)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (101)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (102)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (103)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (104)

The predictions of salient image features provided by ComFe are faithful under image manipulation. Fig.S6 shows that when parts of an image are removed that are not identified as important to classification, ComFe provides similar output. It also shows than when the regions of an image that are used to make a prediction are removed, ComFe may classify the image as the background class (i.e. predict that no class is present).

HeadBackboneBackgroundTest Dataset
PrototypesIN-1KIN-V2SketchIN-RIN-A
Linear [17]DINOv2 ViT-S/14 (f)81.170.941.237.518.9
ComfeDINOv2 ViT-S/14 (f)300083.073.445.342.126.4
ComfeDINOv2 ViT-S/14 (f)082.673.044.441.624.0

Background class prototypes improve the generalisability and robustness of ComFe. TableS4 shows that with and without background classes, ComFe obtains better performance, generalisation and robustness on ImageNet compared to a non-interpretable linear head. As shown in previous work [60, 61], the performance and robustness of ViT models can be improved by considering the patch tokens as well as the class tokens. We also find that the use of background class prototypes slightly improves the performance of ComFe, in addition to providing a mechanism for identifying the salient regions of an image when making a prediction.

BackboneDataset
IN-1KC10C100FoodCUBPetsCarsAircrFlowers
DINOv2 ViT-S/14 (f)83.098.389.292.187.694.691.177.599.0
DINOv2 ViT-S/14 (f) w/reg82.998.289.091.687.894.990.576.598.8
DINOv2 ViT-L/14 (f)86.799.493.694.689.295.993.683.999.4
DINOv2 ViT-L/14 (f) w/reg87.299.594.695.690.096.094.585.699.6

Including registers in the DINOv2 ViT model can improve the performance of ComFe for large models. Further work on the ViT backbone has found that including register tokens [58] can prevent embedding artefacts such as patch tokens with large norms, which improves results for bigger models like ViT-L and the larger variants. In TableS5 we observe that for small models including registers results in poorer performance for ComFe, but that for the ViT-L model, performance is improved. We also find in Fig.S7 that for the ViT-L backbone without registers, for some classes a single patch can denote the class and the image prototypes can be less uniform. However, when registers are used the image prototypes are more descriptive and we were unable to find cases where relevant features from the image were not identified by the class prototypes.

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (105)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (106)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (107)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (108)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (109)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (110)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (111)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (112)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (113)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (114)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (115)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (116)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (117)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (118)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (119)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (120)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (121)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (122)

\clipbox

.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (123)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (124)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (125)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (126)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (127)\clipbox.5 .5 .5 .5ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (128)

MethodsDatasets
IN-V2SketchIN-RIN-AIN-OIN OpenOOD
Near-OODFar-OOD
KNN+ [63]55.381.079.283.580.872.692.7
Energy [64]57.681.985.488.480.676.093.4
NN-Guide [65]57.484.385.489.483.476.694.5
ComFe - pmax(y,P)subscript𝑝max𝑦Pp_{\text{max}}(y,\textbf{P})italic_p start_POSTSUBSCRIPT max end_POSTSUBSCRIPT ( italic_y , P )57.881.681.884.681.480.290.4

ComFe can be used for out-of-distribution detection. Our approach fits distributions to the data, allowing us to consider how well these might perform in detecting out-of-distribution (OOD) samples which may have lower likelihood under the model. We find that the best approach combines the joint distribution over the image prototypes

pmax(y,P)=maxy,jp(y|Pj:)yp(Pj:|y)subscript𝑝max𝑦Psubscript𝑦𝑗𝑝conditional𝑦subscriptP:𝑗absentsubscriptsuperscript𝑦𝑝conditionalsubscriptP:𝑗absentsuperscript𝑦\displaystyle p_{\text{max}}(y,\textbf{P})=\max_{y,j}p(y|\textbf{P}_{j:})\sum_%{y^{\prime}}p(\textbf{P}_{j:}|y^{\prime})italic_p start_POSTSUBSCRIPT max end_POSTSUBSCRIPT ( italic_y , P ) = roman_max start_POSTSUBSCRIPT italic_y , italic_j end_POSTSUBSCRIPT italic_p ( italic_y | P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( P start_POSTSUBSCRIPT italic_j : end_POSTSUBSCRIPT | italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT )(42)

where the first term is defined by Eq.24 and the second by marginalising out y𝑦yitalic_y from Eq.19. TableS6 compares this approach with other methods in the literature, using OOD datasets for ImageNet including ImageNet-O [57], OpenImage-O [66], SSB-hard [67], NINCO [68], iNaturalist [69] and Describable Textures [70] combined into Near-OOD and Far-OOD sets following the OpenOOD benchmark [62]. We find that while ComFe performs well for some datasets, specialised OOD approaches such as NN-Guide [65] perform better overall on the DINO-v2 embeddings. Pruning unused class prototypes C may potentially improve the OOD detection results for ComFe.

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (129)

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (130)

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (131)

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (132)

ComFe can detect images that are likely to have an unreliable prediction. We consider the problem of identifying if a prediction made by a neural network is reliable. The OOD measure for ComFe pmax(y,P)subscript𝑝max𝑦Pp_{\text{max}}(y,\textbf{P})italic_p start_POSTSUBSCRIPT max end_POSTSUBSCRIPT ( italic_y , P ) combines confidence as well as support under the mixture model. An analogous measure could also be created with the NN-Guide OOD score, where we multiply it with the class confidence from a linear head. Reliability diagrams for both of these approaches are shown in Fig.S8, where it can be seen that the ComFe score correlates better with the observed accuracy particularly for the ImageNet-A benchmark.

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (133)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (134)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (135)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (136)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (137)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (138)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (139)
ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (140)

For some initialisations, ComFe transparently classifies using background features. Like other clustering algorithms [71], ComFe is sensitive to the initialisation of the class prototypes. As shown in Fig.S9 for the Food-101 dataset, for particular initialisation seeds ComFe can find background features that can be used to accurately classify the training and validation data. However, there appears to be a trend across both the Food-101 (Fig.S9) and Oxford Pets (Fig.S10) datasets that model initialisations resulting in more salient image features being identified have better performance.

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (141)

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (142)

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (143)

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (144)

We observe that the DINOv2 patch embeddings are influenced by their context. For example, in Fig.S11 we use k𝑘kitalic_k-means clustering to visualise the DINOv2 patch embeddings of two images containing a white plate. We find that the surface the plates rest on both fall in the background cluster (yellow), but the plate containing churros is placed in the red cluster while the plate containing a burger is placed in the blue cluster. When ComFe obtains good performance in cases the salient parts of an image are classified as the background, this likely reflects this context leakage from the backbone model.

Prototypes being selected from the background of an image was a challenge in the original ProtoPNet approach, where a pruning process was employed to remove most of the non-informative prototypes from the reasoning process [9], and this can still occur in more recent approaches such as PIP-Net [10].

BackboneBackgroundTest Dataset
PrototypesIN-RIN-A
DINOv2 ViT-S/14 (f)300057.742.6
DINOv2 ViT-B/14 (f)300067.161.7
DINOv2 ViT-L/14 (f)300072.571.9

Appendix D Ablation studies

In this section we undertake an ablation study on the hyperparameters of ComFe using the FGVC Aircraft, Oxford Pets and Stanford Cars datasets. For each set of parameters we consider 5 runs and report the mean and standard deviation of the validation accuracy.

TransformerDataset
LayersAircrPetsCars
175.3±1.394.5±0.490.3±0.4
277.2±0.994.7±0.491.1±0.1
477.1±0.694.8±0.491.6±0.2
677.9±0.794.4±0.391.2±0.2

Tranformer decoder layers. TableS8 shows that for the Stanford Cars and the FGVC Aircraft datasets, including more transformer decoder layers improves the average performance of ComFe across a number of seeds. However, for the Oxford Pets dataset there appears to be no significant relationship between accuracy and the size of the transformer decoder.

Loss termDataset
Lp-discrimsubscript𝐿p-discrimL_{\text{p-discrim}}italic_L start_POSTSUBSCRIPT p-discrim end_POSTSUBSCRIPTLCARLsubscript𝐿CARLL_{\text{CARL}}italic_L start_POSTSUBSCRIPT CARL end_POSTSUBSCRIPTLcontrastsubscript𝐿contrastL_{\text{contrast}}italic_L start_POSTSUBSCRIPT contrast end_POSTSUBSCRIPTAircrPetsCars
NYY76.4±2.294.7±0.690.5±0.7
YNY76.8±1.094.8±0.690.9±0.5
YYN76.7±0.694.6±0.390.8±0.3
YYY76.8±0.794.8±0.491.0±0.2

Loss terms. In TableS9 we explore the impact of removing loss terms on the performance of ComFe across the Stanford Cars, FGVC Aircraft and Oxford Pets datasets. We find that that they have a marginal impact on the performance of the ComFe models.

ConcentrationDataset
τ1subscript𝜏1\tau_{1}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPTτ2subscript𝜏2\tau_{2}italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTAircrPetsCars
0.050.0276.8±1.594.9±0.490.5±0.5
0.100.0176.5±0.995.0±0.390.6±0.0
0.100.0277.0±0.494.9±0.491.0±0.3
0.100.0577.8±0.594.9±0.591.1±0.6
0.200.0277.4±1.195.1±0.291.2±0.1

Concentration parameters. In TableS10 we explore the impact of the concentration parameters τ1subscript𝜏1\tau_{1}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and τ2subscript𝜏2\tau_{2}italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT on the performance of ComFe. While there is some scope for optimising the choice of these parameters for particular datasets, the performance of ComFe does not appear to be particularly sensitive to these parameters.

Label smoothingDataset
α𝛼\alphaitalic_αAircrPetsCars
0.077.2±0.694.8±0.391.0±0.3
0.177.3±0.794.8±0.491.0±0.2

Label smoothing. In TableS11 we explore the impact of the label smoothing parameter α𝛼\alphaitalic_α on the performance of ComFe. We find that label smoothing has little impact on accuracy, and is not necessary for the success of the method.

Number of prototypesDataset
NPsubscript𝑁𝑃N_{P}italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPTNC/csubscript𝑁𝐶𝑐N_{C}/citalic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT / italic_cAircrPetsCars
5176.7±0.794.6±0.491.0±0.4
5377.6±0.894.8±0.291.0±0.3
10376.1±1.294.8±0.491.3±0.3

Number of prototypes. In TableS12 we explore the impact of the number of image prototypes NPsubscript𝑁𝑃N_{P}italic_N start_POSTSUBSCRIPT italic_P end_POSTSUBSCRIPT and class prototypes per label NC/csubscript𝑁𝐶𝑐N_{C}/citalic_N start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT / italic_c on the performance of ComFe. While there may be some marginal performance gains for particular datasets for choosing a particular set of prototypes, overall these have only a small impact on the accuracy of ComFe.

ComFe: Interpretable Image Classifiers With Foundation Models, Transformers and Component Features (2024)

References

Top Articles
Latest Posts
Article information

Author: Gregorio Kreiger

Last Updated:

Views: 6216

Rating: 4.7 / 5 (77 voted)

Reviews: 84% of readers found this page helpful

Author information

Name: Gregorio Kreiger

Birthday: 1994-12-18

Address: 89212 Tracey Ramp, Sunside, MT 08453-0951

Phone: +9014805370218

Job: Customer Designer

Hobby: Mountain biking, Orienteering, Hiking, Sewing, Backpacking, Mushroom hunting, Backpacking

Introduction: My name is Gregorio Kreiger, I am a tender, brainy, enthusiastic, combative, agreeable, gentle, gentle person who loves writing and wants to share my knowledge and understanding with you.