Spaces:
Running
Running
| import torch | |
| from typing import Dict, Any, Optional | |
| from .models import load_bert_model, load_llama_model, BERTModel, LlamaModelWrapper | |
| from .bias_analyzer import BiasAnalyzer | |
| class ModelManager: | |
| """Manages loading and caching of financial models""" | |
| def __init__(self): | |
| self.loaded_models = {} | |
| self.model_configs = { | |
| "FinBERT": { | |
| "model_id": "ProsusAI/finbert", | |
| "type": "bert" | |
| }, | |
| "DeBERTa-v3": { | |
| "model_id": "mrm8488/deberta-v3-ft-financial-news-sentiment-analysis", | |
| "type": "bert" | |
| }, | |
| "DistilRoBERTa": { | |
| "model_id": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis", | |
| "type": "bert" | |
| }, | |
| # "FinMA": { | |
| # "model_id": "ChanceFocus/finma-7b-full", | |
| # "tokenizer_id": "ChanceFocus/finma-7b-full", | |
| # "type": "llama" | |
| # }, | |
| # "FinGPT": { | |
| # "model_id": "oliverwang15/FinGPT_v32_Llama2_Sentiment_Instruction_LoRA_FT", | |
| # "tokenizer_id": "meta-llama/Llama-2-7b-chat-hf", | |
| # "type": "llama" | |
| # } | |
| } | |
| # Label IDs for Llama models | |
| self.label_ids = { | |
| "Positive": [6374], | |
| "Negative": [8178, 22198], | |
| "Neutral": [21104] | |
| } | |
| def load_model(self, model_name: str) -> tuple: | |
| """Load and cache a model""" | |
| if model_name in self.loaded_models: | |
| return self.loaded_models[model_name] | |
| config = self.model_configs[model_name] | |
| try: | |
| if config["type"] == "bert": | |
| model, tokenizer = load_bert_model(config["model_id"]) | |
| wrapped_model = BERTModel(model, tokenizer) | |
| elif config["type"] == "llama": | |
| model, tokenizer = load_llama_model( | |
| base_tokenizer_id=config["tokenizer_id"], | |
| model_id=config["model_id"], | |
| cache_dir="./cache" | |
| ) | |
| wrapped_model = LlamaModelWrapper(model, tokenizer, self.label_ids) | |
| # Cache the loaded model | |
| self.loaded_models[model_name] = (wrapped_model, tokenizer) | |
| return wrapped_model, tokenizer | |
| except Exception as e: | |
| raise Exception(f"Failed to load {model_name}: {str(e)}") | |
| def get_bias_analyzer(self, model_name: str) -> BiasAnalyzer: | |
| """Get a BiasAnalyzer for the specified model""" | |
| wrapped_model, tokenizer = self.load_model(model_name) | |
| # Create BiasAnalyzer with the wrapped model | |
| analyzer = BiasAnalyzer( | |
| model=wrapped_model, | |
| tokenizer=tokenizer, | |
| model_type=self.model_configs[model_name]["type"], | |
| splitter_type='string', | |
| batch_size=16, | |
| is_wrapped=True | |
| ) | |
| return analyzer | |