Deceptive Risk Minimization: Out-of-Distribution Generalization
by Deceiving Distribution Shift Detectors
Deceptive Risk Minimization
This paper proposes deception as a mechanism for out-of-distribution (OOD) generalization: by learning data representations that make training data appear independent and identically distributed (iid) to an observer, we can identify stable features that eliminate spurious correlations and generalize to unseen domains. We refer to this principle as deceptive risk minimization (DRM) and instantiate it with a practical differentiable objective that simultaneously learns features that eliminate distribution shifts from the perspective of a detector based on conformal martingales while minimizing a task-specific loss. In contrast to domain adaptation or prior invariant representation learning methods, DRM does not require access to test data or a partitioning of training data into a finite number of data-generating domains. We demonstrate the efficacy of DRM on numerical experiments with concept shift and a simulated imitation learning setting with covariate shift in environments that a robot is deployed in.
An illustrative example
The starting point for our approach is the observation that the non-iid nature of data in this setting can be inferred from the sequence of training environments. The figure above shows the output from a distribution shift detector based on conformal martingales computed on observations from the training environments. This detector spikes strongly once the table color is changed. Crucially, the detector also spikes when provided with the sequence of latent features from the policy computed via ERM, indicating that the policy's latent representation encodes color information.
Now consider a policy that eliminates the distribution shift from the perspective of an observer who is only presented with latent representations from the policy. Intuitively, the data can be made to "appear iid" by eliminating sensitivity to the table color, which leads to OOD generalization to different table colors.
Algorithmic Approach
Our goal is to find features that eliminate the distribution shift between training and test settings. Since the learner is only provided with the sequence of training data, we utilize distribution shifts observed in this data as a proxy. Specifically, we learn features that are stable along the training data sequence --- in the sense that they appear iid to an observer --- while also supporting the minimization of the task-specific loss.
We formulate this learning mechanism as an adversarial game, which we refer to as deceptive risk minimization (DRM). An encoder network learns to generate representations that support minimization of a task-specific loss while simultaneously eliminating distribution shifts from the perspective of a detector presented with a sequence of learned representations. We present a practical instantiation of DRM that utilizes conformal martingales (CMs) for distribution shift detection. Concretely, the CM approach computes a quantity that grows quickly in the presence of distribution shifts, but remains small when data are exchangeable. We derive an end-to-end differentiable loss that penalizes the conformal martingale computed on encoded inputs; this loss serves to train the encoder to learn representations that eliminate distribution shifts from the perspective of the CM-based detector.
Experiments
We evaluate DRM in three sets of experiments involving concept shift and covariate shift. Across all three examples, we find that DRM leads to strong OOD generalization in settings involving spurious correlations or irrelevant distractors. In contrast, standard empirical risk minimization (ERM) fails to generalize. We also compare to invariant risk minimization (IRM), which assumes oracle access to a partitioning of training data into a finite number of domains; we find that DRM is able to match the performance of IRM without requiring such an oracle.
Concept shift: Toy 2D example
Concept shift: Colored-MNIST
Next, we consider the Colored-MNIST task introduced by Arjovsky et al. (2019). The goal is to classify MNIST images, where the digits have been colored either red or green. Similar to the toy 2D example, the color is assigned in a way that has a strong (but spurious) correlation with the label. As a result, ERM-based methods that only rely on minimizing training loss exploit the color information to make predictions; when the correlation between color and the label is reversed at test time, performance degrades significantly. In contrast, DRM learns to ignore color information, leading to strong OOD generalization. This is illustrated by the t-SNE plot above. ERM clusters data based on color (R / G), while DRM learns to ignore color information and instead clusters data based on labels.
Covariate shift: Imitation learning with distractors
We consider an imitation learning setting with covariate shift across environments that the robot is trained and deployed in. The task is to pick up and place a red block into a bowl using observations from an RGB camera. The training data consists of 300 expert demonstrations of pick-and-place locations, which are provided in different environments. A third of the demonstrations are provided with one table-and-bowl color combination, the next third with a slightly different combination, and the final third with another combination. At test time, the bowl and table background color are changed to a novel combination that significantly exaggerates the changes seen during training. This results in the performance of behavior cloning (ERM) collapsing, while DRM maintains nearly the same performance as in-distribution settings.
Acknowledgements
This work was supported by the Office of Naval Research (N00014-23-1-2148). The author is grateful to May Mei, Ola Shorinwa, Asher Hancock, David Snyder, and Apurva Badithela for helpful feedback on the paper. The website template is modified from Code as Policies.