Spaces:
Running
Running
File size: 6,913 Bytes
5c7385e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
#Base classes and utilities for TokenSHAP
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import numpy as np
import pandas as pd
import re
import random
from typing import List, Dict, Optional, Tuple, Any
from tqdm.auto import tqdm
class ModelBase:
"""Base model interface"""
def generate(self, **kwargs) -> str:
"""Generate a response for the given input"""
raise NotImplementedError
class BaseSHAP:
"""Base class for SHAP value calculation with Monte Carlo sampling"""
def __init__(self, model: ModelBase, debug: bool = False):
"""
Initialize BaseSHAP
Args:
model: Model to analyze
debug: Enable debug output
"""
self.model = model
self.cache = {} # Cache for model responses
self.debug = debug
def _calculate_baseline(self, content: str) -> Dict[str, Any]:
"""Calculate baseline model response for full content"""
# Content here should already have the prefix/suffix if needed
baseline = self.model.generate(prompt=content)
if self.debug:
print(f"Baseline prediction: {baseline['label']}")
return baseline
def _prepare_generate_args(self, content: str, **kwargs) -> Dict:
"""Prepare arguments for model.generate()"""
raise NotImplementedError
def _get_samples(self, content: str) -> List[str]:
"""Get samples from content"""
raise NotImplementedError
def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict:
"""Prepare model arguments for a combination"""
raise NotImplementedError
def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str:
"""Get unique key for combination"""
raise NotImplementedError
def _get_all_combinations(self, samples: List[str], sampling_ratio: float = 0.0,
max_combinations: Optional[int] = None) -> Dict[str, Tuple[List[str], Tuple[int, ...]]]:
"""
Get all possible combinations of samples with their indices
Args:
samples: List of samples (e.g., tokens)
sampling_ratio: Ratio of combinations to sample (0-1)
max_combinations: Maximum number of combinations to generate
Returns:
Dictionary mapping combination keys to (combination, indices) tuples
"""
n = len(samples)
# Always include combinations that exclude exactly one token
essential_combinations = {}
for i in range(n):
combination = samples.copy()
del combination[i]
indices = tuple(j for j in range(n) if j != i)
key = f"omit_{i}"
essential_combinations[key] = (combination, indices)
# Calculate total possible combinations and sampling count
if sampling_ratio <= 0:
# Just return essential combinations
return essential_combinations
total_combinations = 2**n - 1 # All non-empty combinations
sample_count = int(total_combinations * sampling_ratio)
if max_combinations is not None:
sample_count = min(sample_count, max_combinations)
if sample_count <= len(essential_combinations):
return essential_combinations
# Randomly sample additional combinations
all_combinations = essential_combinations.copy()
additional_needed = sample_count - len(essential_combinations)
# Generate random combinations
combinations_added = 0
max_attempts = additional_needed * 10 # Limit attempts to avoid infinite loop
attempts = 0
while combinations_added < additional_needed and attempts < max_attempts:
# Decide how many tokens to include
subset_size = random.randint(1, n-1) # At least 1, at most n-1
# Randomly select indices
indices = tuple(sorted(random.sample(range(n), subset_size)))
# Create combination
combination = [samples[i] for i in indices]
key = f"random_{','.join(str(i) for i in indices)}"
# Only add if not already present
if key not in all_combinations:
all_combinations[key] = (combination, indices)
combinations_added += 1
attempts += 1
if self.debug and attempts >= max_attempts:
print(f"Warning: Reached max attempts ({max_attempts}) when generating combinations")
return all_combinations
def _get_result_per_combination(self, content: str, sampling_ratio: float = 0.0,
max_combinations: Optional[int] = None) -> Dict[str, Dict[str, Any]]:
"""
Get model responses for combinations of content
Args:
content: Original content
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
Dictionary mapping combination keys to response data
"""
samples = self._get_samples(content)
if self.debug:
print(f"Found {len(samples)} samples in content")
combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations)
if self.debug:
print(f"Generated {len(combinations)} combinations")
results = {}
# Process each combination
for key, (combination, indices) in tqdm(combinations.items(), desc="Processing combinations"):
comb_args = self._prepare_combination_args(combination, content)
comb_key = self._get_combination_key(combination, indices)
# Check cache first
if comb_key in self._cache:
response = self._cache[comb_key]
else:
response = self.model.generate(**comb_args)
self._cache[comb_key] = response
# Store results
results[key] = {
"combination": combination,
"indices": indices,
"response": response
}
return results
def analyze(self, content: str, sampling_ratio: float = 0.0, max_combinations: Optional[int] = None) -> pd.DataFrame:
"""
Analyze importance in content
Args:
content: Content to analyze
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
DataFrame with analysis results
"""
raise NotImplementedError
|