File size: 6,913 Bytes
5c7385e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#Base classes and utilities for TokenSHAP
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import numpy as np
import pandas as pd
import re
import random
from typing import List, Dict, Optional, Tuple, Any
from tqdm.auto import tqdm

class ModelBase:
    """Base model interface"""
   
    def generate(self, **kwargs) -> str:
        """Generate a response for the given input"""
        raise NotImplementedError


class BaseSHAP:
    """Base class for SHAP value calculation with Monte Carlo sampling"""
   
    def __init__(self, model: ModelBase, debug: bool = False):
        """
        Initialize BaseSHAP
       
        Args:
            model: Model to analyze
            debug: Enable debug output
        """
        self.model = model  
        self.cache = {}  # Cache for model responses
        self.debug = debug
   
    def _calculate_baseline(self, content: str) -> Dict[str, Any]:
        """Calculate baseline model response for full content"""
        # Content here should already have the prefix/suffix if needed
        baseline = self.model.generate(prompt=content)
        if self.debug:
            print(f"Baseline prediction: {baseline['label']}")
        return baseline
   
    def _prepare_generate_args(self, content: str, **kwargs) -> Dict:
        """Prepare arguments for model.generate()"""
        raise NotImplementedError
   
    def _get_samples(self, content: str) -> List[str]:
        """Get samples from content"""
        raise NotImplementedError
   
    def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict:
        """Prepare model arguments for a combination"""
        raise NotImplementedError
   
    def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str:
        """Get unique key for combination"""
        raise NotImplementedError
   
    def _get_all_combinations(self, samples: List[str], sampling_ratio: float = 0.0,
                             max_combinations: Optional[int] = None) -> Dict[str, Tuple[List[str], Tuple[int, ...]]]:
        """
        Get all possible combinations of samples with their indices
       
        Args:
            samples: List of samples (e.g., tokens)
            sampling_ratio: Ratio of combinations to sample (0-1)
            max_combinations: Maximum number of combinations to generate
           
        Returns:
            Dictionary mapping combination keys to (combination, indices) tuples
        """
        n = len(samples)
        # Always include combinations that exclude exactly one token
        essential_combinations = {}      
        for i in range(n):
            combination = samples.copy()
            del combination[i]
            indices = tuple(j for j in range(n) if j != i)
            key = f"omit_{i}"
            essential_combinations[key] = (combination, indices)
       
        # Calculate total possible combinations and sampling count
        if sampling_ratio <= 0:
            # Just return essential combinations
            return essential_combinations
           
        total_combinations = 2**n - 1  # All non-empty combinations
        sample_count = int(total_combinations * sampling_ratio)
       
        if max_combinations is not None:
            sample_count = min(sample_count, max_combinations)
       
        if sample_count <= len(essential_combinations):
            return essential_combinations
        
        # Randomly sample additional combinations
        all_combinations = essential_combinations.copy()
        additional_needed = sample_count - len(essential_combinations)       
        # Generate random combinations
        combinations_added = 0
        max_attempts = additional_needed * 10  # Limit attempts to avoid infinite loop
        attempts = 0
       
        while combinations_added < additional_needed and attempts < max_attempts:
            # Decide how many tokens to include
            subset_size = random.randint(1, n-1)  # At least 1, at most n-1
           
            # Randomly select indices
            indices = tuple(sorted(random.sample(range(n), subset_size)))
           
            # Create combination
            combination = [samples[i] for i in indices]
            key = f"random_{','.join(str(i) for i in indices)}"
           
            # Only add if not already present
            if key not in all_combinations:
                all_combinations[key] = (combination, indices)
                combinations_added += 1
           
            attempts += 1
       
        if self.debug and attempts >= max_attempts:
            print(f"Warning: Reached max attempts ({max_attempts}) when generating combinations")
           
        return all_combinations
   
    def _get_result_per_combination(self, content: str, sampling_ratio: float = 0.0,
                                   max_combinations: Optional[int] = None) -> Dict[str, Dict[str, Any]]:
        """
        Get model responses for combinations of content
       
        Args:
            content: Original content
            sampling_ratio: Ratio of combinations to sample
            max_combinations: Maximum number of combinations
           
        Returns:
            Dictionary mapping combination keys to response data
        """
        samples = self._get_samples(content)
        if self.debug:
            print(f"Found {len(samples)} samples in content")
       
        combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations)
        if self.debug:
            print(f"Generated {len(combinations)} combinations")
        
        results = {}
        # Process each combination
        for key, (combination, indices) in tqdm(combinations.items(), desc="Processing combinations"):
            comb_args = self._prepare_combination_args(combination, content)
            comb_key = self._get_combination_key(combination, indices)
           
            # Check cache first
            if comb_key in self._cache:
                response = self._cache[comb_key]
            else:
                response = self.model.generate(**comb_args)
                self._cache[comb_key] = response
           
            # Store results
            results[key] = {
                "combination": combination,
                "indices": indices,
                "response": response
            }        
        return results

    def analyze(self, content: str, sampling_ratio: float = 0.0, max_combinations: Optional[int] = None) -> pd.DataFrame:
        """
        Analyze importance in content
       
        Args:
            content: Content to analyze
            sampling_ratio: Ratio of combinations to sample
            max_combinations: Maximum number of combinations
           
        Returns:
            DataFrame with analysis results
        """
        raise NotImplementedError