File size: 15,852 Bytes
5c7385e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import pandas as pd
import matplotlib.pyplot as plt
import re
from typing import List, Dict, Optional, Tuple, Any
from tqdm.auto import tqdm
from collections import defaultdict
from .base import ModelBase, BaseSHAP
from .splitters import Splitter
from .helpers import build_full_prompt, jensen_shannon_distance

class TokenSHAP(BaseSHAP):
    """Analyzes token importance in text prompts using SHAP values"""
   
    def __init__(self,
                 model: ModelBase,
                 splitter: Splitter,
                 debug: bool = False,
                 batch_size=16):
        """
        Initialize TokenSHAP
       
        Args:
            model: Model to analyze
            splitter: Text splitter implementation
            debug: Enable debug output
        """
        super().__init__(model, debug)
        self.splitter = splitter
        self.prompt_prefix = ""  
        self.prompt_suffix = ""  
        self.batch_size = batch_size

    def _get_samples(self, content: str) -> List[str]:
        """Get tokens from prompt"""
        return self.splitter.split(content)
   
    def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str:
        return self.splitter.join(combination)

    def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict:
        prompt = f"{self.prompt_prefix}{self.splitter.join(combination)}{self.prompt_suffix}"
        return {"prompt": prompt}

    def _get_result_per_combination(self, content, sampling_ratio=0.0, max_combinations=None):
        """
        Get model responses for combinations with batch processing
        
        Args:
            content: Original content
            sampling_ratio: Ratio of combinations to sample
            max_combinations: Maximum number of combinations
            
        Returns:
            Dictionary mapping combination keys to response data
        """
        samples = self._get_samples(content)
        combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations)
        
        # Prepare prompts for batch processing
        prompts = []
        comb_keys = []
        comb_indices = []
        
        for key, (combination, indices) in combinations.items():
            #Call with both parameters and extract prompt from returned dict
            comb_args = self._prepare_combination_args(combination, content)
            prompt = comb_args["prompt"]  # Extract prompt from dict
            
            prompts.append(prompt)
            comb_keys.append(key)
            comb_indices.append(indices)
        
        # Batching with error handling
        all_results = []
        for batch_start in range(0, len(prompts), self.batch_size):
            batch_end = min(batch_start + self.batch_size, len(prompts))
            batch_prompts = prompts[batch_start:batch_end]
            try:
                batch_results = self.model.generate_batch(batch_prompts)
                all_results.extend(batch_results)
            except RuntimeError as e:
                if "stack expects each tensor to be equal size" in str(e):
                    print(f"Error in batch {batch_start//self.batch_size}: {str(e)}")
                    print("Falling back to individual processing for this batch")
                    # Fall back to individual processing with generate
                    for prompt in batch_prompts:
                        try:
                            single_result = self.model.generate(prompt)
                            all_results.append(single_result)
                        except Exception as inner_e:
                            print(f"Individual processing also failed: {str(inner_e)}")
                            # Provide fallback result with default values
                            all_results.append({
                                "label": "NA",  
                                "probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34}
                            })
                else:
                    # Re-raise other RuntimeErrors
                    raise
            except Exception as other_e:
                # Handle any other exceptions during batch processing
                print(f"Unexpected error in batch {batch_start//self.batch_size}: {str(other_e)}")
                # Fall back to individual processing
                for prompt in batch_prompts:
                    try:
                        single_result = self.model.generate(prompt)
                        all_results.append(single_result)
                    except Exception:
                        # Provide fallback result
                        all_results.append({
                            "label": "NA",
                            "probabilities": {"Positive": 0.33, "Negative": 0.33, "Neutral": 0.34}
                        })
        
        # Attach back to combination keys
        results = {}
        for i, key in enumerate(comb_keys):
            results[key] = {
                "combination": combinations[key][0],
                "indices": comb_indices[i],
                "response": all_results[i]
            }
        
        return results 

    def _get_df_per_combination(self, responses: Dict[str, Dict[str, Any]], baseline_response: Dict[str, Any]) -> pd.DataFrame:
        """
        Create DataFrame with combination results using probability-based similarity
       
        Args:
            responses: Dictionary of combination responses
            baseline_response: Baseline model response
           
        Returns:
            DataFrame with results
        """
        # Prepare data for DataFrame
        data = []
       
        baseline_probs = baseline_response["probabilities"]
        baseline_label = baseline_response["label"]
       
        # Process each combination response
        for key, res in responses.items():
            combination = res["combination"]
            indices = res["indices"]
            response_data = res["response"]
            response_probs = response_data["probabilities"]
            response_label = response_data["label"]
           
            # Calculate probability-based similarity (lower = more similar)
            prob_similarity = 1.0 - jensen_shannon_distance(baseline_probs, response_probs)
           
            # Track the probability of the baseline's predicted class
            baseline_class_prob = response_probs.get(baseline_label, 0.0)
           
            # Add to data
            data.append({
                "key": key,
                "combination": combination,
                "indices": indices,
                "response_label": response_label,
                "similarity": prob_similarity,
                "baseline_class_prob": baseline_class_prob,
                "probabilities": response_probs
            })
       
        # Create DataFrame
        df = pd.DataFrame(data)
        return df
   
    def _calculate_shapley_values(self, df: pd.DataFrame, content: str) -> Dict[str, Dict[str, float]]:
        """
        Calculate Shapley values for each sample using probability distributions
       
        Args:
            df: DataFrame with combination results
            content: Original content
           
        Returns:
            Dictionary mapping sample names to various Shapley values
        """
        samples = self._get_samples(content)
        n = len(samples)
       
        # Initialize counters for each sample
        with_count = defaultdict(int)
        without_count = defaultdict(int)
        with_similarity_sum = defaultdict(float)
        without_similarity_sum = defaultdict(float)
        with_baseline_prob_sum = defaultdict(float)
        without_baseline_prob_sum = defaultdict(float)
       
        # Process each combination
        for _, row in df.iterrows():
            indices = row["indices"]  
            similarity = row["similarity"]
            baseline_class_prob = row["baseline_class_prob"]
           
            # Update counters for each sample
            for i in range(n):
                if i in indices:
                    with_similarity_sum[i] += similarity
                    with_baseline_prob_sum[i] += baseline_class_prob
                    with_count[i] += 1
                else:
                    without_similarity_sum[i] += similarity
                    without_baseline_prob_sum[i] += baseline_class_prob
                    without_count[i] += 1
        
        # Calculate Shapley values for different metrics
        shapley_values = {}
        for i in range(n):
            # Similarity-based Shapley (distribution similarity)
            with_avg = with_similarity_sum[i] / with_count[i] if with_count[i] > 0 else 0
            without_avg = without_similarity_sum[i] / without_count[i] if without_count[i] > 0 else 0
            similarity_shapley = with_avg - without_avg
           
            # Baseline class probability-based Shapley
            with_prob_avg = with_baseline_prob_sum[i] / with_count[i] if with_count[i] > 0 else 0
            without_prob_avg = without_baseline_prob_sum[i] / without_count[i] if without_count[i] > 0 else 0
            prob_shapley = with_prob_avg - without_prob_avg
           
            shapley_values[f"{samples[i]}_{i}"] = {
                "similarity_shapley": similarity_shapley,
                "prob_shapley": prob_shapley
            }
        
        # Normalize each type of Shapley value separately
        norm_shapley = self._normalize_shapley_dict(shapley_values)
       
        return norm_shapley
   
    def _normalize_shapley_dict(self, shapley_dict: Dict[str, Dict[str, float]]) -> Dict[str, Dict[str, float]]:
        """Normalize each type of Shapley value separately"""
        # Get all metric types
        if not shapley_dict:
            return {}
           
        metrics = list(next(iter(shapley_dict.values())).keys())
        normalized = {k: {} for k in shapley_dict}
       
        # Normalize each metric separately
        for metric in metrics:
            values = [v[metric] for v in shapley_dict.values()]
            min_val = min(values)
            max_val = max(values)
            value_range = max_val - min_val
           
            if value_range > 0:
                for k, v in shapley_dict.items():
                    normalized[k][metric] = (v[metric] - min_val) / value_range
            else:
                for k, v in shapley_dict.items():
                    normalized[k][metric] = 0.5  # Default to middle when no variance
                   
        return normalized
   
    def get_tokens_shapley_values(self) -> Dict[str, float]:
        """
        Returns a dictionary mapping each token to its Shapley value
       
        Returns:
            Dictionary with token text as keys and Shapley values as values
        """
        if not hasattr(self, 'shapley_values'):
            raise ValueError("Must run analyze() before getting Shapley values")

        # Extract token texts without indices
        tokens = {}
        for key, value in self.shapley_values.items():
            token = key.rsplit('_', 1)[0]  # Remove index suffix
            tokens[token] = value
       
        return tokens

    # Add a method to get the Similarity-based Shapley values specifically
    def get_sim_shapley_values(self) -> Dict[str, float]:
        """
        Returns a dictionary mapping each token to its similarity-based Shapley value
       
        Returns:
            Dictionary with token text as keys and similarity-based Shapley values as values
        """
        if not hasattr(self, 'shapley_values'):
            raise ValueError("Must run analyze() before getting Shapley values")

        # Extract token texts without indices and get the similarity-based metric
        tokens = {}
        for key, value_dict in self.shapley_values.items():
            token = key.rsplit('_', 1)[0]  # Remove index suffix
            tokens[token] = value_dict["similarity_shapley"]
       
        return tokens
   
    def analyze(self, prompt: str,
        sampling_ratio: float = 0.0,
        max_combinations: Optional[int] = 1000) -> pd.DataFrame:
        """
        Analyze token importance in a financial statement
       
        Args:
            prompt: Financial statement to analyze (without instructions)
            sampling_ratio: Ratio of combinations to sample (0-1)
            max_combinations: Maximum number of combinations to generate
           
        Returns:
            DataFrame with analysis results
        """
        # Clean prompt
        prompt = prompt.strip()
        prompt = re.sub(r'\s+', ' ', prompt)

        # Get baseline using full prompt with instructions
        prefix = "Analyze the sentiment of this statement extracted from a financial news article. Provide your answer as either negative, positive, or neutral.. Text: "
        suffix = ".. Answer: "
        full_prompt = build_full_prompt(prompt, prefix, suffix)
        self.baseline_response = self._calculate_baseline(full_prompt)
        self.baseline_text = self.baseline_response["label"]
       
        # Process combinations (this function will add instructions to each combination)
        responses = self._get_result_per_combination(
            prompt,
            sampling_ratio=sampling_ratio,
            max_combinations=max_combinations
        )
       
        # Create results DataFrame
        self.results_df = self._get_df_per_combination(responses, self.baseline_response)
       
        # Calculate Shapley values
        self.shapley_values = self._calculate_shapley_values(self.results_df, prompt)

        return self.results_df
   
    #To update
    def plot_colored_text(self, new_line: bool = False):
        """
        Plot text visualization with importance colors
       
        Args:
            new_line: Whether to plot tokens on new lines
        """
        if not hasattr(self, 'shapley_values'):
            raise ValueError("Must run analyze() before visualization")

        # Extract token texts without indices
        tokens = {}
        for key, value in self.shapley_values.items():
            token = key.rsplit('_', 1)[0]  # Remove index suffix
            tokens[token] = value

        num_items = len(tokens)
        fig_height = num_items * 0.5 + 1 if new_line else 2

        fig, ax = plt.subplots(figsize=(10, fig_height))
        ax.axis('off')

        y_pos = 0.9
        x_pos = 0.1
        step = 0.8 / (num_items)

        for token, value in tokens.items():
            color = plt.cm.coolwarm(value)
            if new_line:
                ax.text(
                    0.5, y_pos,
                    token,
                    color=color,
                    fontsize=14,
                    ha='center',
                    va='center',
                    transform=ax.transAxes
                )
                y_pos -= step
            else:
                ax.text(
                    x_pos, y_pos,
                    token,
                    color=color,
                    fontsize=14,
                    ha='left',
                    va='center',
                    transform=ax.transAxes
                )
                x_pos += len(token) * 0.015 + 0.02  # Adjust spacing based on token length

        sm = plt.cm.ScalarMappable(
            cmap=plt.cm.coolwarm,
            norm=plt.Normalize(vmin=0, vmax=1)
        )
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.05)
        cbar.ax.set_position([0.05, 0.02, 0.9, 0.05])
        cbar.set_label('Importance (Shapley Value)', fontsize=12)

        plt.tight_layout()
        plt.show()