00:00:00

Share Your Feedback 🏝️

Counterfactual Prompt Learning

Counterfactual Prompt Learning

MinWoo(Daniel) Park | Tech Blog

Read more
Previous: PL | Contrastive Preference Learning* Next: RAG, Survey | RAG Survey**

Counterfactual Prompt Learning

  • Related Project: Private
  • Category: Paper Review
  • Date: 2024-01-13

CPL: Counterfactual Prompt Learning for Vision and Language Models

  • url: https://arxiv.org/abs/2210.10362
  • pdf: https://arxiv.org/pdf/2210.10362
  • abstract: Prompt tuning is a new few-shot transfer learning technique that only tunes the learnable prompt for pre-trained vision and language models such as CLIP. However, existing prompt tuning methods tend to learn spurious or entangled representations, which leads to poor generalization to unseen concepts. Towards non-spurious and efficient prompt learning from limited examples, this paper presents a novel \(\underline{\textbf{C}}ounterfactual \underline{\textbf{P}}rompt \underline{\textbf{L}}earning (CPL)\) method for vision and language models, which simultaneously employs counterfactual generation and contrastive learning in a joint optimization framework. Particularly, CPL constructs counterfactual by identifying minimal non-spurious feature change between semantically-similar positive and negative samples that causes concept change, and learns more generalizable prompt representation from both factual and counterfactual examples via contrastive learning. Extensive experiments demonstrate that CPL can obtain superior few-shot performance on different vision and language tasks than previous prompt tuning methods on CLIP. On image classification, we achieve 3.55\% average relative improvement on unseen classes across seven datasets; on image-text retrieval and visual question answering, we gain up to 4.09\% and 25.08\% relative improvements across three few-shot scenarios on unseen test sets respectively.

Contents

TL;DR


  1. 반사실적 프롬프트 학습을 통한 비편향적 표현 학습
  2. 텍스트 기반 부정 샘플링과 제어 가능한 반사실적 생성 전략 제안
  3. 이미지 분류, 이미지-텍스트 검색, 시각 질의응답 태스크에서의 성능 향상

[서론]

대규모 시각-언어 사전학습 모델들은 개방형 시각-개념 매칭 태스크에서 우수한 성능을 보여주고 있습니다. 그러나 이런 모델들의 성능은 프롬프트 엔지니어링에 크게 의존하며, 수작업으로 프롬프트를 설계하는 것은 시간 소모적이고 최적의 해결책을 찾기 어렵습니다. 이런 문제를 해결하기 위해 프롬프트 튜닝 방법들이 제안되었으나, 경험적 위험 최소화(Empirical Risk Minimization, ERM)에 기반한 기존 방법들은 편향되거나 비효율적인 표현을 학습하는 경향이 있습니다.

본 논문에서는 이런 한계를 극복하기 위해 반사실적 프롬프트 학습(Counterfactual Prompt Learning, CPL)이라는 새로운 인과관계 기반 접근 방식을 제안합니다. CPL은 비편향적이고 효율적인 프롬프트 학습을 위해 반사실적 인퍼런스를 활용합니다.


[방법]

[텍스트 기반 부정 샘플링]

CPL의 첫 번째 단계는 텍스트 유사도를 기반으로 의미적으로 가장 유사한 부정 샘플을 찾는 것입니다. 이를 위해 BERTScore를 사용하여 프롬프트 간의 텍스트 유사도를 측정합니다.

$sim(i, j) = BERTScore(h_i, h_j)$

수식에서 $h_i$와 $h_j$는 각각 $i$번째와 $j$번째 프롬프트입니다.

[제어 가능한 반사실적 생성]

다음으로, CPL은 긍정 샘플과 부정 샘플 사이의 최소한의 비편향적 특징 변화를 식별하여 반사실적 예제를 생성합니다. 이 과정은 다음과 같은 수식으로 표현됩니다.

\[v' = (1 - u) \circ v + u \circ v^-\]
  • $v’$: 생성된 반사실적 이미지 특징
  • $v$: 긍정 이미지 특징
  • $v^-$: 부정 이미지 특징
  • $u$: 특징 변화를 제어하는 파라미터
  • $\circ$: 요소별 곱셈

[결합 최적화]

CPL은 반사실적 생성과 프롬프트 학습을 동시에 최적화하는 결합 최적화 프레임워크를 사용합니다. 이 프레임워크는 다음과 같은 목적함수를 최소화합니다.

\[\min_{p,u^*} L_{CE}(p) + \lambda \cdot L_{CL}(p, u^*) + \|u^*\|_1\]

subject to: $u^* = \arg\max_u D_{c^-}(v’)$ $v’ = (1 - u) \circ v + u \circ v^-$

  • $L_{CE}$: Cross-Entropy Loss
  • $L_{CL}$: Contrastive Learning Loss
  • $\lambda$: Contrastive Learning Loss의 가중치
  • $D_{c^-}$: 부정 클래스에 대한 판별기

Contrastive Learning Loss $L_{CL}$은 다음과 같이 정의됩니다.

\[L_{CL}(p, u^*) = -\log \left( \frac{e^{S(v,G(t))/\tau}}{e^{S(v,G(t))/\tau} + e^{S(v',G(t))/\tau}} \right)\]

수식에서 $S(\cdot,\cdot)$는 코사인 유사도 함수이고, $\tau$는 온도 파라미터


[실험 및 결과]

CPL의 성능을 평가하기 위해 세 가지 주요 태스크에서 실험을 진행했습니다.

  1. 이미지 분류: 7개의 공개 데이터셋(SUN397, Caltech101, ImageNet, OxfordPets, StanfordCars, Flowers102, Food101)을 사용하고, 16-shot 설정에서 훈련하고 전체 테스트 세트로 평가
  2. 이미지-텍스트 검색: MSCOCOFlickr30K 데이터셋을 사용하고, 0.5%, 1%, 3%의 training dataset를 사용한 few-shot 설정에서 실험을 진행
  3. 시각 질의응답: VQAv2 데이터셋을 사용하고, 마찬가지로 0.5%, 1%, 3%의 training dataset로 few-shot 학습을 수행


[주요 결과]

  1. 이미지 분류: CPL은 대부분의 데이터셋에서 CoCoOp를 능가했으며, 특히 unseen 클래스에서 평균 3.55%의 상대적 성능 향상을 보임.
  2. 이미지-텍스트 검색: CPL은 모든 few-shot 설정에서 zero-shot CLIP을 일관되게 능가했으며, 0.5% training dataset 사용 시 Recall@1 기준으로 최대 8.55%의 상대적 성능 향상을 달성
  3. 시각 질의응답: CPL은 1% training dataset 사용 시 CoCoOp 대비 최대 25.08%의 상대적 성능 향상을 보임.


[결론]

본 연구에서 제안한 반사실적 프롬프트 학습(CPL)은 비편향적이고 일반화 가능한 프롬프트 표현을 학습함으로써 다양한 시각-언어 태스크에서 기존 방법들을 능가하는 성능을 보여주었습니다. 특히 unseen 클래스나 개념에 대한 일반화 능력이 향상되었음을 확인할 수 있었습니다. 이는 반사실적 인퍼런스과 Contrastive Learning의Loss합이 편향되지 않은 의미론적 정보를 학습하는 데 효과적임을 시사합니다. 향후 연구에서는 CPL의 다양한 도메인과 언어에 대한 일반화 능력을 더욱 개선하고, 실시간 시나리오에서의 적용 가능성을 탐구할 계획이라고 합니다.

Previous: PL | Contrastive Preference Learning* Next: RAG, Survey | RAG Survey**

post contain ""

    No matching posts found containing ""