Conflict-Averse Gradient Descent for Multi-task learning

Link
Abstract

The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point.In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.

Synth

Problem:: 다중 작업 학습에서 그래디언트 충돌 문제로 개별 작업 성능 저하 / 기존 방법들은 파레토-정지점에서 학습 중단되는 한계 존재 / 초기 파라미터에 따라 수렴점이 달라지는 예측 불가능성

Solution:: 파레토-정지점의 특성을 고려해 평균 손실 최소화하면서 개별 작업의 최악 방향으로 가지 않는 방식 제안 / 제안한 CAGrad가 파레토-정지점의 최적점 수렴에 대한 이론적 근거 제시 / 여러 태스크의 손실 함수 계산 연산량을 줄이기 위해 특정 Task만 이용하는 방식 제안

Novelty:: 기존 방식들의 파레토-정지점 문제 제기 / 이전 방식이 제안 방법의 특수 경우임을 보임 / 하이퍼파라미터 c로 알고리즘 동작 제어 가능

Note:: 전체 이익 최우선, 가장 반대하는 의견도 일정 수준 고려하는 접근법

Summary

Motivation

file-20250409212307781.png

기존 방식들은 Pareto Set에서 더이상 학습되지 않음

Method

file-20250409212400154.png

CAGrad (Conflict-Averse Gradient Descent)

Dual Problem

Method 검증

Toy Project

file-20250409213155096.png

실험 결과