Until recently ANNs were unable to provide reliable measures of their prediction uncertainty, which often suffer from over-confidence. Model ensembles can yield improvements both in accuracy and uncertainty estimations. Their higher computational costs motivates their distillation into a single network. While accuracy performance is usually kept after distillation, the information on the diversity of the ensemble (different types of uncertainty) is lost.
This paper presents a way to distill an ensemble of models while maintaining the learned uncertainty.
Consider an ensemble of M models in a k-classes classification task: \(\hat{\mathcal{M}} = \left\{\mathcal{M}_1, ..., \mathcal{M}_M \right\}\) with parameters: \(\hat{\theta} = \left\{\theta_1, ..., \theta_M \right\}\). Where \(\theta_m\) can be seen as a sample from some underlying parameter distribution: \(\theta_m \sim p(\theta \mid \mathcal{D})\).
Consider now the categorical distribution which the m-th model of the ensemble yields when a data-point \(x^\star\) is evaluated: \(\pi_m = \left[ P(y=y_1 \mid x^\star , \theta_m), ..., P(y=y_1 \mid x^\star , \theta_m) \right]\). I.e. \(\mathcal{M}_m(x^\star) = \pi_m\).
From these categorical distributions (\(\pi_m\)) we can see the sources of uncertainty of our ensemble. \(\pi_m\) can be understood as the barycentric coordinates in a k-dimensional simplex. Each corner of the simplex represents a different class in the classification task. Geometrically, the probability of \(\pi\) being a particular class is given by its distance to that particular corner. For instance, in a 3-class classification problem, we can observe the following behaviors:
The closer to the simplex center a \(\pi_m\) distribution is, the higher the entropy \(\mathcal{H}\) of that distribution (more uncertainty).
This interpretation helps us understand the following identities:
Expected data uncertainty: \(E_{p(\theta \mid \mathcal{D})} \left[ \mathcal{H} \left[ P(y \mid x^\star, \theta )\right] \right]\). That is: the average of entropies each model of the ensemble has.
Total uncertainty: \(\mathcal{H} \left[ E_{p(\theta \mid \mathcal{D})} \left[ P(y \mid x^\star, \theta)\right] \right]\) That is: the spread or “disagreement” between models in the ensemble. I.e. the entropy of the average of the predictions.
Therefore, from the ensemble we can infer the model uncertainty \(\mathcal{MI}\) as:
\begin{equation} \mathcal{MI} \left[ y, \theta \mid x^\star, \mathcal{D} \right] = \mathcal{H} \left[ E_{p(\theta \mid \mathcal{D})} \left[ P(y \mid x^\star, \theta)\right] \right] - E_{p(\theta \mid \mathcal{D})} \left[ \mathcal{H} \left[ P(y \mid x^\star, \theta )\right] \right] \end{equation}
In order to maintain both the predictive accuracy and diversity of the ensemble the authors use prior networks. Prior networks: \(p(\pi \mid x; \phi)\) model a distribution over categorical output distributions \(\left( \{\pi_m \}_m \right)\). The Dirichlet distribution is chosen for its tractable analytic properties (allows closed-form expressions):
\begin{equation} p(\pi \mid x; \hat \phi) = Dir (\pi \mid \hat \alpha) \end{equation}
Where \(\hat \alpha\) is the concentration parameters vector: \(\hat \alpha_c > 0, \hat \alpha_0 = \sum_{c=1}^k \hat \alpha_c\). And \(\hat \phi\) are the set of parameters which map each input data-point \(x\) to its associated concentration parameters \(\hat \alpha\): \(\hat \alpha = f (x; \hat \phi)\). This parameters can be fitted using MLE on a dataset with all input and output of each network of the ensemble: \(\mathcal{D_e} = \left\{ x_i \pi_{i, 1:M} \right\}_{i=1}^N \sim \hat p (x \mid \pi)\). To do so, we simply minimize the following loss:
\begin{equation} \mathcal{L} \left(\phi, \mathcal{D_e} \right) = - E_{p(x)} \left[ E_{\hat p (\pi \mid x)} \left[ \log p (\pi \mid x ; \phi) \right] \right] \end{equation}
Which means (if my interpretation is not mistaken) that we get the parameters as:
\begin{equation} \phi^\star = \arg \min_\phi - \sum_{(x_i, \pi_{i, j}) \in \mathcal{D_e}} \log p(\pi_{i, j} \mid x_i, \phi) \end{equation}
Often training output distributions will be very sharp on a corner of the simplex (as in figure 2.a). Nevertheless, the initial parameters of the Dirichlet distribution are closer to the center of the simplex (assume unknown). Training with this disparity can be challenging and authors introduce a temperature annealing schedule. They “heat” or “move” first optimization steps to make the distributions more uncertain and then gradually decrease this temperature.
The experimented on generating the following two datasets:
Then train and assess the uncertainty of:
For reference, a single network had an error of \(13.21\%\)
We can visualize the uncertainty sources in the following plots:
To overcome that issue, the researchers sample the \(Aux\) data-points and label them with the ensemble guesses. Then use these predictions (in combination with the previous ones) to re-train the distilled model, overcoming the issue.
Without the \(Aux\) data the distilled model fails to capture knowledge uncertainty (dark areas).
They then run similar tests on CIFAR-10, CIFAR-100 and TIM datasets reaching the following conclusions: