확률적 샘플링 방법 (Stochastic Sampling)

확률적 샘플링은 모델의 확률 분포에서 토큰을 무작위로 선택하는 방법으로, 다양성과 창의성을 제공합니다.

2.1 기본 샘플링 전략

Temperature Sampling

개념

  • softmax 분포에서 온도 매개변수로 무작위성 조절
  • 가장 기본적이고 널리 사용되는 확률적 샘플링 방법

특징

  • ✅ 구현이 간단
  • ✅ 무작위성 조절 가능
  • ✅ 창의적 출력 생성
  • ❌ 일관성 부족
  • ❌ 품질 예측 어려움

온도별 특성

  • 낮은 온도 (0에 가까움): 예측 가능하고 보수적인 출력
  • 높은 온도 (1 이상): 창의적이고 다양한 출력
  • 중간 온도 (0.7-0.9): 균형잡힌 출력

구현 예시

def temperature_sampling(model, input_ids, max_length, temperature=1.0):
    for _ in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :] / temperature
        
        # softmax로 확률 분포 생성
        probs = torch.softmax(logits, dim=-1)
        
        # 확률에 따라 토큰 샘플링
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
    
    return input_ids

온도별 비교

# 다양한 온도에서의 샘플링
temperatures = [0.1, 0.5, 1.0, 1.5, 2.0]
for temp in temperatures:
    result = temperature_sampling(model, input_ids, max_length, temp)
    print(f"Temperature {temp}: {tokenizer.decode(result)}")

Top-k Sampling

개념

  • 상위 k개의 가능한 토큰에서만 샘플링
  • 고정된 k값으로 후보 집합 제한

특징

  • ✅ 구현이 간단
  • ✅ 무작위성 제한
  • ✅ 품질 향상
  • ❌ 적응성 부족
  • ❌ k값 선택의 어려움

구현 예시

def top_k_sampling(model, input_ids, max_length, k=50):
    for _ in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]
        
        # 상위 k개 토큰 선택
        top_k_logits, top_k_indices = torch.topk(logits, k)
        
        # 선택된 토큰들에 대해서만 softmax 적용
        top_k_probs = torch.softmax(top_k_logits, dim=-1)
        
        # 샘플링
        selected_idx = torch.multinomial(top_k_probs, num_samples=1)
        next_token = top_k_indices[0, selected_idx]
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
    
    return input_ids

Top-p (Nucleus) Sampling

개념

  • 누적 확률이 p를 초과하는 최소 토큰 집합에서 샘플링
  • 동적으로 후보 집합 크기 조절

특징

  • ✅ top-k보다 더 유연하고 적응적
  • ✅ 모델의 확신도에 따른 동적 조절
  • ✅ 품질과 다양성의 균형
  • ❌ 구현이 약간 복잡

구현 예시

def top_p_sampling(model, input_ids, max_length, p=0.9):
    for _ in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]
        
        # 확률 분포 계산
        probs = torch.softmax(logits, dim=-1)
        
        # 확률을 내림차순으로 정렬
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        
        # 누적 확률 계산
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        
        # p를 초과하는 토큰들 제거
        sorted_indices_to_remove = cumulative_probs > p
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = 0
        
        # 유효한 토큰들만 선택
        valid_indices = sorted_indices[~sorted_indices_to_remove]
        valid_probs = sorted_probs[~sorted_indices_to_remove]
        
        # 정규화된 확률로 샘플링
        normalized_probs = valid_probs / valid_probs.sum()
        selected_idx = torch.multinomial(normalized_probs, num_samples=1)
        next_token = valid_indices[selected_idx]
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
    
    return input_ids

2.2 고급 샘플링 방법

Min-p Sampling

개념

  • 상위 토큰의 확률에 따라 샘플링 임계값을 동적 조정
  • 모델의 확신도에 기반한 적응적 토큰 선택

특징

  • ✅ 높은 온도에서도 일관성과 다양성 균형 유지
  • ✅ 모델의 확신도에 따른 적응적 조절
  • ✅ 창의성과 품질의 균형
  • ❌ 구현 복잡도 증가

논문

  • “Min-p Sampling for Creative and Coherent LLM Outputs” (2024)

구현 예시

def min_p_sampling(model, input_ids, max_length, min_p=0.1):
    for _ in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]
        
        # 확률 분포 계산
        probs = torch.softmax(logits, dim=-1)
        
        # 최소 확률 임계값 적용
        min_prob = probs.max() * min_p
        valid_mask = probs >= min_prob
        
        if valid_mask.sum() == 0:
            valid_mask = probs >= probs.max()
        
        # 유효한 토큰들만 선택
        valid_probs = probs * valid_mask
        valid_probs = valid_probs / valid_probs.sum()
        
        # 샘플링
        next_token = torch.multinomial(valid_probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)
    
    return input_ids

Typical Sampling

개념

  • 통계적으로 “전형적인” 토큰 선택
  • 너무 예측 가능하거나 희귀하지 않은 토큰 선택

특징

  • ✅ 자연스러운 텍스트 생성
  • ✅ 극단적 확률 방지
  • ✅ 일관성 향상
  • ❌ 구현 복잡도

핵심 아이디어

def typical_sampling(model, input_ids, max_length, typical_p=0.9):
    for _ in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]
        
        # 확률 분포 계산
        probs = torch.softmax(logits, dim=-1)
        
        # 엔트로피 계산
        entropy = -torch.sum(probs * torch.log(probs + 1e-10))
        
        # typicality 점수 계산
        log_probs = torch.log(probs + 1e-10)
        typicality = torch.abs(log_probs + entropy)
        
        # typical_p에 해당하는 토큰들 선택
        sorted_typicality, sorted_indices = torch.sort(typicality)
        cutoff_idx = int(typical_p * len(sorted_typicality))
        
        valid_indices = sorted_indices[:cutoff_idx]
        valid_probs = probs[valid_indices]
        valid_probs = valid_probs / valid_probs.sum()
        
        # 샘플링
        selected_idx = torch.multinomial(valid_probs, num_samples=1)
        next_token = valid_indices[selected_idx]
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
    
    return input_ids

Mirostat Sampling

개념

  • top-k 토큰에서 샘플링하는 동안 텍스트의 perplexity 비율을 직접 제어
  • 목표 surprise 값을 유지하기 위해 피드백 기반으로 k값 조정

특징

  • ✅ 일정한 perplexity 유지
  • ✅ 일관된 품질 보장
  • ✅ 피드백 기반 적응
  • ❌ 구현 복잡도
  • ❌ 계산 오버헤드

논문

  • “mirostat: a neural text decoding algorithm” (2020)

구현 예시

def mirostat_sampling(model, input_ids, max_length, target_surprise=1.0, learning_rate=0.1):
    k = 50  # 초기 k값
    
    for _ in range(max_length):
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]
        
        # top-k 샘플링
        top_k_logits, top_k_indices = torch.topk(logits, k)
        top_k_probs = torch.softmax(top_k_logits, dim=-1)
        
        # 샘플링
        selected_idx = torch.multinomial(top_k_probs, num_samples=1)
        next_token = top_k_indices[0, selected_idx]
        
        # surprise 계산 (선택된 토큰의 확률)
        selected_prob = top_k_probs[0, selected_idx]
        surprise = -torch.log(selected_prob)
        
        # k값 조정
        error = target_surprise - surprise
        k = max(1, min(100, k + learning_rate * error))
        
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
    
    return input_ids

η-Sampling

개념

  • 엔트로피 의존적 임계값 아래의 확률을 가진 단어를 잘라내기

특징

  • ✅ 엔트로피 기반 적응적 샘플링
  • ✅ 자연스러운 확률 분포 유지
  • ❌ 구현 복잡도

샘플링 방법 비교

성능 비교표

방법 속도 품질 다양성 일관성 구현 난이도
Temperature ⭐⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐ ⭐⭐⭐⭐⭐
Top-k ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐
Top-p ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐
Min-p ⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐
Typical ⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐⭐⭐ ⭐⭐
Mirostat ⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐⭐⭐⭐

권장 사용 시나리오

  • Temperature: 창의적 글쓰기, 브레인스토밍
  • Top-k/Top-p: 일반적인 텍스트 생성
  • Min-p: 높은 온도에서도 일관성 필요한 경우
  • Typical: 자연스러운 대화, 스토리텔링
  • Mirostat: 일정한 품질이 요구되는 경우

하이브리드 접근법

여러 샘플링 방법을 조합하여 사용할 수 있습니다:

def hybrid_sampling(model, input_ids, max_length, method='top_p', **kwargs):
    if method == 'top_p':
        return top_p_sampling(model, input_ids, max_length, **kwargs)
    elif method == 'min_p':
        return min_p_sampling(model, input_ids, max_length, **kwargs)
    elif method == 'typical':
        return typical_sampling(model, input_ids, max_length, **kwargs)
    elif method == 'mirostat':
        return mirostat_sampling(model, input_ids, max_length, **kwargs)
    else:
        return temperature_sampling(model, input_ids, max_length, **kwargs)

This site uses Just the Docs, a documentation theme for Jekyll.