Causality Inspired Representation Learning for Domain Generalization

Link
Abstract

Domain generalization (DG) is essentially an out-ofdistribution problem, aiming to generalize the knowledge learned from multiple source domains to an unseen target domain. The mainstream is to leverage statistical models to model the dependence between data and labels, intending to learn representations independent of domain. Nevertheless, the statistical models are superficial descriptions of reality since they are only required to model dependence instead of the intrinsic causal mechanism. When the dependence changes with the target distribution, the statistic models may fail to generalize. In this regard, we introduce a general structural causal model to formalize the DG problem. Specifically, we assume that each input is constructed from a mix of causal factors (whose relationship with the label is invariant across domains) and non-causal factors (categoryindependent), and only the former cause the classification judgments. Our goal is to extract the causal factors from inputs and then reconstruct the invariant causal mechanisms. However, the theoretical idea is far from practical of DG since the required causal/non-causal factors are unobserved. We highlight that ideal causal factors should meet three basic properties: separated from the non-causal ones, jointly independent, and causally sufficient for the classification. Based on that, we propose a Causality Inspired Representation Learning (CIRL) algorithm that enforces the representations to satisfy the above properties and then uses them to simulate the causal factors, which yields improved generalization ability. Extensive experimental results on several widely used datasets verify the effectiveness of our approach.

Synth

Problem:: 기존 Domain Generalization 방법들이 통계적 의존성만 모델링하여 표면적 설명에 그침 / Target Distribution 변화 시 일반화 실패 / 인과 메커니즘이 아닌 상관관계 학습 (예: 기린-풀밭)

Solution:: Structural Causal Model 도입하여 DG 문제 정형화 / Causal Intervention Module (Fourier Transform으로 Non-Causal 개입) + Factorization Module (독립성 강제) + Adversarial Mask Module (모든 차원 활용)

Novelty:: Domain Generalization에 명시적 인과관계 관점 최초 도입 / 3개 모듈의 상호보완적 설계로 이론을 실제 구현 / MatchDG와 달리 명시적 Causal Factorization

Note:: Mechanistic Interpretability와 연계 가능 할 듯

Summary

Motivation

file-20250611194742353.png|406

기존 학습 방식은 X와 Y의 통계적 의존성(점선)만을 모델링 → Causal Factor S와 Non-Causal Factor U를 모두 고려해보자

Method

Causal Factors의 3가지 속성

file-20250611194909011.png|550

(a) 이상적이지 않은 Causal Factors, (b) 이상적인 Causal Factors

CIRL (Causality Inspired Representation Learning) 알고리즘

file-20250611195047622.png

1. Causal Intervention Module

2. Causal Factorization Module

3. Adversarial Mask Module

Method 검증

실험 설정

주요 실험 결과

Ablation Study

추가 분석