Spaces:
Running
Running
| #Base classes and utilities for TokenSHAP | |
| # SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors | |
| import numpy as np | |
| import pandas as pd | |
| import re | |
| import random | |
| from typing import List, Dict, Optional, Tuple, Any | |
| from tqdm.auto import tqdm | |
| class ModelBase: | |
| """Base model interface""" | |
| def generate(self, **kwargs) -> str: | |
| """Generate a response for the given input""" | |
| raise NotImplementedError | |
| class BaseSHAP: | |
| """Base class for SHAP value calculation with Monte Carlo sampling""" | |
| def __init__(self, model: ModelBase, debug: bool = False): | |
| """ | |
| Initialize BaseSHAP | |
| Args: | |
| model: Model to analyze | |
| debug: Enable debug output | |
| """ | |
| self.model = model | |
| self.cache = {} # Cache for model responses | |
| self.debug = debug | |
| def _calculate_baseline(self, content: str) -> Dict[str, Any]: | |
| """Calculate baseline model response for full content""" | |
| # Content here should already have the prefix/suffix if needed | |
| baseline = self.model.generate(prompt=content) | |
| if self.debug: | |
| print(f"Baseline prediction: {baseline['label']}") | |
| return baseline | |
| def _prepare_generate_args(self, content: str, **kwargs) -> Dict: | |
| """Prepare arguments for model.generate()""" | |
| raise NotImplementedError | |
| def _get_samples(self, content: str) -> List[str]: | |
| """Get samples from content""" | |
| raise NotImplementedError | |
| def _prepare_combination_args(self, combination: List[str], original_content: str) -> Dict: | |
| """Prepare model arguments for a combination""" | |
| raise NotImplementedError | |
| def _get_combination_key(self, combination: List[str], indexes: Tuple[int, ...]) -> str: | |
| """Get unique key for combination""" | |
| raise NotImplementedError | |
| def _get_all_combinations(self, samples: List[str], sampling_ratio: float = 0.0, | |
| max_combinations: Optional[int] = None) -> Dict[str, Tuple[List[str], Tuple[int, ...]]]: | |
| """ | |
| Get all possible combinations of samples with their indices | |
| Args: | |
| samples: List of samples (e.g., tokens) | |
| sampling_ratio: Ratio of combinations to sample (0-1) | |
| max_combinations: Maximum number of combinations to generate | |
| Returns: | |
| Dictionary mapping combination keys to (combination, indices) tuples | |
| """ | |
| n = len(samples) | |
| # Always include combinations that exclude exactly one token | |
| essential_combinations = {} | |
| for i in range(n): | |
| combination = samples.copy() | |
| del combination[i] | |
| indices = tuple(j for j in range(n) if j != i) | |
| key = f"omit_{i}" | |
| essential_combinations[key] = (combination, indices) | |
| # Calculate total possible combinations and sampling count | |
| if sampling_ratio <= 0: | |
| # Just return essential combinations | |
| return essential_combinations | |
| total_combinations = 2**n - 1 # All non-empty combinations | |
| sample_count = int(total_combinations * sampling_ratio) | |
| if max_combinations is not None: | |
| sample_count = min(sample_count, max_combinations) | |
| if sample_count <= len(essential_combinations): | |
| return essential_combinations | |
| # Randomly sample additional combinations | |
| all_combinations = essential_combinations.copy() | |
| additional_needed = sample_count - len(essential_combinations) | |
| # Generate random combinations | |
| combinations_added = 0 | |
| max_attempts = additional_needed * 10 # Limit attempts to avoid infinite loop | |
| attempts = 0 | |
| while combinations_added < additional_needed and attempts < max_attempts: | |
| # Decide how many tokens to include | |
| subset_size = random.randint(1, n-1) # At least 1, at most n-1 | |
| # Randomly select indices | |
| indices = tuple(sorted(random.sample(range(n), subset_size))) | |
| # Create combination | |
| combination = [samples[i] for i in indices] | |
| key = f"random_{','.join(str(i) for i in indices)}" | |
| # Only add if not already present | |
| if key not in all_combinations: | |
| all_combinations[key] = (combination, indices) | |
| combinations_added += 1 | |
| attempts += 1 | |
| if self.debug and attempts >= max_attempts: | |
| print(f"Warning: Reached max attempts ({max_attempts}) when generating combinations") | |
| return all_combinations | |
| def _get_result_per_combination(self, content: str, sampling_ratio: float = 0.0, | |
| max_combinations: Optional[int] = None) -> Dict[str, Dict[str, Any]]: | |
| """ | |
| Get model responses for combinations of content | |
| Args: | |
| content: Original content | |
| sampling_ratio: Ratio of combinations to sample | |
| max_combinations: Maximum number of combinations | |
| Returns: | |
| Dictionary mapping combination keys to response data | |
| """ | |
| samples = self._get_samples(content) | |
| if self.debug: | |
| print(f"Found {len(samples)} samples in content") | |
| combinations = self._get_all_combinations(samples, sampling_ratio, max_combinations) | |
| if self.debug: | |
| print(f"Generated {len(combinations)} combinations") | |
| results = {} | |
| # Process each combination | |
| for key, (combination, indices) in tqdm(combinations.items(), desc="Processing combinations"): | |
| comb_args = self._prepare_combination_args(combination, content) | |
| comb_key = self._get_combination_key(combination, indices) | |
| # Check cache first | |
| if comb_key in self._cache: | |
| response = self._cache[comb_key] | |
| else: | |
| response = self.model.generate(**comb_args) | |
| self._cache[comb_key] = response | |
| # Store results | |
| results[key] = { | |
| "combination": combination, | |
| "indices": indices, | |
| "response": response | |
| } | |
| return results | |
| def analyze(self, content: str, sampling_ratio: float = 0.0, max_combinations: Optional[int] = None) -> pd.DataFrame: | |
| """ | |
| Analyze importance in content | |
| Args: | |
| content: Content to analyze | |
| sampling_ratio: Ratio of combinations to sample | |
| max_combinations: Maximum number of combinations | |
| Returns: | |
| DataFrame with analysis results | |
| """ | |
| raise NotImplementedError | |