# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors import pandas as pd import matplotlib.pyplot as plt import re from typing import List, Dict, Optional, Tuple, Any from tqdm.auto import tqdm from collections import defaultdict from .base import ModelBase, BaseSHAP from .splitters import Splitter from .helpers import build_full_prompt, jensen_shannon_distance class TokenSHAP(BaseSHAP): """Analyzes token importance in text prompts using SHAP values""" def __init__(self, model: ModelBase, splitter: Splitter, debug: bool = False, batch_size=16): """ Initialize TokenSHAP Args: model: Model to analyze splitter: Text splitter implementation debug: Enable debug output """ super().__init__(model, debug) self.splitter = splitter self.prompt_prefix = "" self.prompt_suffix = "" self.batch_size = batch_size def _get_samples(self, content: str) -> List[str]: """Get tokens from prompt""" return self.splitter.split(content) def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str: return self.splitter.join(combination) def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict: prompt = f"{self.prompt_prefix}{self.splitter.join(combination)}{self.prompt_suffix}" return {"prompt": prompt} def _get_result_per_combination(self, content, sampling_ratio=0.0, max_combinations=None): """ Get model responses for combinations with batch processing Args: content: Original content sampling_ratio: Ratio of combinations to sample max_combinations: Maximum number of combinations Returns: Dictionary mapping combination keys to response data """ samples = self._get_samples(content) combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations) # Prepare prompts for batch processing prompts = [] comb_keys = [] comb_indices = [] for key, (combination, indices) in combinations.items(): #Call with both parameters and extract prompt from returned dict comb_args = self._prepare_combination_args(combination, content) prompt = comb_args["prompt"] # Extract prompt from dict prompts.append(prompt) comb_keys.append(key) comb_indices.append(indices) # Batching with error handling all_results = [] for batch_start in range(0, len(prompts), self.batch_size): batch_end = min(batch_start + self.batch_size, len(prompts)) batch_prompts = prompts[batch_start:batch_end] try: batch_results = self.model.generate_batch(batch_prompts) all_results.extend(batch_results) except RuntimeError as e: if "stack expects each tensor to be equal size" in str(e): print(f"Error in batch {batch_start//self.batch_size}: {str(e)}") print("Falling back to individual processing for this batch") # Fall back to individual processing with generate for prompt in batch_prompts: try: single_result = self.model.generate(prompt) all_results.append(single_result) except Exception as inner_e: print(f"Individual processing also failed: {str(inner_e)}") # Provide fallback result with default values all_results.append({ "label": "NA", "probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34} }) else: # Re-raise other RuntimeErrors raise except Exception as other_e: # Handle any other exceptions during batch processing print(f"Unexpected error in batch {batch_start//self.batch_size}: {str(other_e)}") # Fall back to individual processing for prompt in batch_prompts: try: single_result = self.model.generate(prompt) all_results.append(single_result) except Exception: # Provide fallback result all_results.append({ "label": "NA", "probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34} }) # Attach back to combination keys results = {} for i, key in enumerate(comb_keys): results[key] = { "combination": combinations[key][0], "indices": comb_indices[i], "response": all_results[i] } return results def _get_df_per_combination(self, responses: Dict[str, Dict[str, Any]], baseline_response: Dict[str, Any]) -> pd.DataFrame: """ Create DataFrame with combination results using probability-based similarity Args: responses: Dictionary of combination responses baseline_response: Baseline model response Returns: DataFrame with results """ # Prepare data for DataFrame data = [] baseline_probs = baseline_response["probabilities"] baseline_label = baseline_response["label"] # Process each combination response for key, res in responses.items(): combination = res["combination"] indices = res["indices"] response_data = res["response"] response_probs = response_data["probabilities"] response_label = response_data["label"] # Calculate probability-based similarity (lower = more similar) prob_similarity = 1.0 - jensen_shannon_distance(baseline_probs, response_probs) # Track the probability of the baseline's predicted class baseline_class_prob = response_probs.get(baseline_label, 0.0) # Add to data data.append({ "key": key, "combination": combination, "indices": indices, "response_label": response_label, "similarity": prob_similarity, "baseline_class_prob": baseline_class_prob, "probabilities": response_probs }) # Create DataFrame df = pd.DataFrame(data) return df def _calculate_shapley_values(self, df: pd.DataFrame, content: str) -> Dict[str, Dict[str, float]]: """ Calculate Shapley values for each sample using probability distributions Args: df: DataFrame with combination results content: Original content Returns: Dictionary mapping sample names to various Shapley values """ samples = self._get_samples(content) n = len(samples) # Initialize counters for each sample with_count = defaultdict(int) without_count = defaultdict(int) with_similarity_sum = defaultdict(float) without_similarity_sum = defaultdict(float) with_baseline_prob_sum = defaultdict(float) without_baseline_prob_sum = defaultdict(float) # Process each combination for _, row in df.iterrows(): indices = row["indices"] similarity = row["similarity"] baseline_class_prob = row["baseline_class_prob"] # Update counters for each sample for i in range(n): if i in indices: with_similarity_sum[i] += similarity with_baseline_prob_sum[i] += baseline_class_prob with_count[i] += 1 else: without_similarity_sum[i] += similarity without_baseline_prob_sum[i] += baseline_class_prob without_count[i] += 1 # Calculate Shapley values for different metrics shapley_values = {} for i in range(n): # Similarity-based Shapley (distribution similarity) with_avg = with_similarity_sum[i] / with_count[i] if with_count[i] > 0 else 0 without_avg = without_similarity_sum[i] / without_count[i] if without_count[i] > 0 else 0 similarity_shapley = with_avg - without_avg # Baseline class probability-based Shapley with_prob_avg = with_baseline_prob_sum[i] / with_count[i] if with_count[i] > 0 else 0 without_prob_avg = without_baseline_prob_sum[i] / without_count[i] if without_count[i] > 0 else 0 prob_shapley = with_prob_avg - without_prob_avg shapley_values[f"{samples[i]}_{i}"] = { "similarity_shapley": similarity_shapley, "prob_shapley": prob_shapley } # Normalize each type of Shapley value separately norm_shapley = self._normalize_shapley_dict(shapley_values) return norm_shapley def _normalize_shapley_dict(self, shapley_dict: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]: """Normalize each type of Shapley value separately""" # Get all metric types if not shapley_dict: return {} metrics = list(next(iter(shapley_dict.values())).keys()) normalized = {k: {} for k in shapley_dict} # Normalize each metric separately for metric in metrics: values = [v[metric] for v in shapley_dict.values()] min_val = min(values) max_val = max(values) value_range = max_val - min_val if value_range > 0: for k, v in shapley_dict.items(): normalized[k][metric] = (v[metric] - min_val) / value_range else: for k, v in shapley_dict.items(): normalized[k][metric] = 0.5 # Default to middle when no variance return normalized def get_tokens_shapley_values(self) -> Dict[str, float]: """ Returns a dictionary mapping each token to its Shapley value Returns: Dictionary with token text as keys and Shapley values as values """ if not hasattr(self, 'shapley_values'): raise ValueError("Must run analyze() before getting Shapley values") # Extract token texts without indices tokens = {} for key, value in self.shapley_values.items(): token = key.rsplit('_', 1)[0] # Remove index suffix tokens[token] = value return tokens # Add a method to get the Similarity-based Shapley values specifically def get_sim_shapley_values(self) -> Dict[str, float]: """ Returns a dictionary mapping each token to its similarity-based Shapley value Returns: Dictionary with token text as keys and similarity-based Shapley values as values """ if not hasattr(self, 'shapley_values'): raise ValueError("Must run analyze() before getting Shapley values") # Extract token texts without indices and get the similarity-based metric tokens = {} for key, value_dict in self.shapley_values.items(): token = key.rsplit('_', 1)[0] # Remove index suffix tokens[token] = value_dict["similarity_shapley"] return tokens def analyze(self, prompt: str, sampling_ratio: float = 0.0, max_combinations: Optional[int] = 1000) -> pd.DataFrame: """ Analyze token importance in a financial statement Args: prompt: Financial statement to analyze (without instructions) sampling_ratio: Ratio of combinations to sample (0-1) max_combinations: Maximum number of combinations to generate Returns: DataFrame with analysis results """ # Clean prompt prompt = prompt.strip() prompt = re.sub(r'\s+', ' ', prompt) # Get baseline using full prompt with instructions prefix = "Analyze the sentiment of this statement extracted from a financial news article. Provide your answer as either negative, positive, or neutral.. Text: " suffix = ".. Answer: " full_prompt = build_full_prompt(prompt, prefix, suffix) self.baseline_response = self._calculate_baseline(full_prompt) self.baseline_text = self.baseline_response["label"] # Process combinations (this function will add instructions to each combination) responses = self._get_result_per_combination( prompt, sampling_ratio=sampling_ratio, max_combinations=max_combinations ) # Create results DataFrame self.results_df = self._get_df_per_combination(responses, self.baseline_response) # Calculate Shapley values self.shapley_values = self._calculate_shapley_values(self.results_df, prompt) return self.results_df #To update def plot_colored_text(self, new_line: bool = False): """ Plot text visualization with importance colors Args: new_line: Whether to plot tokens on new lines """ if not hasattr(self, 'shapley_values'): raise ValueError("Must run analyze() before visualization") # Extract token texts without indices tokens = {} for key, value in self.shapley_values.items(): token = key.rsplit('_', 1)[0] # Remove index suffix tokens[token] = value num_items = len(tokens) fig_height = num_items * 0.5 + 1 if new_line else 2 fig, ax = plt.subplots(figsize=(10, fig_height)) ax.axis('off') y_pos = 0.9 x_pos = 0.1 step = 0.8 / (num_items) for token, value in tokens.items(): color = plt.cm.coolwarm(value) if new_line: ax.text( 0.5, y_pos, token, color=color, fontsize=14, ha='center', va='center', transform=ax.transAxes ) y_pos -= step else: ax.text( x_pos, y_pos, token, color=color, fontsize=14, ha='left', va='center', transform=ax.transAxes ) x_pos += len(token) * 0.015 + 0.02 # Adjust spacing based on token length sm = plt.cm.ScalarMappable( cmap=plt.cm.coolwarm, norm=plt.Normalize(vmin=0, vmax=1) ) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.05) cbar.ax.set_position([0.05, 0.02, 0.9, 0.05]) cbar.set_label('Importance (Shapley Value)', fontsize=12) plt.tight_layout() plt.show()