Please enable JavaScript in your browser.

fltech - 富士通研究所の技術ブログ

富士通研究所の研究員がさまざまなテーマで語る技術ブログ

"Selective Mixup Fine-Tuning for Optimizing Non-Decomposable Objectives" presented at ICLR 2024

Overview

Hi! I am Sho Takemori from AI Innovation Core Project, Artificial Intelligence Laboratory. Recently, in our joint research with the Indian Institute of Science (IISc), we have developed an AI technology termed SelMix that enables an efficient optimization of real-world, complex performance measures for classification, and presented our paper at ICLR 2024 as a spotlight presentation.

Paper

(Here, FRIPL: Fujitsu Research of India Private Limited, FRJ: Fujitsu Research Japan)

accepted paper

Motivation and Problem Formulation

In the paper, we propose a fine-tuning method for optimizing non-decomposable objectives in both semi-supervised and supervised learning setup. In the following, we detail our motivation and problem formulation.

Long-tailed datasets and Non-decomposable objectives

Classification accuracy is a conventional metric for deep neural classifiers for computer vision tasks. However, in practical applications, datasets are often imbalanced (i.e., the distribution of class labels is long-tailed), and we have to optimize more nuanced metrics to ensure the robustness of the classifier. More precisely, we consider optimization of non-decomposable objectives (metrics) [Narasimhan et al., 2022]. Here, a classification metric is called decomposable if there exists a score function defined on each instance, and the metric associated with a set of instances can be computed by scores of instances. Otherwise, we call a metric non-decomposable. A typical example of a decomposable metric is accuracy. Some practical metrics are non-decomposable. Examples include the mean, minimum, harmonic-mean of class-wise recalls, F1, AUC. In addition, optimizing these metrics is of practical importance, especially when the dataset is long-tailed. More specifically, we consider non-decomposable metrics that can be written as a function of (the entries of) the confusion matrix  C[h] of a classifier  h (or a score function associated the classifier). For a confusion matrix  C of a classifier, we assume the non-decomposable metric is given as  \psi(C) with a function  \psi. We list examples of non-decomposable metrics as a function of entries of a confusion matrix  C[h] in the following table. Here, for a constrained objective regarding coverage, we consider an unconstrained objective by introducing Lagrange multipliers. Here, classifier coverage for a given class label is the probability that a classifier predicts an instance as the class label and the coverage constraints are useful for representing fairness. The mean, min, H-mean recall in the table represent the average, minimum, and harmonic mean of class-wise recalls, respectively.

table of non-decomposable metrics

A classifier with a feature extractor

In the previous work [Rangwani et al., 2022], we consider optimization of non-decomposable metrics (in the semi-supervised learning setting) and proposed a framework termed cost sensitive self-training (CSST). However, CSST needs to train deep neural networks from scratch and is computationally expensive. In this work, for computational efficiency, we fine-tune a pre-trained model by "selective mixup", which we explain later. Let  K be the number of classes and  \mathcal{X} the instance space. We assume the classifier  F can be computed through the score function  h that consists of a feature extractor  g : \mathrm{X} \rightarrow \mathbb{R}^{d} and a linear layer parametrized by a matrix  W \in \mathbb{R}^{d\times K}, i.e., the classifier  F:\mathcal{X} \rightarrow [K] is given as  F(x) = \mathrm{argmax}_i h_i(x), where  h(x) = W^{T} g(x) for each instance  x \in \mathcal{X}.

Semi-supervised and supervised learning

Similarly to the previous work [Rangwani et al., 2022], our method is able to utilize unlabeled instances (i.e., semi-supervised learning setting). Although our main focus is the semi-supervised learning (SSL) setting, our method can be easily extended to the supervised learning setting and we evaluate our method in both SSL and supervised settings.

Proposed Method

Feature Mixup and (i, j)-mixup

Mixup is a data augmentation method and existing works show that mixup has better generalization performance, calibration, and robustness, especially in the standard setting, where the metric is accuracy [Zhang et al., 2018, Zhang et al., 2021, Zhong et al., 2021]. In the most basic form of the mixup method, for two samples  (x, y), (x', y') \in \mathcal{X} \times \mathbb{R}^{K}, mixup creates a new sample by a convex combination, i.e., a new sample is defined as  (\beta x + (1 -\beta)x', \beta y + (1 - \beta)y'). Here, we identify the class label with a one-hot vector in  \mathbb{R}^{K} and  \beta \in (0, 1). Our method is based on the feature mixup [Verma et al., 2019] and we consider a convex combination of feature vectors  \beta g(x) + (1 - \beta)g(x') instead of that of instances. Given two samples  (x, y), (x', y'), we define the mixup loss as follows:

 \mathcal{L}_{mixup}(g(x), g(x'), y; W) = \mathcal{L}_{SCE}(W^{T} (\beta g(x) + (1 -\beta) g(x')), y),

where  \mathcal{L}_{SCE} denotes the softmax cross-entropy loss and  \beta \sim U(\beta_{min}, 1).

Conventional mixup methods draw two samples  (x, y), (x', y') uniform randomly. However, since our objective is more complex (optimization of non-decomposable metrics), such a naive sampling strategy leads to a sub-optimal performance. Therefore, we consider  (i, j)-mixups that take into account of class labels of the samples  (x, y), (x', y'). We call the feature mixup  (\beta g(x) + (1 -\beta)g(x'), y) a  (i, j)-mixup if  y=i and  y' = j. Here, we consider a convex combination of only features and we use the class label  y for the mixup sample. For each  y \in [K], we denote by  D_y \subset \mathcal{X} the subset of instances with the label  y. In the setup of SSL, we denote  \widetilde{D}_y the subset of instances with the pseudo label  y. For each  y \in [K], we denote by  z_y the mean feature vector  \mathbb{E}_{x \sim D_y}[g(x)]. Then, the representative of the expected loss due to  (i, j)-mixups is defined as  \mathcal{L}^{mix}_{(i, j)} = \mathcal{L}_{mixup}(z_i, z_j, i; W).

Gain due to  (i, j) mixup and approximated gain

If we perform an  (i, j)-mixup, we update the weights  W of the linear layer by the directional vector  V_{ij}:= -\partial{\mathcal{L}_{(i,j)}^{mix}}/{\partial{W}}. We consider the change of the metric  \psi(C[h]) due to the update of the weights  W. Let  \eta > 0 be a small scalar. Then, by the Taylor expansion, we have

 \psi(C[W^{T} g + \eta V_{ij}^{T} g])=\psi(C[W^{T} g]) + \eta \nabla_{V_{ij}}\psi(C[W^{T} g]) + O(\eta^{2} |V_{ij}|^{2}),

where  \nabla_{V_{ij}} denotes the directional derivative. We define the gain  G_{ij} due to  (i, j)-mixups by  G_{ij} = \partial_{V_{ij}}\psi(C[W^{T} g]), \quad V_{ij}= -\partial{\mathcal{L}_{(i,j)}^{mix}}/{\partial{W}}, i.e.,  G_{ij} represents change of the metric due to  (i, j)-mixups.

Selective Mixup

Next, we explain how to select a pair  (i, j). We assume we select a mixup pair  (i, j) following a distribution  \mathcal{P} on  [K] \times [K]. Then, the expected change of the weights is given as  V = \sum_{i, j} \mathcal{P}(i, j)V_{ij} and the expected gain induced by the distribution  \mathcal{P} is given as  \mathbb{E}(G) = \sum_{i, j \in [K]\times [K]} \mathcal{P}(i, j)G_{ij}. To optimize the expected gain, we introduce a sampling strategy termed SelMix defined as  \mathcal{P}_{selmix}(i, j) = \mathrm{softmax}((sG_{ij})_{1 \le i, j\le K}), where  s > 0 is an inverse temperature parameter.

Estimation of Gains

To implement our algorithm, we need to compute gains  G_{ij}. However, definition of gains involves the confusion matrix  C_{ij}[h] = \mathbb{E}_{x, y}[1(y=i, argmax_l h_l(x)=j)], which may be non-smooth w.r.t. the weights  W (we assume directional derivative of  \psi(C[h]) w.r.t  W exists or we assume the objective  \psi is defined through an approximation of the confusion matrix). We assume there exist a matrix  \widetilde{C}[W^{T} g] called the unconstrained confusion matrix such that  C_i[W^{T} g] is (approximately) equal to  \pi_i \mathrm{softmax}(\widetilde{C}_i[W^{T} g]), where  C_i denotes the  i-th row of a matrix and  \pi_i = P(y=i).

Then, we can approximately compute gain  G_{ij} as follows (this is an informal statement and we refer to Sec. 4 and Sec. D of our paper for more precise statements):

Theorem 4.1 Assume  |V_{ij}| is sufficiently small. Then, we have the following approximation formula:

 G_{ij} \approx \sum_{1 \le k, l \le K} \frac{\partial \psi(\widetilde{C})}{\partial \widetilde{C}_{kl}}((V_{ij})_l^{T} \cdot z_k).

Overview of Algorithm

In our algorithm, we perform (i, j)-mixup and update the weight  W using the gradient of the mixup loss  \mathcal{L}_{mixup}. For each iteration  t=1,\dots, T, the algorithm computes the SelMix distribution  \mathcal{P}_{selmix}^{(t)} using the approximation formula and sample a mixup pair  (Y_1, Y_2) from the distribution. Then, instances  X_1, X_2 are sampled uniformly from  \mathcal{D}_{Y_1} and  \mathcal{D}_{Y_2} and the weights  W are updated the gradient of the mixup loss  \mathcal{L}_{mixup}. We refer to our paper for pseudo code.

Theoretical Results

We briefly introduce a theoretical result (convergence analysis) that justifies our proposed method.

In the setting of our theoretical analysis, we assume that for each iteration  t, our algorithm selects a mixup pair  (i, j) from the SelMix distribution, and updates the weights  W by  W^{(t + 1)} = W^{(t)} + \eta_t \widetilde{V}^{(t)}, where  \widetilde{V}^{(t)} is the normalized directional vector  V_{ij}/|V_{ij}|. We assume that  \psi(C[W^{T} g]) is  \gamma-smooth and concave w.r.t  W and  \widetilde{V}^{(t)} is a reasonable directional vector, more precisely, we assume there exists a constant  c>0 such that  \mathbb{E}[\widetilde{V}^{(t)}] \cdot \partial{\psi(W^{(t)})}/\partial{W} > |\partial{\psi(W^{(t)})}/\partial{W}| holds for all  t \ge 1. Let  W^{*} be an optimal parameter  \mathrm{argmax}\psi(W). Then, with some additional boundedness assumption, by a standard argument we can prove the following under further mild assumptions:

Theorem For any  t \ge 1, we have  \psi(W^{*}) - \psi(W^{(t)}) = O(1/t).

Experimental Results

To show effectiveness of our method, we have conducted experiments for optimizing various non-decomposable objectives in both semi-supervised and supervised learning settings. Here, we only introduce experimental results on a long-tailed version of the CIFAR-10 dataset (we refer to our paper for further experimental results).

Semi-Supervised Learning Setup

In experiments, we consider optimization of the minimum of coverage across classes, and the minimum (the mean, the harmonic mean, the geometric mean) of recalls across classes, that are important metrics for long-tailed datasets. The following figure shows comparison of our method (SelMix) with SoTA semi-supervised methods CSST [Rangwani et al., 2022], DASO [Oh et al., 2022], ABC [Lee et al., 2021], FixMatch [Sohn et al., 2020] with the logit-adjusted loss and the results show our method significantly improves these SoTA methods.

results in the semi-supervised setting

Supervised Learning Setup

The following figure shows comparison in the supervised setting. We compare our method with one of SoTA methods (MiSLAS), where similar to SelMix, MiSLAS [Zhong el al., 2021] also decouples the learning procedure into two stages: representation learning (i.e., learning of feature extractor) and classifier learning (i.e., fine-tuning). In the experiments, we fine-tune the pre-trained model obtained by MiSLAS with Stage 1. The experimental results also show significant improvement over the SoTA method.

Results in the supervised-setting

Conclusion

We proposed a novel fine-tuning method termed SelMiX for optimization of non-decomposable objectives, prove its validity, and demonstrate effectiveness of the proposed method across various experimental settings. Our paper has been accepted at ICLR 2024 as a spotlight presentation and we have presented our joint work at ICLR 2024.

References

  1. Harikrishna Narasimhan and Aditya K Menon. Training over-parameterized models with nondecomposable objectives. Advances in Neural Information Processing Systems, 34, 2021
  2. Harikrishna Narasimhan, Harish G Ramaswamy, Shiv Kumar Tavker, Drona Khurana, Praneeth Netrapalli, and Shivani Agarwal. Consistent multiclass algorithms for complex metrics and constraints. arXiv preprint arXiv:2210.09695, 2022.
  3. Hyuck Lee, Seungjae Shin, and Heeyoung Kim. Abc: Auxiliary balanced classifier for classimbalanced semi-supervised learning. Advances in Neural Information Processing Systems, 34: 7082–7094, 2021.
  4. Youngtaek Oh, Dong-Jin Kim, and In So Kweon. Daso: Distribution-aware semantics-oriented pseudo-label for imbalanced semi-supervised learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9786–9796, 2022.
  5. Kihyuk Sohn, David Berthelot, Nicholas Carlini, Zizhao Zhang, Han Zhang, Colin A Raffel, Ekin Dogus Cubuk, Alexey Kurakin, and Chun-Liang Li. Fixmatch: Simplifying semi-supervised learning with consistency and confidence. Advances in Neural Information Processing Systems, 33:596–608, 2020
  6. Harsh Rangwani, Shrinivas Ramasubramanian, Sho Takemori, Kato Takashi, Yuhei Umeda, and Venkatesh Babu Radhakrishnan. Cost-sensitive self-training for optimizing non-decomposable metrics., Advances in Neural Information Processing Systems, 2022
  7. Vikas Verma, Alex Lamb, Christopher Beckham, Amir Najafi, Ioannis Mitliagkas, David Lopez-Paz, and Yoshua Bengio. Manifold mixup: Better representations by interpolating hidden states. In International Conference on Machine Learning, pp. 6438–6447. PMLR, 2019.
  8. L Zhang, Z Deng, K Kawaguchi, A Ghorbani, and J Zou. How does mixup help with robustness and generalization? In International Conference on Learning Representations, 2021.
  9. Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. mixup: Beyond empirical risk minimization. In International Conference on Learning Representations, 2018.
  10. Zhisheng Zhong, Jiequan Cui, Shu Liu, and Jiaya Jia. Improving calibration for long-tailed recognition. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 16489–16498, 2021