Task Arithmetic in the Tangent Space: Improved Editing of Pre-Trained Models

Link
Abstract

Task arithmetic has recently emerged as a cost-effective and scalable approach to edit pre-trained models directly in weight space: By adding the fine-tuned weights of different tasks, the model's performance can be improved on these tasks, while negating them leads to task forgetting. Yet, our understanding of the effectiveness of task arithmetic and its underlying principles remains limited. We present a comprehensive study of task arithmetic in vision-language models and show that weight disentanglement is the crucial factor that makes it effective. This property arises during pre-training and manifests when distinct directions in weight space govern separate, localized regions in function space associated with the tasks. Notably, we show that fine-tuning models in their tangent space by linearizing them amplifies weight disentanglement. This leads to substantial performance improvements across multiple task arithmetic benchmarks and diverse models. Building on these findings, we provide theoretical and empirical analyses of the neural tangent kernel (NTK) of these models and establish a compelling link between task arithmetic and the spatial localization of the NTK eigenfunctions. Overall, our work uncovers novel insights into the fundamental mechanisms of task arithmetic and offers a more reliable and effective approach to edit pre-trained models through the NTK linearization.

Synth

Problem:: Task Arithmetic의 작동 메커니즘에 대한 불명확한 이해

Solution:: Task Arithmetic의 핵심 메커니즘으로 Weight Disentanglement 개념 제시/Tangent Space에서 직접 파인튜닝하는 새로운 접근법 개발

Novelty:: Task Arithmetic이 선형성이 아닌 Weight Disentanglement에 의해 작동함을 규명/사전 학습 과정에서 Eigenfunction 국소화가 자연스럽게 발생함을 발견

Note:: 사전 학습은 의미적으로 관련된 특성들을 구조화하는 과정

Summary

Motivation

Task Arithmetic과 핵심 특성

기존 가설: Linear Regime

Neural Tangent Kernel (NTK)

f(x;θ)f(x;θ0)+(θθ0)Tθf(x;θ0) kNTK(x,x)=θf(x;θ0)Tθf(x;θ0) f(x;θ)f(x;θ0)+(θθ0)Tθf(x;θ0)=f(x;θ0)+i=1nαikNTK(xi,x)

실제 원인: Weight Disentanglement

file-20250326221210352.png

Weight Disentanglement에 대한 시각적 설명

f(x;θ0+t=1Tαtτt)=t=1Tgt(x;αtτt)+g0(x)=t=1Tf(x;θ0+αtτt)1(xDt)+f(x;θ0)1(xt[T]Dt)
  1. gt(x;αtτt)=f(x;θ0+αtτt)=0 for xDt and t=1,,T
  2. g0(x)=0 for xt[T]Dt

이론적 설명

Method

file-20250326231545536.png|675

f(θ): Non-Linear FT, flin(θ): Post-Hoc, flin(θlin): Linearized FT

  1. 비선형 파인튜닝(Non-Linear FT):
    • 사전 학습된 모델 파라미터 θ0에서 시작하여 직접 파인튜닝
    • 결과 모델: f(x;θ0+τ), 여기서 τ=θθ0
    • 모델 편집: f(x;θ0+tαtτt)
    • 비선형 레짐에서 작동하여 Weight Disentanglement가 제한적
  2. 사후 선형화(Post-Hoc Linearization):
    • 비선형 파인튜닝된 모델을 선형 근사
    • 결과 모델: flin(x;θ0+τ)=f(x;θ0)+τTθf(x;θ0)
    • 모델 편집: flin(x;θ0+tαtτt)=f(x;θ0)+(tαtτt)Tθf(x;θ0)
    • 단일 태스크 성능은 저하되지만 Weight Disentanglement는 향상
  3. 선형화된 파인튜닝(Linearized FT):
    • 원본 모델 파라미터 θ0를 고정하고 선형화된 모델을 직접 파인튜닝
    • 결과 모델: flin(x;θ0+τlin)=f(x;θ0)+τlinTθf(x;θ0)
    • 모델 편집: flin(x;θ0+tαtτlin,t)=f(x;θ0)+(tαtτlin,t)Tθf(x;θ0)
    • 원본 모델 파라미터 동결(freeze)하고 차이값 τlin만 학습
    • 실제 구현 방법
      • 원본 모델 파라미터 θ0 고정(freeze)
      • 새로운 파라미터 τlin을 도입하여 이를 학습
      • 각 순전파 과정에서 초기 출력 f(x;θ0)에 그래디언트-벡터 곱 τlinTθf(x;θ0)을 더하여 계산
      • Jacobian-vector product 구현을 통해 계산 효율성 확보

Method 검증