Spaces:
Sleeping
Sleeping
File size: 15,852 Bytes
5c7385e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import pandas as pd
import matplotlib.pyplot as plt
import re
from typing import List, Dict, Optional, Tuple, Any
from tqdm.auto import tqdm
from collections import defaultdict
from .base import ModelBase, BaseSHAP
from .splitters import Splitter
from .helpers import build_full_prompt, jensen_shannon_distance
class TokenSHAP(BaseSHAP):
"""Analyzes token importance in text prompts using SHAP values"""
def __init__(self,
model: ModelBase,
splitter: Splitter,
debug: bool = False,
batch_size=16):
"""
Initialize TokenSHAP
Args:
model: Model to analyze
splitter: Text splitter implementation
debug: Enable debug output
"""
super().__init__(model, debug)
self.splitter = splitter
self.prompt_prefix = ""
self.prompt_suffix = ""
self.batch_size = batch_size
def _get_samples(self, content: str) -> List[str]:
"""Get tokens from prompt"""
return self.splitter.split(content)
def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str:
return self.splitter.join(combination)
def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict:
prompt = f"{self.prompt_prefix}{self.splitter.join(combination)}{self.prompt_suffix}"
return {"prompt": prompt}
def _get_result_per_combination(self, content, sampling_ratio=0.0, max_combinations=None):
"""
Get model responses for combinations with batch processing
Args:
content: Original content
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
Dictionary mapping combination keys to response data
"""
samples = self._get_samples(content)
combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations)
# Prepare prompts for batch processing
prompts = []
comb_keys = []
comb_indices = []
for key, (combination, indices) in combinations.items():
#Call with both parameters and extract prompt from returned dict
comb_args = self._prepare_combination_args(combination, content)
prompt = comb_args["prompt"] # Extract prompt from dict
prompts.append(prompt)
comb_keys.append(key)
comb_indices.append(indices)
# Batching with error handling
all_results = []
for batch_start in range(0, len(prompts), self.batch_size):
batch_end = min(batch_start + self.batch_size, len(prompts))
batch_prompts = prompts[batch_start:batch_end]
try:
batch_results = self.model.generate_batch(batch_prompts)
all_results.extend(batch_results)
except RuntimeError as e:
if "stack expects each tensor to be equal size" in str(e):
print(f"Error in batch {batch_start//self.batch_size}: {str(e)}")
print("Falling back to individual processing for this batch")
# Fall back to individual processing with generate
for prompt in batch_prompts:
try:
single_result = self.model.generate(prompt)
all_results.append(single_result)
except Exception as inner_e:
print(f"Individual processing also failed: {str(inner_e)}")
# Provide fallback result with default values
all_results.append({
"label": "NA",
"probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34}
})
else:
# Re-raise other RuntimeErrors
raise
except Exception as other_e:
# Handle any other exceptions during batch processing
print(f"Unexpected error in batch {batch_start//self.batch_size}: {str(other_e)}")
# Fall back to individual processing
for prompt in batch_prompts:
try:
single_result = self.model.generate(prompt)
all_results.append(single_result)
except Exception:
# Provide fallback result
all_results.append({
"label": "NA",
"probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34}
})
# Attach back to combination keys
results = {}
for i, key in enumerate(comb_keys):
results[key] = {
"combination": combinations[key][0],
"indices": comb_indices[i],
"response": all_results[i]
}
return results
def _get_df_per_combination(self, responses: Dict[str, Dict[str, Any]], baseline_response: Dict[str, Any]) -> pd.DataFrame:
"""
Create DataFrame with combination results using probability-based similarity
Args:
responses: Dictionary of combination responses
baseline_response: Baseline model response
Returns:
DataFrame with results
"""
# Prepare data for DataFrame
data = []
baseline_probs = baseline_response["probabilities"]
baseline_label = baseline_response["label"]
# Process each combination response
for key, res in responses.items():
combination = res["combination"]
indices = res["indices"]
response_data = res["response"]
response_probs = response_data["probabilities"]
response_label = response_data["label"]
# Calculate probability-based similarity (lower = more similar)
prob_similarity = 1.0 - jensen_shannon_distance(baseline_probs, response_probs)
# Track the probability of the baseline's predicted class
baseline_class_prob = response_probs.get(baseline_label, 0.0)
# Add to data
data.append({
"key": key,
"combination": combination,
"indices": indices,
"response_label": response_label,
"similarity": prob_similarity,
"baseline_class_prob": baseline_class_prob,
"probabilities": response_probs
})
# Create DataFrame
df = pd.DataFrame(data)
return df
def _calculate_shapley_values(self, df: pd.DataFrame, content: str) -> Dict[str, Dict[str, float]]:
"""
Calculate Shapley values for each sample using probability distributions
Args:
df: DataFrame with combination results
content: Original content
Returns:
Dictionary mapping sample names to various Shapley values
"""
samples = self._get_samples(content)
n = len(samples)
# Initialize counters for each sample
with_count = defaultdict(int)
without_count = defaultdict(int)
with_similarity_sum = defaultdict(float)
without_similarity_sum = defaultdict(float)
with_baseline_prob_sum = defaultdict(float)
without_baseline_prob_sum = defaultdict(float)
# Process each combination
for _, row in df.iterrows():
indices = row["indices"]
similarity = row["similarity"]
baseline_class_prob = row["baseline_class_prob"]
# Update counters for each sample
for i in range(n):
if i in indices:
with_similarity_sum[i] += similarity
with_baseline_prob_sum[i] += baseline_class_prob
with_count[i] += 1
else:
without_similarity_sum[i] += similarity
without_baseline_prob_sum[i] += baseline_class_prob
without_count[i] += 1
# Calculate Shapley values for different metrics
shapley_values = {}
for i in range(n):
# Similarity-based Shapley (distribution similarity)
with_avg = with_similarity_sum[i] / with_count[i] if with_count[i] > 0 else 0
without_avg = without_similarity_sum[i] / without_count[i] if without_count[i] > 0 else 0
similarity_shapley = with_avg - without_avg
# Baseline class probability-based Shapley
with_prob_avg = with_baseline_prob_sum[i] / with_count[i] if with_count[i] > 0 else 0
without_prob_avg = without_baseline_prob_sum[i] / without_count[i] if without_count[i] > 0 else 0
prob_shapley = with_prob_avg - without_prob_avg
shapley_values[f"{samples[i]}_{i}"] = {
"similarity_shapley": similarity_shapley,
"prob_shapley": prob_shapley
}
# Normalize each type of Shapley value separately
norm_shapley = self._normalize_shapley_dict(shapley_values)
return norm_shapley
def _normalize_shapley_dict(self, shapley_dict: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""Normalize each type of Shapley value separately"""
# Get all metric types
if not shapley_dict:
return {}
metrics = list(next(iter(shapley_dict.values())).keys())
normalized = {k: {} for k in shapley_dict}
# Normalize each metric separately
for metric in metrics:
values = [v[metric] for v in shapley_dict.values()]
min_val = min(values)
max_val = max(values)
value_range = max_val - min_val
if value_range > 0:
for k, v in shapley_dict.items():
normalized[k][metric] = (v[metric] - min_val) / value_range
else:
for k, v in shapley_dict.items():
normalized[k][metric] = 0.5 # Default to middle when no variance
return normalized
def get_tokens_shapley_values(self) -> Dict[str, float]:
"""
Returns a dictionary mapping each token to its Shapley value
Returns:
Dictionary with token text as keys and Shapley values as values
"""
if not hasattr(self, 'shapley_values'):
raise ValueError("Must run analyze() before getting Shapley values")
# Extract token texts without indices
tokens = {}
for key, value in self.shapley_values.items():
token = key.rsplit('_', 1)[0] # Remove index suffix
tokens[token] = value
return tokens
# Add a method to get the Similarity-based Shapley values specifically
def get_sim_shapley_values(self) -> Dict[str, float]:
"""
Returns a dictionary mapping each token to its similarity-based Shapley value
Returns:
Dictionary with token text as keys and similarity-based Shapley values as values
"""
if not hasattr(self, 'shapley_values'):
raise ValueError("Must run analyze() before getting Shapley values")
# Extract token texts without indices and get the similarity-based metric
tokens = {}
for key, value_dict in self.shapley_values.items():
token = key.rsplit('_', 1)[0] # Remove index suffix
tokens[token] = value_dict["similarity_shapley"]
return tokens
def analyze(self, prompt: str,
sampling_ratio: float = 0.0,
max_combinations: Optional[int] = 1000) -> pd.DataFrame:
"""
Analyze token importance in a financial statement
Args:
prompt: Financial statement to analyze (without instructions)
sampling_ratio: Ratio of combinations to sample (0-1)
max_combinations: Maximum number of combinations to generate
Returns:
DataFrame with analysis results
"""
# Clean prompt
prompt = prompt.strip()
prompt = re.sub(r'\s+', ' ', prompt)
# Get baseline using full prompt with instructions
prefix = "Analyze the sentiment of this statement extracted from a financial news article. Provide your answer as either negative, positive, or neutral.. Text: "
suffix = ".. Answer: "
full_prompt = build_full_prompt(prompt, prefix, suffix)
self.baseline_response = self._calculate_baseline(full_prompt)
self.baseline_text = self.baseline_response["label"]
# Process combinations (this function will add instructions to each combination)
responses = self._get_result_per_combination(
prompt,
sampling_ratio=sampling_ratio,
max_combinations=max_combinations
)
# Create results DataFrame
self.results_df = self._get_df_per_combination(responses, self.baseline_response)
# Calculate Shapley values
self.shapley_values = self._calculate_shapley_values(self.results_df, prompt)
return self.results_df
#To update
def plot_colored_text(self, new_line: bool = False):
"""
Plot text visualization with importance colors
Args:
new_line: Whether to plot tokens on new lines
"""
if not hasattr(self, 'shapley_values'):
raise ValueError("Must run analyze() before visualization")
# Extract token texts without indices
tokens = {}
for key, value in self.shapley_values.items():
token = key.rsplit('_', 1)[0] # Remove index suffix
tokens[token] = value
num_items = len(tokens)
fig_height = num_items * 0.5 + 1 if new_line else 2
fig, ax = plt.subplots(figsize=(10, fig_height))
ax.axis('off')
y_pos = 0.9
x_pos = 0.1
step = 0.8 / (num_items)
for token, value in tokens.items():
color = plt.cm.coolwarm(value)
if new_line:
ax.text(
0.5, y_pos,
token,
color=color,
fontsize=14,
ha='center',
va='center',
transform=ax.transAxes
)
y_pos -= step
else:
ax.text(
x_pos, y_pos,
token,
color=color,
fontsize=14,
ha='left',
va='center',
transform=ax.transAxes
)
x_pos += len(token) * 0.015 + 0.02 # Adjust spacing based on token length
sm = plt.cm.ScalarMappable(
cmap=plt.cm.coolwarm,
norm=plt.Normalize(vmin=0, vmax=1)
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.05)
cbar.ax.set_position([0.05, 0.02, 0.9, 0.05])
cbar.set_label('Importance (Shapley Value)', fontsize=12)
plt.tight_layout()
plt.show()
|