naazimsnh02's picture
Initial deployment: Autonomous AI agent for code modernization
ec4aa90
"""
Centralized AI Manager for multiple providers.
Supports Gemini, Nebius Token Factory, and other OpenAI-compatible providers.
"""
import os
import json
import logging
from typing import Dict, Any, Optional, List
from enum import Enum
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
logger = logging.getLogger(__name__)
class AIProvider(Enum):
"""Supported AI providers."""
GEMINI = "gemini"
NEBIUS = "nebius"
OPENAI = "openai"
class AIManager:
"""
Centralized manager for AI API calls across different providers.
Provides a unified interface regardless of the underlying provider.
"""
# Default configurations
DEFAULT_PROVIDER = "gemini"
DEFAULT_GEMINI_MODEL = "gemini-2.5-flash"
DEFAULT_NEBIUS_MODEL = "zai-org/GLM-4.5"
DEFAULT_OPENAI_MODEL = "gpt-4"
# Temperature settings for different use cases
TEMPERATURE_PRECISE = 0.0 # For JSON schema responses
TEMPERATURE_LOW = 0.1 # For code generation
TEMPERATURE_MEDIUM = 0.2 # For transformations
TEMPERATURE_HIGH = 0.7 # For creative tasks
# Token limits
MAX_OUTPUT_TOKENS_SMALL = 8192
MAX_OUTPUT_TOKENS_MEDIUM = 16384
MAX_OUTPUT_TOKENS_LARGE = 32768
# Retry settings
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
def __init__(self, provider: Optional[str] = None, model: Optional[str] = None):
"""
Initialize AI Manager.
Args:
provider: AI provider to use (gemini, nebius, openai).
If None, reads from AI_PROVIDER env var or uses default.
model: Model name to use. If None, reads from provider-specific env var.
"""
# Determine provider
self.provider_name = (
provider or
os.getenv("AI_PROVIDER", self.DEFAULT_PROVIDER)
).lower()
try:
self.provider = AIProvider(self.provider_name)
except ValueError:
logger.warning(
f"Unknown provider '{self.provider_name}', falling back to Gemini"
)
self.provider = AIProvider.GEMINI
self.provider_name = "gemini"
# Initialize provider-specific client
if self.provider == AIProvider.GEMINI:
self._init_gemini(model)
elif self.provider == AIProvider.NEBIUS:
self._init_nebius(model)
elif self.provider == AIProvider.OPENAI:
self._init_openai(model)
logger.info(
f"AIManager initialized with provider: {self.provider_name}, "
f"model: {self.model_name}"
)
def _init_gemini(self, model: Optional[str] = None):
"""Initialize Gemini provider."""
from google import genai
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError(
"GEMINI_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
self.model_name = (
model or
os.getenv("GEMINI_MODEL", self.DEFAULT_GEMINI_MODEL)
)
self.client = genai.Client(api_key=api_key)
self.provider_type = "gemini"
def _init_nebius(self, model: Optional[str] = None):
"""Initialize Nebius Token Factory provider (OpenAI-compatible)."""
from openai import OpenAI
api_key = os.getenv("NEBIUS_API_KEY")
if not api_key:
raise ValueError(
"NEBIUS_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
self.model_name = (
model or
os.getenv("NEBIUS_MODEL", self.DEFAULT_NEBIUS_MODEL)
)
self.client = OpenAI(
base_url="https://api.tokenfactory.nebius.com/v1/",
api_key=api_key
)
self.provider_type = "openai_compatible"
def _init_openai(self, model: Optional[str] = None):
"""Initialize OpenAI provider."""
from openai import OpenAI
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"OPENAI_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
self.model_name = (
model or
os.getenv("OPENAI_MODEL", self.DEFAULT_OPENAI_MODEL)
)
self.client = OpenAI(api_key=api_key)
self.provider_type = "openai_compatible"
def generate_content(
self,
prompt: str,
temperature: float = TEMPERATURE_LOW,
max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM,
response_format: Optional[str] = None,
response_schema: Optional[Dict[str, Any]] = None,
system_prompt: Optional[str] = None
) -> str:
"""
Generate content using the configured AI provider.
Args:
prompt: The prompt to send to the AI
temperature: Temperature setting (0.0-1.0)
max_tokens: Maximum output tokens
response_format: Response format ("json" or None)
response_schema: JSON schema for structured responses (Gemini format)
system_prompt: Optional system prompt (for OpenAI-compatible providers)
Returns:
Generated text content
"""
if self.provider_type == "gemini":
return self._generate_gemini(
prompt, temperature, max_tokens,
response_format, response_schema
)
else: # openai_compatible
return self._generate_openai_compatible(
prompt, temperature, max_tokens,
response_format, system_prompt
)
def _generate_gemini(
self,
prompt: str,
temperature: float,
max_tokens: int,
response_format: Optional[str],
response_schema: Optional[Dict[str, Any]]
) -> str:
"""Generate content using Gemini API."""
config = {
"temperature": temperature,
"max_output_tokens": max_tokens,
"top_p": 0.95,
}
# Add JSON schema if provided
if response_schema:
config["response_mime_type"] = "application/json"
config["response_schema"] = response_schema
elif response_format == "json":
config["response_mime_type"] = "application/json"
response = self.client.models.generate_content(
model=self.model_name,
contents=prompt,
config=config
)
return response.text
def _generate_openai_compatible(
self,
prompt: str,
temperature: float,
max_tokens: int,
response_format: Optional[str],
system_prompt: Optional[str]
) -> str:
"""Generate content using OpenAI-compatible API."""
messages = []
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
kwargs = {
"model": self.model_name,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
# Add JSON mode if requested
if response_format == "json":
kwargs["response_format"] = {"type": "json_object"}
response = self.client.chat.completions.create(**kwargs)
return response.choices[0].message.content
def get_base_config(
self,
temperature: float = TEMPERATURE_LOW,
max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM
) -> Dict[str, Any]:
"""
Get base configuration for AI calls.
Args:
temperature: Temperature setting (0.0-1.0)
max_tokens: Maximum output tokens
Returns:
Configuration dictionary
"""
return {
"temperature": temperature,
"max_tokens": max_tokens,
}
def get_json_config(
self,
schema: Optional[Dict[str, Any]] = None,
temperature: float = TEMPERATURE_PRECISE,
max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM
) -> Dict[str, Any]:
"""
Get configuration for JSON schema-enforced responses.
Args:
schema: JSON schema dictionary (Gemini format)
temperature: Temperature setting (default: 0.0 for precision)
max_tokens: Maximum output tokens
Returns:
Configuration dictionary
"""
config = self.get_base_config(temperature, max_tokens)
config["response_format"] = "json"
if schema and self.provider_type == "gemini":
config["response_schema"] = schema
return config
@classmethod
def validate_config(cls) -> bool:
"""
Validate that required configuration is present.
Returns:
True if configuration is valid
Raises:
ValueError: If required configuration is missing
"""
provider = os.getenv("AI_PROVIDER", cls.DEFAULT_PROVIDER).lower()
if provider == "gemini":
if not os.getenv("GEMINI_API_KEY"):
raise ValueError(
"GEMINI_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
elif provider == "nebius":
if not os.getenv("NEBIUS_API_KEY"):
raise ValueError(
"NEBIUS_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
elif provider == "openai":
if not os.getenv("OPENAI_API_KEY"):
raise ValueError(
"OPENAI_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
return True