FBDF / backend /tokenShap.py
Firas HADJ KACEM
created the interface
5c7385e
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import pandas as pd
import matplotlib.pyplot as plt
import re
from typing import List, Dict, Optional, Tuple, Any
from tqdm.auto import tqdm
from collections import defaultdict
from .base import ModelBase, BaseSHAP
from .splitters import Splitter
from .helpers import build_full_prompt, jensen_shannon_distance
class TokenSHAP(BaseSHAP):
"""Analyzes token importance in text prompts using SHAP values"""
def __init__(self,
model: ModelBase,
splitter: Splitter,
debug: bool = False,
batch_size=16):
"""
Initialize TokenSHAP
Args:
model: Model to analyze
splitter: Text splitter implementation
debug: Enable debug output
"""
super().__init__(model, debug)
self.splitter = splitter
self.prompt_prefix = ""
self.prompt_suffix = ""
self.batch_size = batch_size
def _get_samples(self, content: str) -> List[str]:
"""Get tokens from prompt"""
return self.splitter.split(content)
def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str:
return self.splitter.join(combination)
def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict:
prompt = f"{self.prompt_prefix}{self.splitter.join(combination)}{self.prompt_suffix}"
return {"prompt": prompt}
def _get_result_per_combination(self, content, sampling_ratio=0.0, max_combinations=None):
"""
Get model responses for combinations with batch processing
Args:
content: Original content
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
Dictionary mapping combination keys to response data
"""
samples = self._get_samples(content)
combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations)
# Prepare prompts for batch processing
prompts = []
comb_keys = []
comb_indices = []
for key, (combination, indices) in combinations.items():
#Call with both parameters and extract prompt from returned dict
comb_args = self._prepare_combination_args(combination, content)
prompt = comb_args["prompt"] # Extract prompt from dict
prompts.append(prompt)
comb_keys.append(key)
comb_indices.append(indices)
# Batching with error handling
all_results = []
for batch_start in range(0, len(prompts), self.batch_size):
batch_end = min(batch_start + self.batch_size, len(prompts))
batch_prompts = prompts[batch_start:batch_end]
try:
batch_results = self.model.generate_batch(batch_prompts)
all_results.extend(batch_results)
except RuntimeError as e:
if "stack expects each tensor to be equal size" in str(e):
print(f"Error in batch {batch_start//self.batch_size}: {str(e)}")
print("Falling back to individual processing for this batch")
# Fall back to individual processing with generate
for prompt in batch_prompts:
try:
single_result = self.model.generate(prompt)
all_results.append(single_result)
except Exception as inner_e:
print(f"Individual processing also failed: {str(inner_e)}")
# Provide fallback result with default values
all_results.append({
"label": "NA",
"probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34}
})
else:
# Re-raise other RuntimeErrors
raise
except Exception as other_e:
# Handle any other exceptions during batch processing
print(f"Unexpected error in batch {batch_start//self.batch_size}: {str(other_e)}")
# Fall back to individual processing
for prompt in batch_prompts:
try:
single_result = self.model.generate(prompt)
all_results.append(single_result)
except Exception:
# Provide fallback result
all_results.append({
"label": "NA",
"probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34}
})
# Attach back to combination keys
results = {}
for i, key in enumerate(comb_keys):
results[key] = {
"combination": combinations[key][0],
"indices": comb_indices[i],
"response": all_results[i]
}
return results
def _get_df_per_combination(self, responses: Dict[str, Dict[str, Any]], baseline_response: Dict[str, Any]) -> pd.DataFrame:
"""
Create DataFrame with combination results using probability-based similarity
Args:
responses: Dictionary of combination responses
baseline_response: Baseline model response
Returns:
DataFrame with results
"""
# Prepare data for DataFrame
data = []
baseline_probs = baseline_response["probabilities"]
baseline_label = baseline_response["label"]
# Process each combination response
for key, res in responses.items():
combination = res["combination"]
indices = res["indices"]
response_data = res["response"]
response_probs = response_data["probabilities"]
response_label = response_data["label"]
# Calculate probability-based similarity (lower = more similar)
prob_similarity = 1.0 - jensen_shannon_distance(baseline_probs, response_probs)
# Track the probability of the baseline's predicted class
baseline_class_prob = response_probs.get(baseline_label, 0.0)
# Add to data
data.append({
"key": key,
"combination": combination,
"indices": indices,
"response_label": response_label,
"similarity": prob_similarity,
"baseline_class_prob": baseline_class_prob,
"probabilities": response_probs
})
# Create DataFrame
df = pd.DataFrame(data)
return df
def _calculate_shapley_values(self, df: pd.DataFrame, content: str) -> Dict[str, Dict[str, float]]:
"""
Calculate Shapley values for each sample using probability distributions
Args:
df: DataFrame with combination results
content: Original content
Returns:
Dictionary mapping sample names to various Shapley values
"""
samples = self._get_samples(content)
n = len(samples)
# Initialize counters for each sample
with_count = defaultdict(int)
without_count = defaultdict(int)
with_similarity_sum = defaultdict(float)
without_similarity_sum = defaultdict(float)
with_baseline_prob_sum = defaultdict(float)
without_baseline_prob_sum = defaultdict(float)
# Process each combination
for _, row in df.iterrows():
indices = row["indices"]
similarity = row["similarity"]
baseline_class_prob = row["baseline_class_prob"]
# Update counters for each sample
for i in range(n):
if i in indices:
with_similarity_sum[i] += similarity
with_baseline_prob_sum[i] += baseline_class_prob
with_count[i] += 1
else:
without_similarity_sum[i] += similarity
without_baseline_prob_sum[i] += baseline_class_prob
without_count[i] += 1
# Calculate Shapley values for different metrics
shapley_values = {}
for i in range(n):
# Similarity-based Shapley (distribution similarity)
with_avg = with_similarity_sum[i] / with_count[i] if with_count[i] > 0 else 0
without_avg = without_similarity_sum[i] / without_count[i] if without_count[i] > 0 else 0
similarity_shapley = with_avg - without_avg
# Baseline class probability-based Shapley
with_prob_avg = with_baseline_prob_sum[i] / with_count[i] if with_count[i] > 0 else 0
without_prob_avg = without_baseline_prob_sum[i] / without_count[i] if without_count[i] > 0 else 0
prob_shapley = with_prob_avg - without_prob_avg
shapley_values[f"{samples[i]}_{i}"] = {
"similarity_shapley": similarity_shapley,
"prob_shapley": prob_shapley
}
# Normalize each type of Shapley value separately
norm_shapley = self._normalize_shapley_dict(shapley_values)
return norm_shapley
def _normalize_shapley_dict(self, shapley_dict: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]:
"""Normalize each type of Shapley value separately"""
# Get all metric types
if not shapley_dict:
return {}
metrics = list(next(iter(shapley_dict.values())).keys())
normalized = {k: {} for k in shapley_dict}
# Normalize each metric separately
for metric in metrics:
values = [v[metric] for v in shapley_dict.values()]
min_val = min(values)
max_val = max(values)
value_range = max_val - min_val
if value_range > 0:
for k, v in shapley_dict.items():
normalized[k][metric] = (v[metric] - min_val) / value_range
else:
for k, v in shapley_dict.items():
normalized[k][metric] = 0.5 # Default to middle when no variance
return normalized
def get_tokens_shapley_values(self) -> Dict[str, float]:
"""
Returns a dictionary mapping each token to its Shapley value
Returns:
Dictionary with token text as keys and Shapley values as values
"""
if not hasattr(self, 'shapley_values'):
raise ValueError("Must run analyze() before getting Shapley values")
# Extract token texts without indices
tokens = {}
for key, value in self.shapley_values.items():
token = key.rsplit('_', 1)[0] # Remove index suffix
tokens[token] = value
return tokens
# Add a method to get the Similarity-based Shapley values specifically
def get_sim_shapley_values(self) -> Dict[str, float]:
"""
Returns a dictionary mapping each token to its similarity-based Shapley value
Returns:
Dictionary with token text as keys and similarity-based Shapley values as values
"""
if not hasattr(self, 'shapley_values'):
raise ValueError("Must run analyze() before getting Shapley values")
# Extract token texts without indices and get the similarity-based metric
tokens = {}
for key, value_dict in self.shapley_values.items():
token = key.rsplit('_', 1)[0] # Remove index suffix
tokens[token] = value_dict["similarity_shapley"]
return tokens
def analyze(self, prompt: str,
sampling_ratio: float = 0.0,
max_combinations: Optional[int] = 1000) -> pd.DataFrame:
"""
Analyze token importance in a financial statement
Args:
prompt: Financial statement to analyze (without instructions)
sampling_ratio: Ratio of combinations to sample (0-1)
max_combinations: Maximum number of combinations to generate
Returns:
DataFrame with analysis results
"""
# Clean prompt
prompt = prompt.strip()
prompt = re.sub(r'\s+', ' ', prompt)
# Get baseline using full prompt with instructions
prefix = "Analyze the sentiment of this statement extracted from a financial news article. Provide your answer as either negative, positive, or neutral.. Text: "
suffix = ".. Answer: "
full_prompt = build_full_prompt(prompt, prefix, suffix)
self.baseline_response = self._calculate_baseline(full_prompt)
self.baseline_text = self.baseline_response["label"]
# Process combinations (this function will add instructions to each combination)
responses = self._get_result_per_combination(
prompt,
sampling_ratio=sampling_ratio,
max_combinations=max_combinations
)
# Create results DataFrame
self.results_df = self._get_df_per_combination(responses, self.baseline_response)
# Calculate Shapley values
self.shapley_values = self._calculate_shapley_values(self.results_df, prompt)
return self.results_df
#To update
def plot_colored_text(self, new_line: bool = False):
"""
Plot text visualization with importance colors
Args:
new_line: Whether to plot tokens on new lines
"""
if not hasattr(self, 'shapley_values'):
raise ValueError("Must run analyze() before visualization")
# Extract token texts without indices
tokens = {}
for key, value in self.shapley_values.items():
token = key.rsplit('_', 1)[0] # Remove index suffix
tokens[token] = value
num_items = len(tokens)
fig_height = num_items * 0.5 + 1 if new_line else 2
fig, ax = plt.subplots(figsize=(10, fig_height))
ax.axis('off')
y_pos = 0.9
x_pos = 0.1
step = 0.8 / (num_items)
for token, value in tokens.items():
color = plt.cm.coolwarm(value)
if new_line:
ax.text(
0.5, y_pos,
token,
color=color,
fontsize=14,
ha='center',
va='center',
transform=ax.transAxes
)
y_pos -= step
else:
ax.text(
x_pos, y_pos,
token,
color=color,
fontsize=14,
ha='left',
va='center',
transform=ax.transAxes
)
x_pos += len(token) * 0.015 + 0.02 # Adjust spacing based on token length
sm = plt.cm.ScalarMappable(
cmap=plt.cm.coolwarm,
norm=plt.Normalize(vmin=0, vmax=1)
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.05)
cbar.ax.set_position([0.05, 0.02, 0.9, 0.05])
cbar.set_label('Importance (Shapley Value)', fontsize=12)
plt.tight_layout()
plt.show()