728x90
PEFT로 AI 모델 효율적으로 학습하기
오늘은 PEFT 알아보자.
오늘의 배움 |
|
1. PEFT (Parameter Efficient Fine-Tuning)
- 정의: 모델 전체를 재학습하지 않고, 일부 파라미터만 조정하여 학습 비용을 줄이는 방법
- 한 줄 요약: 대규모 AI 모델을 적은 자원으로 특정 작업에 맞게 효율적으로 조정하는 기술
- 특징:
- 전체 모델 중 일부 파라미터만 업데이트
- 메모리와 계산 비용 대폭 감소
- 기존 모델의 지식 유지하며 새로운 작업 수행 가능
- 필요성:
- 대규모 모델 전체 재학습 시 막대한 자원 소요
- 제한된 하드웨어로 최신 모델 활용 필요
- 다양한 도메인별 맞춤형 모델 구축 요구
- 장점/단점:
- 장점: 학습 비용 절감, 빠른 학습 속도, 성능 유지
- 단점: 일부 복잡한 태스크에서 풀 파인튜닝보다 성능 제한 가능
- 예시: LoRA, QLoRA, DoRA, Soft Prompts 등
📚 실제 예시로 이해하기[실무/현업 예시]
|
2. 핵심 개념 정리
2-1. LoRA (Low-Rank Adaptation)
- 정의: 기존 모델의 가중치를 고정한 상태에서 저랭크(Low-Rank) 매트릭스를 추가해 학습하는 방법
- 작동 원리:
- 기존 가중치 매트릭스 W에 저랭크 매트릭스 두 개의 곱인 ΔW를 더함
- W + ΔW = W + A⋅B 형태로 업데이트
- A와 B는 작은 크기의 매트릭스로, 학습 파라미터 수를 크게 줄임
- 특징:
- 랭크(r)를 조절해 학습 복잡도 조정 가능
- 원본 모델 가중치를 변경하지 않음
- 모델 추론 시 원본 가중치와 병합 가능
- 장점/단점:
- 장점: 메모리 사용량 감소, 학습 속도 향상, 원본 모델 능력 보존
- 단점: 최적의 랭크 설정에 실험 필요, 모든 층에 적용 시 효율성 차이
- 필요성: 제한된 컴퓨팅 자원으로 대규모 모델 조정 필요성 증가
- 예시
from peft import LoRAConfig, get_peft_model
from transformers import AutoModelForCausalLM
# 기본 모델 로드
base_model = AutoModelForCausalLM.from_pretrained("gpt2")
# LoRA 설정
lora_config = LoRAConfig(
r=8, # 랭크 설정
lora_alpha=32, # 스케일링 파라미터
target_modules=["query", "value"], # 적용할 모듈 지정
lora_dropout=0.1 # 드롭아웃 설정
)
# LoRA 모델 생성
model = get_peft_model(base_model, lora_config)
# 학습 가능한 파라미터 확인
print(f"학습 가능한 파라미터: {model.num_parameters(True)}")
print(f"전체 파라미터: {model.num_parameters()}")
2-2. Q-LoRA (Quantized LoRA)
- 정의: LoRA 기법을 양자화(Quantization)하여 GPU 메모리 사용량을 더욱 줄이는 방법
- 작동 원리:
- 4-bit NormalFloat(NF4) 양자화로 모델 가중치 압축
- 양자화된 기본 모델에 LoRA 적용
- 페이징(Paging)과 이중 양자화(Double Quantization) 기법 활용
- 특징:
- 원본 모델의 1/4 수준의 메모리 사용
- 학습 중에도 양자화 상태 유지
- 양자화와 LoRA의 장점 결합
- 장점/단점:
- 장점: 극도로 낮은 메모리 사용량, 대형 모델 노트북에서도 학습 가능
- 단점: 양자화로 인한 약간의 정확도 손실 가능성
- 필요성: 65B+ 대형 모델을 소형 하드웨어에서 학습시키기 위함
- 예시:
from peft import LoRAConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
# 4-bit 양자화 설정
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# 양자화된 모델 로드
base_model = AutoModelForCausalLM.from_pretrained(
"llama-7b",
quantization_config=quantization_config,
device_map="auto"
)
# 양자화 모델 준비
model = prepare_model_for_kbit_training(base_model)
# LoRA 설정 및 적용
lora_config = LoRAConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1
)
model = get_peft_model(model, lora_config)
2-3. DoRA (Decomposed LoRA)
- 정의: LoRA 기법을 확장하여 더 정밀한 파인튜닝을 수행하는 방법
- 작동 원리:
- 기존 LoRA의 저랭크 분해를 더 세분화
- 가중치 매트릭스를 크기(magnitude)와 방향(direction) 요소로 분해
- 두 요소를 별도로 업데이트하여 더 정교한 조정 가능
- 특징:
- 기존 LoRA보다 더 세밀한 가중치 조정
- 다층 업데이트 구조로 복잡한 관계 학습 가능
- 같은 파라미터 수로 더 높은 성능 달성
- 장점/단점:
- 장점: 더 정확한 모델 조정, 성능 향상
- 단점: 구현 복잡성 증가, 약간의 추가 메모리 필요
- 필요성: 보다 미세한 조정이 필요한 특화 태스크에 적합
- 예시:
# DoRA 구현 예시 (개념적 코드)
def dora_update(W, A, B, magnitude_factor):
# 기존 가중치 W를 크기와 방향으로 분해
magnitude = torch.norm(W, dim=1, keepdim=True)
direction = W / (magnitude + 1e-9)
# 방향 업데이트 (LoRA 방식)
direction_update = torch.matmul(A, B)
new_direction = direction + direction_update
new_direction = new_direction / torch.norm(new_direction, dim=1, keepdim=True)
# 크기 업데이트
new_magnitude = magnitude * magnitude_factor
# 최종 가중치 업데이트
W_new = new_direction * new_magnitude
return W_new
2-4. Soft Prompts
- 정의: 기존 모델을 수정하지 않고 학습 가능한 임베딩 벡터를 입력에 추가하는 방법
- 작동 원리:
- 텍스트 입력 앞에 학습 가능한 "가상 토큰"을 추가
- 이 토큰들은 특정 작업을 수행하도록 학습됨
- 모델 파라미터는 고정된 상태에서 토큰만 업데이트
- 특징:
- 프롬프트 엔지니어링의 자동화 버전
- 매우 적은 수의 파라미터만 학습 (수천 개 정도)
- 다양한 작업에 빠르게 적응 가능
- 장점/단점:
- 장점: 극도로 낮은 메모리 사용량, 빠른 학습 속도
- 단점: 일부 복잡한 작업에서 성능 제한, 해석 어려움
- 필요성: 최소한의 자원으로 모델을 특정 작업에 맞게 조정
- 예시
from peft import PromptTuningConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
# 모델과 토크나이저 로드
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 프롬프트 튜닝 설정
peft_config = PromptTuningConfig(
task_type="CAUSAL_LM",
num_virtual_tokens=20, # 가상 토큰 수
tokenizer=tokenizer,
prompt_tuning_init="TEXT", # 초기화 방식
prompt_tuning_init_text="분류 작업을 수행하세요:" # 초기 텍스트
)
# 프롬프트 튜닝 모델 생성
model = get_peft_model(model, peft_config)
# 학습 가능한 파라미터 확인
print(f"학습 가능한 파라미터: {model.num_parameters(True)}")
print(f"전체 파라미터: {model.num_parameters()}")
3. 비교분석표
구분 | LoRA | QLoRA | DoRA | Soft Prompts |
메모리 효율성 | 높음 | 매우 높음 | 높음 | 극도로 높음 |
학습 속도 | 빠름 | 중간 | 중간 | 매우 빠름 |
성능 유지도 | 좋음 | 좋음 | 매우 좋음 | 중간 |
구현 복잡성 | 낮음 | 중간 | 높음 | 낮음 |
학습 파라미터 비율 | ~1% | ~1% | ~1-2% | ~0.1% |
적합한 사용 사례 | 일반적인 태스크 | 대형 모델, 제한된 하드웨어 | 정밀한 조정이 필요한 태스크 | 간단한 작업, 빠른 적응 |
코드 예시 | lora_config = LoRAConfig(r=8) | load_in_4bit=True | 복잡한 분해 로직 | num_virtual_tokens=20 |
활용 사례 | 텍스트 생성, 분류 | 대규모 모델 학습 | 의료, 법률 등 전문 분야 | 간단한 지시 따르기 |
728x90
'Develop > AI' 카테고리의 다른 글
DPO를 알아보자. (0) | 2025.03.22 |
---|---|
RLHF을 알아보자. (1) | 2025.03.22 |
파인튜닝을 알아보자. (0) | 2025.03.16 |
프롬프트 엔지니어링 - Chain of Thought (CoT) 알아보자. (0) | 2025.03.16 |
RAG (검색 증강 생성) 알아보자. (0) | 2025.03.15 |