Agglomerator++: interpretable part-whole hierarchies and latent space representations in neural networks

1University of Trento, 2CNIT
Teaser image.

Abstract

Deep neural networks achieve outstanding results in a large variety of tasks, often outperforming human experts. However, a known limitation of current neural architectures is the poor accessibility in understanding and interpreting the network's response to a given input. This is directly related to the huge number of variables and the associated non-linearities of neural models, which are often used as black boxes. This lack of transparency, particularly in crucial areas like autonomous driving, security, and healthcare, can trigger skepticism and limit trust, despite the networks' high performance.

In this work, we want to advance the interpretability in neural networks. We present Agglomerator++, a framework capable of providing a representation of part-whole hierarchies from visual cues and organizing the input distribution to match the conceptual-semantic hierarchical structure between classes. We evaluate our method on common datasets, such as SmallNORB, MNIST, FashionMNIST, CIFAR-10, and CIFAR-100, showing that our solution delivers a more interpretable model compared to other state-of-the-art approaches.

Pretraining

Pretraining. During the pre-training phase, a masked input is given to the network. The reconstruction loss ๐“›Recon depends on how well the masked patches of the input image are reconstructed. The loss is attached to level Lt1 because the network presents more detailed features at a lower level, while the representation becomes more abstract at higher levels, making them less suitable for reconstructing the input image. At the same time, enforcing the minimization of the regularisation losses ๐“›d on the last level LtK encourages the network to display more definite islands of agreement at higher levels.

Training for image classification

Training. During the training phase, image samples are fed through the network to obtain their neural representation. The information is routed for at least 2K iterations to ensure that it is propagated through the network from the bottom level upward thanks to the bottom-up networks and back downward thanks to the top-down networks. The classification loss ๐“›2 is associated with the last level because the higher-level features are more suitable for the classification task.

Architecture

Architecture.

Architecture of our Agglomerator++ model (center) with information routing (left) and detailed structure of building elements (right). Each cube represents a level ltk. Top: (a) legend for the arrows in the figure, representing the top-down network NTD(lt-1k+1) and the positional embedding p(h,w), the bottom-up network NBU(lt-1k-1), attention mechanism A(Lt-1k) and time step t. Left: (b) Contribution to the value of level ltk given by lt-1k, NTD(lt-1k+1) and NBU(lt-1k-1). (c) The attention mechanism A(Lt-1k) facilitates information sharing between lt-1k โˆˆ Lt-1k. The positional embedding p(h,w) is different for each column C(h,w). All levels belonging to the same hyper-column C(h,w) share the positional embedding p(h,w). Center: bottom to top, the architecture consists of the tokenizer module, followed by the columns C(h,w), with each level ltk connected to the neighbors with NTD(lt-1k+1) and NBU(lt-1k-1). Right: (d) Structure of the top-down network NTD(lt-1k+1) and the bottom-up network NBU(lt-1k-1).

Latent Spaces

Latent Spaces.

The 2D latent space representation of multiple methods trained on the CIFAR-10 dataset through PCA is illustrated. The PCA reduces data from multidimensional to 2D. The legend (f) classifies samples into Vehicles and Animals following the WordNet hierarchy (Miller, 1995). All methods (a, b, c, d, e) cluster samples between super-classes. The MLP-based methods (c, d) offer superior super-class separation, while placing similar samples together. Our method (c) optimizes inter-class and intra-class separability. The overlap percentage O denotes areas prone to severe hierarchical errors (Bertinetto, 2020).

Islands of agreement

Training. Illustration of the evolving islands of agreement at varying K levels for MNIST and CIFAR-10 dataset samples. Displaying agreement vectors for each patch at each k level post 300 epochs of pre-training. Level k=1 functions akin to a feature extractor with minimal neighbor agreement. Lastly, at level k=5, two islands surface representing the object and the background.

BibTeX

Coming soon