FBDF / backend /model_manager.py
Firas HADJ KACEM
keep only lightweight models
2ce05e0
import torch
from typing import Dict, Any, Optional
from .models import load_bert_model, load_llama_model, BERTModel, LlamaModelWrapper
from .bias_analyzer import BiasAnalyzer
class ModelManager:
"""Manages loading and caching of financial models"""
def __init__(self):
self.loaded_models = {}
self.model_configs = {
"FinBERT": {
"model_id": "ProsusAI/finbert",
"type": "bert"
},
"DeBERTa-v3": {
"model_id": "mrm8488/deberta-v3-ft-financial-news-sentiment-analysis",
"type": "bert"
},
"DistilRoBERTa": {
"model_id": "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis",
"type": "bert"
},
# "FinMA": {
# "model_id": "ChanceFocus/finma-7b-full",
# "tokenizer_id": "ChanceFocus/finma-7b-full",
# "type": "llama"
# },
# "FinGPT": {
# "model_id": "oliverwang15/FinGPT_v32_Llama2_Sentiment_Instruction_LoRA_FT",
# "tokenizer_id": "meta-llama/Llama-2-7b-chat-hf",
# "type": "llama"
# }
}
# Label IDs for Llama models
self.label_ids = {
"Positive": [6374],
"Negative": [8178, 22198],
"Neutral": [21104]
}
def load_model(self, model_name: str) -> tuple:
"""Load and cache a model"""
if model_name in self.loaded_models:
return self.loaded_models[model_name]
config = self.model_configs[model_name]
try:
if config["type"] == "bert":
model, tokenizer = load_bert_model(config["model_id"])
wrapped_model = BERTModel(model, tokenizer)
elif config["type"] == "llama":
model, tokenizer = load_llama_model(
base_tokenizer_id=config["tokenizer_id"],
model_id=config["model_id"],
cache_dir="./cache"
)
wrapped_model = LlamaModelWrapper(model, tokenizer, self.label_ids)
# Cache the loaded model
self.loaded_models[model_name] = (wrapped_model, tokenizer)
return wrapped_model, tokenizer
except Exception as e:
raise Exception(f"Failed to load {model_name}: {str(e)}")
def get_bias_analyzer(self, model_name: str) -> BiasAnalyzer:
"""Get a BiasAnalyzer for the specified model"""
wrapped_model, tokenizer = self.load_model(model_name)
# Create BiasAnalyzer with the wrapped model
analyzer = BiasAnalyzer(
model=wrapped_model,
tokenizer=tokenizer,
model_type=self.model_configs[model_name]["type"],
splitter_type='string',
batch_size=16,
is_wrapped=True
)
return analyzer