import torch from typing import Dict, Any, Optional from .models import load_bert_model, load_llama_model, BERTModel, LlamaModelWrapper from .bias_analyzer import BiasAnalyzer class ModelManager: """Manages loading and caching of financial models""" def __init__(self): self.loaded_models = {} self.model_configs = { "FinBERT": { "model_id": "ProsusAI/finbert", "type": "bert" }, "DeBERTa-v3": { "model_id": "mrm8488/deberta-v3-ft-financial-news-sentiment-analysis", "type": "bert" }, "DistilRoBERTa": { "model_id": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", "type": "bert" }, # "FinMA": { # "model_id": "ChanceFocus/finma-7b-full", # "tokenizer_id": "ChanceFocus/finma-7b-full", # "type": "llama" # }, # "FinGPT": { # "model_id": "oliverwang15/FinGPT_v32_Llama2_Sentiment_Instruction_LoRA_FT", # "tokenizer_id": "meta-llama/Llama-2-7b-chat-hf", # "type": "llama" # } } # Label IDs for Llama models self.label_ids = { "Positive": [6374], "Negative": [8178, 22198], "Neutral": [21104] } def load_model(self, model_name: str) -> tuple: """Load and cache a model""" if model_name in self.loaded_models: return self.loaded_models[model_name] config = self.model_configs[model_name] try: if config["type"] == "bert": model, tokenizer = load_bert_model(config["model_id"]) wrapped_model = BERTModel(model, tokenizer) elif config["type"] == "llama": model, tokenizer = load_llama_model( base_tokenizer_id=config["tokenizer_id"], model_id=config["model_id"], cache_dir="./cache" ) wrapped_model = LlamaModelWrapper(model, tokenizer, self.label_ids) # Cache the loaded model self.loaded_models[model_name] = (wrapped_model, tokenizer) return wrapped_model, tokenizer except Exception as e: raise Exception(f"Failed to load {model_name}: {str(e)}") def get_bias_analyzer(self, model_name: str) -> BiasAnalyzer: """Get a BiasAnalyzer for the specified model""" wrapped_model, tokenizer = self.load_model(model_name) # Create BiasAnalyzer with the wrapped model analyzer = BiasAnalyzer( model=wrapped_model, tokenizer=tokenizer, model_type=self.model_configs[model_name]["type"], splitter_type='string', batch_size=16, is_wrapped=True ) return analyzer