Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors | |
| import numpy as np | |
| from typing import Dict | |
| from .models import checkModelType, BERTModel, LlamaModelWrapper | |
| from .helpers import build_full_prompt | |
| import os | |
| import csv | |
| from .splitters import StringSplitter, TokenizerSplitter | |
| from .tokenShap import TokenSHAP | |
| class BiasAnalyzer: | |
| """Analyze bias in financial language models using TokenSHAP""" | |
| def __init__(self, model, tokenizer, model_type, splitter_type='string', batch_size = 16, is_wrapped=False): | |
| """ | |
| Initialize bias analyzer | |
| Args: | |
| model: model | |
| tokenizer: tokenizer | |
| splitter_type: Type of splitter ('string' or 'tokenizer') | |
| """ | |
| # Check if model is already a wrapper | |
| if is_wrapped or hasattr(model, 'generate') and hasattr(model, 'generate_batch'): | |
| print("Using pre-wrapped model") | |
| self.model_wrapper = model # Use the model directly | |
| else: | |
| # Check for bert or llama based model | |
| if checkModelType(model) == 'bert': | |
| self.model_wrapper = BERTModel(model, tokenizer) | |
| elif checkModelType(model) == 'llama': | |
| # Assuming label_ids is passed separately or handled elsewhere | |
| raise ValueError("For Llama models, please wrap the model before passing to BiasAnalyzer or provide label_ids") | |
| else: | |
| raise ValueError(f"Unknown model type: {type(model)}. Only BERT and Llama models are supported.") | |
| # Create appropriate splitter | |
| if splitter_type == 'string': | |
| self.splitter = StringSplitter() | |
| elif splitter_type == 'tokenizer': | |
| self.splitter = TokenizerSplitter(tokenizer) | |
| else: | |
| raise ValueError(f"Unknown splitter type: {splitter_type}") | |
| # Initialize token SHAP | |
| self.token_shap = TokenSHAP(self.model_wrapper, self.splitter, batch_size=batch_size) | |
| def compare_sentences(self, original: str, mutated: str, sampling_ratio: float = 0.1, max_combinations: int = 100): | |
| """ | |
| Compare original and mutated sentences | |
| Args: | |
| original: Original financial sentence | |
| mutated: Mutated version of the sentence | |
| sampling_ratio: Ratio of combinations to sample | |
| max_combinations: Maximum number of combinations | |
| Returns: | |
| Comparison results | |
| """ | |
| # Analyze both sentences | |
| original_result = self.analyze_sentence(original, sampling_ratio, max_combinations) | |
| mutated_result = self.analyze_sentence(mutated, sampling_ratio, max_combinations) | |
| # Get prediction changes | |
| prediction_change = mutated_result['prediction']['label'] != original_result['prediction']['label'] | |
| # Find common bias tokens | |
| common_bias_tokens = set(original_result['Bias Token Ranks'].keys()) & set(mutated_result['Bias Token Ranks'].keys()) | |
| # Compare ranks for common bias tokens | |
| bias_rank_changes = {} | |
| for token in common_bias_tokens: | |
| orig_rank = original_result['Bias Token Ranks'][token]['rank'] | |
| mut_rank = mutated_result['Bias Token Ranks'][token]['rank'] | |
| bias_rank_changes[token] = { | |
| 'original_rank': orig_rank, | |
| 'mutated_rank': mut_rank, | |
| 'rank_changed': orig_rank != mut_rank, | |
| 'rank_difference': mut_rank - orig_rank | |
| } | |
| return { | |
| 'original': original_result, | |
| 'mutated': mutated_result, | |
| 'prediction_changed': prediction_change, | |
| 'common_bias_tokens': list(common_bias_tokens), | |
| 'bias_rank_changes': bias_rank_changes | |
| } | |
| def analyze_sentence(self, financial_statement: str, sampling_ratio: float = 0.5, max_combinations: int = 1000): | |
| """ | |
| Analyze a single financial statement | |
| Args: | |
| financial_statement: Plain financial statement to analyze (without instructions) | |
| sampling_ratio: Ratio of combinations to sample | |
| max_combinations: Maximum number of combinations | |
| Returns: | |
| Prediction and analysis results | |
| """ | |
| # Create the full prompt with instructions | |
| prefix = "Analyze the sentiment of this statement extracted from a financial news article. Provide your answer as either negative, positive, or neutral.. Text: " | |
| suffix = ".. Answer: " | |
| full_prompt = build_full_prompt(financial_statement, prefix, suffix) | |
| # Get baseline prediction using the FULL prompt | |
| prediction = self.model_wrapper.generate(prompt=full_prompt) | |
| # Store the prefix and suffix in TokenSHAP for use in combinations | |
| self.token_shap.prompt_prefix = prefix | |
| self.token_shap.prompt_suffix = suffix | |
| # Store the original statement for multi-word bias detection | |
| self.token_shap.original_statement = financial_statement | |
| # Run TokenSHAP analysis on ONLY the financial statement | |
| self.token_shap.analyze(financial_statement, sampling_ratio, max_combinations) | |
| # Get token importance values | |
| shapley_values = self.token_shap.get_tokens_shapley_values() | |
| shapley_values_similarity = self.token_shap.get_sim_shapley_values() | |
| bias_tokens_ranks = self.analyze_bias_tokens_importance('data/bias', original_text=financial_statement) | |
| return { | |
| 'sentence': financial_statement, | |
| 'prediction': prediction, | |
| 'Shapley Values': shapley_values_similarity, | |
| 'Bias Token Ranks': bias_tokens_ranks | |
| } | |
| def analyze_bias_tokens_importance(self, bias_files_dir: str, original_text: str = None): | |
| """ | |
| Analyze the importance of bias tokens in a financial statement | |
| Args: | |
| bias_files_dir: Directory containing files with bias terms | |
| Returns: | |
| Dictionary with bias analysis results including rankings | |
| """ | |
| # Load bias terms from files | |
| single_word_terms, multi_word_terms = self._load_bias_terms(bias_files_dir) | |
| # Get the original sentence and token importance values | |
| shapley_values_similarity = self.token_shap.get_sim_shapley_values() | |
| # Rank ALL tokens by importance (highest to lowest) | |
| all_tokens_ranked = sorted(shapley_values_similarity.items(), key=lambda x: x[1], reverse=True) | |
| # Create rankings dictionary with positions | |
| total_tokens = len(all_tokens_ranked) | |
| token_rankings = {token: {'value': value, 'rank': idx + 1} | |
| for idx, (token, value) in enumerate(all_tokens_ranked)} | |
| # Get the original text - use parameter if provided, otherwise try to get from object | |
| if original_text is None: | |
| original_text = getattr(self.token_shap, 'original_statement', '') | |
| # Original content in lowercase for case-insensitive matching | |
| original_text_lower = original_text.lower() | |
| # Identify bias tokens and their rankings | |
| bias_tokens_with_rank = {} | |
| # 1. Process single-word terms | |
| for token, token_data in token_rankings.items(): | |
| if token.lower() in single_word_terms: | |
| rank = token_data['rank'] | |
| value = token_data['value'] | |
| bias_tokens_with_rank[token] = { | |
| 'shapley_value': value, | |
| 'rank': rank, | |
| 'total_tokens': total_tokens, | |
| 'percentile': round((1 - (rank - 1) / total_tokens) * 100, 1), | |
| 'type': 'single_word' | |
| } | |
| # 2. Process multi-word terms by checking the original sentence | |
| for multi_word_term in multi_word_terms: | |
| # Case insensitive check if the term exists in the original content | |
| if multi_word_term.lower() in original_text_lower: | |
| # Split the multi-word term into individual words | |
| term_words = multi_word_term.lower().split() | |
| # Find matching tokens in our token rankings | |
| matched_tokens = [] | |
| matched_values = [] | |
| # Look for each word in the tokenized tokens | |
| for word in term_words: | |
| for token, data in token_rankings.items(): | |
| # Case insensitive comparison | |
| if word == token.lower(): | |
| matched_tokens.append(token) | |
| matched_values.append(data['value']) | |
| break | |
| # If we found at least one token, calculate an aggregate score | |
| if matched_tokens: | |
| avg_value = sum(matched_values) / len(matched_values) | |
| # Find equivalent rank based on value | |
| equivalent_rank = 1 | |
| for idx, (_, value) in enumerate(all_tokens_ranked): | |
| if avg_value >= value: | |
| equivalent_rank = idx + 1 | |
| break | |
| equivalent_rank = idx + 2 # If lower than all, put at the end | |
| # Add the multi-word term to results | |
| bias_tokens_with_rank[multi_word_term] = { | |
| 'shapley_value': avg_value, | |
| 'rank': equivalent_rank, | |
| 'total_tokens': total_tokens, | |
| 'percentile': round((1 - (equivalent_rank - 1) / total_tokens) * 100, 1), | |
| 'type': 'multi_word', | |
| 'constituent_tokens': matched_tokens, | |
| 'individual_values': dict(zip(matched_tokens, matched_values)) | |
| } | |
| return bias_tokens_with_rank | |
| def _load_bias_terms(self, bias_files_dir: str) -> tuple: | |
| """ | |
| Load bias terms from files in the specified directory | |
| Args: | |
| bias_files_dir: Directory containing files with bias terms | |
| Returns: | |
| Tuple of (single_word_terms, multi_word_terms) | |
| """ | |
| single_word_terms = set() | |
| multi_word_terms = set() | |
| # Check if the directory exists | |
| if not os.path.exists(bias_files_dir): | |
| raise ValueError(f"Bias files directory {bias_files_dir} does not exist") | |
| # Load terms from each file | |
| for bias_folder in os.listdir(bias_files_dir): | |
| folder_path = os.path.join(bias_files_dir, bias_folder) | |
| if not os.path.isdir(folder_path): | |
| continue | |
| for file in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, file) | |
| if os.path.isfile(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| csv_reader = csv.reader(f, delimiter=';') | |
| for row in csv_reader: | |
| for term in row: | |
| term = term.strip().lower() | |
| if term: | |
| if ' ' in term: | |
| multi_word_terms.add(term) | |
| else: | |
| single_word_terms.add(term) | |
| return single_word_terms, multi_word_terms | |