FBDF / backend /bias_analyzer.py
Firas HADJ KACEM
created the interface
5c7385e
# SPDX-FileCopyrightText: 2023-2024 The TokenSHAP Authors
import numpy as np
from typing import Dict
from .models import checkModelType, BERTModel, LlamaModelWrapper
from .helpers import build_full_prompt
import os
import csv
from .splitters import StringSplitter, TokenizerSplitter
from .tokenShap import TokenSHAP
class BiasAnalyzer:
"""Analyze bias in financial language models using TokenSHAP"""
def __init__(self, model, tokenizer, model_type, splitter_type='string', batch_size = 16, is_wrapped=False):
"""
Initialize bias analyzer
Args:
model: model
tokenizer: tokenizer
splitter_type: Type of splitter ('string' or 'tokenizer')
"""
# Check if model is already a wrapper
if is_wrapped or hasattr(model, 'generate') and hasattr(model, 'generate_batch'):
print("Using pre-wrapped model")
self.model_wrapper = model # Use the model directly
else:
# Check for bert or llama based model
if checkModelType(model) == 'bert':
self.model_wrapper = BERTModel(model, tokenizer)
elif checkModelType(model) == 'llama':
# Assuming label_ids is passed separately or handled elsewhere
raise ValueError("For Llama models, please wrap the model before passing to BiasAnalyzer or provide label_ids")
else:
raise ValueError(f"Unknown model type: {type(model)}. Only BERT and Llama models are supported.")
# Create appropriate splitter
if splitter_type == 'string':
self.splitter = StringSplitter()
elif splitter_type == 'tokenizer':
self.splitter = TokenizerSplitter(tokenizer)
else:
raise ValueError(f"Unknown splitter type: {splitter_type}")
# Initialize token SHAP
self.token_shap = TokenSHAP(self.model_wrapper, self.splitter, batch_size=batch_size)
def compare_sentences(self, original: str, mutated: str, sampling_ratio: float = 0.1, max_combinations: int = 100):
"""
Compare original and mutated sentences
Args:
original: Original financial sentence
mutated: Mutated version of the sentence
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
Comparison results
"""
# Analyze both sentences
original_result = self.analyze_sentence(original, sampling_ratio, max_combinations)
mutated_result = self.analyze_sentence(mutated, sampling_ratio, max_combinations)
# Get prediction changes
prediction_change = mutated_result['prediction']['label'] != original_result['prediction']['label']
# Find common bias tokens
common_bias_tokens = set(original_result['Bias Token Ranks'].keys()) & set(mutated_result['Bias Token Ranks'].keys())
# Compare ranks for common bias tokens
bias_rank_changes = {}
for token in common_bias_tokens:
orig_rank = original_result['Bias Token Ranks'][token]['rank']
mut_rank = mutated_result['Bias Token Ranks'][token]['rank']
bias_rank_changes[token] = {
'original_rank': orig_rank,
'mutated_rank': mut_rank,
'rank_changed': orig_rank != mut_rank,
'rank_difference': mut_rank - orig_rank
}
return {
'original': original_result,
'mutated': mutated_result,
'prediction_changed': prediction_change,
'common_bias_tokens': list(common_bias_tokens),
'bias_rank_changes': bias_rank_changes
}
def analyze_sentence(self, financial_statement: str, sampling_ratio: float = 0.5, max_combinations: int = 1000):
"""
Analyze a single financial statement
Args:
financial_statement: Plain financial statement to analyze (without instructions)
sampling_ratio: Ratio of combinations to sample
max_combinations: Maximum number of combinations
Returns:
Prediction and analysis results
"""
# Create the full prompt with instructions
prefix = "Analyze the sentiment of this statement extracted from a financial news article. Provide your answer as either negative, positive, or neutral.. Text: "
suffix = ".. Answer: "
full_prompt = build_full_prompt(financial_statement, prefix, suffix)
# Get baseline prediction using the FULL prompt
prediction = self.model_wrapper.generate(prompt=full_prompt)
# Store the prefix and suffix in TokenSHAP for use in combinations
self.token_shap.prompt_prefix = prefix
self.token_shap.prompt_suffix = suffix
# Store the original statement for multi-word bias detection
self.token_shap.original_statement = financial_statement
# Run TokenSHAP analysis on ONLY the financial statement
self.token_shap.analyze(financial_statement, sampling_ratio, max_combinations)
# Get token importance values
shapley_values = self.token_shap.get_tokens_shapley_values()
shapley_values_similarity = self.token_shap.get_sim_shapley_values()
bias_tokens_ranks = self.analyze_bias_tokens_importance('data/bias', original_text=financial_statement)
return {
'sentence': financial_statement,
'prediction': prediction,
'Shapley Values': shapley_values_similarity,
'Bias Token Ranks': bias_tokens_ranks
}
def analyze_bias_tokens_importance(self, bias_files_dir: str, original_text: str = None):
"""
Analyze the importance of bias tokens in a financial statement
Args:
bias_files_dir: Directory containing files with bias terms
Returns:
Dictionary with bias analysis results including rankings
"""
# Load bias terms from files
single_word_terms, multi_word_terms = self._load_bias_terms(bias_files_dir)
# Get the original sentence and token importance values
shapley_values_similarity = self.token_shap.get_sim_shapley_values()
# Rank ALL tokens by importance (highest to lowest)
all_tokens_ranked = sorted(shapley_values_similarity.items(), key=lambda x: x[1], reverse=True)
# Create rankings dictionary with positions
total_tokens = len(all_tokens_ranked)
token_rankings = {token: {'value': value, 'rank': idx + 1}
for idx, (token, value) in enumerate(all_tokens_ranked)}
# Get the original text - use parameter if provided, otherwise try to get from object
if original_text is None:
original_text = getattr(self.token_shap, 'original_statement', '')
# Original content in lowercase for case-insensitive matching
original_text_lower = original_text.lower()
# Identify bias tokens and their rankings
bias_tokens_with_rank = {}
# 1. Process single-word terms
for token, token_data in token_rankings.items():
if token.lower() in single_word_terms:
rank = token_data['rank']
value = token_data['value']
bias_tokens_with_rank[token] = {
'shapley_value': value,
'rank': rank,
'total_tokens': total_tokens,
'percentile': round((1 - (rank - 1) / total_tokens) * 100, 1),
'type': 'single_word'
}
# 2. Process multi-word terms by checking the original sentence
for multi_word_term in multi_word_terms:
# Case insensitive check if the term exists in the original content
if multi_word_term.lower() in original_text_lower:
# Split the multi-word term into individual words
term_words = multi_word_term.lower().split()
# Find matching tokens in our token rankings
matched_tokens = []
matched_values = []
# Look for each word in the tokenized tokens
for word in term_words:
for token, data in token_rankings.items():
# Case insensitive comparison
if word == token.lower():
matched_tokens.append(token)
matched_values.append(data['value'])
break
# If we found at least one token, calculate an aggregate score
if matched_tokens:
avg_value = sum(matched_values) / len(matched_values)
# Find equivalent rank based on value
equivalent_rank = 1
for idx, (_, value) in enumerate(all_tokens_ranked):
if avg_value >= value:
equivalent_rank = idx + 1
break
equivalent_rank = idx + 2 # If lower than all, put at the end
# Add the multi-word term to results
bias_tokens_with_rank[multi_word_term] = {
'shapley_value': avg_value,
'rank': equivalent_rank,
'total_tokens': total_tokens,
'percentile': round((1 - (equivalent_rank - 1) / total_tokens) * 100, 1),
'type': 'multi_word',
'constituent_tokens': matched_tokens,
'individual_values': dict(zip(matched_tokens, matched_values))
}
return bias_tokens_with_rank
def _load_bias_terms(self, bias_files_dir: str) -> tuple:
"""
Load bias terms from files in the specified directory
Args:
bias_files_dir: Directory containing files with bias terms
Returns:
Tuple of (single_word_terms, multi_word_terms)
"""
single_word_terms = set()
multi_word_terms = set()
# Check if the directory exists
if not os.path.exists(bias_files_dir):
raise ValueError(f"Bias files directory {bias_files_dir} does not exist")
# Load terms from each file
for bias_folder in os.listdir(bias_files_dir):
folder_path = os.path.join(bias_files_dir, bias_folder)
if not os.path.isdir(folder_path):
continue
for file in os.listdir(folder_path):
file_path = os.path.join(folder_path, file)
if os.path.isfile(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
csv_reader = csv.reader(f, delimiter=';')
for row in csv_reader:
for term in row:
term = term.strip().lower()
if term:
if ' ' in term:
multi_word_terms.add(term)
else:
single_word_terms.add(term)
return single_word_terms, multi_word_terms