File size: 10,365 Bytes
ec4aa90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
Centralized AI Manager for multiple providers.
Supports Gemini, Nebius Token Factory, and other OpenAI-compatible providers.
"""

import os
import json
import logging
from typing import Dict, Any, Optional, List
from enum import Enum
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

logger = logging.getLogger(__name__)


class AIProvider(Enum):
    """Supported AI providers."""
    GEMINI = "gemini"
    NEBIUS = "nebius"
    OPENAI = "openai"


class AIManager:
    """
    Centralized manager for AI API calls across different providers.
    Provides a unified interface regardless of the underlying provider.
    """
    
    # Default configurations
    DEFAULT_PROVIDER = "gemini"
    DEFAULT_GEMINI_MODEL = "gemini-2.5-flash"
    DEFAULT_NEBIUS_MODEL = "zai-org/GLM-4.5"
    DEFAULT_OPENAI_MODEL = "gpt-4"
    
    # Temperature settings for different use cases
    TEMPERATURE_PRECISE = 0.0  # For JSON schema responses
    TEMPERATURE_LOW = 0.1      # For code generation
    TEMPERATURE_MEDIUM = 0.2   # For transformations
    TEMPERATURE_HIGH = 0.7     # For creative tasks
    
    # Token limits
    MAX_OUTPUT_TOKENS_SMALL = 8192
    MAX_OUTPUT_TOKENS_MEDIUM = 16384
    MAX_OUTPUT_TOKENS_LARGE = 32768
    
    # Retry settings
    MAX_RETRIES = 3
    RETRY_DELAY = 1.0  # seconds
    
    def __init__(self, provider: Optional[str] = None, model: Optional[str] = None):
        """
        Initialize AI Manager.
        
        Args:
            provider: AI provider to use (gemini, nebius, openai). 
                     If None, reads from AI_PROVIDER env var or uses default.
            model: Model name to use. If None, reads from provider-specific env var.
        """
        # Determine provider
        self.provider_name = (
            provider or 
            os.getenv("AI_PROVIDER", self.DEFAULT_PROVIDER)
        ).lower()
        
        try:
            self.provider = AIProvider(self.provider_name)
        except ValueError:
            logger.warning(
                f"Unknown provider '{self.provider_name}', falling back to Gemini"
            )
            self.provider = AIProvider.GEMINI
            self.provider_name = "gemini"
        
        # Initialize provider-specific client
        if self.provider == AIProvider.GEMINI:
            self._init_gemini(model)
        elif self.provider == AIProvider.NEBIUS:
            self._init_nebius(model)
        elif self.provider == AIProvider.OPENAI:
            self._init_openai(model)
        
        logger.info(
            f"AIManager initialized with provider: {self.provider_name}, "
            f"model: {self.model_name}"
        )
    
    def _init_gemini(self, model: Optional[str] = None):
        """Initialize Gemini provider."""
        from google import genai
        
        api_key = os.getenv("GEMINI_API_KEY")
        if not api_key:
            raise ValueError(
                "GEMINI_API_KEY not found in environment variables. "
                "Please set it in your .env file."
            )
        
        self.model_name = (
            model or 
            os.getenv("GEMINI_MODEL", self.DEFAULT_GEMINI_MODEL)
        )
        
        self.client = genai.Client(api_key=api_key)
        self.provider_type = "gemini"
    
    def _init_nebius(self, model: Optional[str] = None):
        """Initialize Nebius Token Factory provider (OpenAI-compatible)."""
        from openai import OpenAI
        
        api_key = os.getenv("NEBIUS_API_KEY")
        if not api_key:
            raise ValueError(
                "NEBIUS_API_KEY not found in environment variables. "
                "Please set it in your .env file."
            )
        
        self.model_name = (
            model or 
            os.getenv("NEBIUS_MODEL", self.DEFAULT_NEBIUS_MODEL)
        )
        
        self.client = OpenAI(
            base_url="https://api.tokenfactory.nebius.com/v1/",
            api_key=api_key
        )
        self.provider_type = "openai_compatible"
    
    def _init_openai(self, model: Optional[str] = None):
        """Initialize OpenAI provider."""
        from openai import OpenAI
        
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError(
                "OPENAI_API_KEY not found in environment variables. "
                "Please set it in your .env file."
            )
        
        self.model_name = (
            model or 
            os.getenv("OPENAI_MODEL", self.DEFAULT_OPENAI_MODEL)
        )
        
        self.client = OpenAI(api_key=api_key)
        self.provider_type = "openai_compatible"
    
    def generate_content(
        self,
        prompt: str,
        temperature: float = TEMPERATURE_LOW,
        max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM,
        response_format: Optional[str] = None,
        response_schema: Optional[Dict[str, Any]] = None,
        system_prompt: Optional[str] = None
    ) -> str:
        """
        Generate content using the configured AI provider.
        
        Args:
            prompt: The prompt to send to the AI
            temperature: Temperature setting (0.0-1.0)
            max_tokens: Maximum output tokens
            response_format: Response format ("json" or None)
            response_schema: JSON schema for structured responses (Gemini format)
            system_prompt: Optional system prompt (for OpenAI-compatible providers)
        
        Returns:
            Generated text content
        """
        if self.provider_type == "gemini":
            return self._generate_gemini(
                prompt, temperature, max_tokens, 
                response_format, response_schema
            )
        else:  # openai_compatible
            return self._generate_openai_compatible(
                prompt, temperature, max_tokens,
                response_format, system_prompt
            )
    
    def _generate_gemini(
        self,
        prompt: str,
        temperature: float,
        max_tokens: int,
        response_format: Optional[str],
        response_schema: Optional[Dict[str, Any]]
    ) -> str:
        """Generate content using Gemini API."""
        config = {
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": 0.95,
        }
        
        # Add JSON schema if provided
        if response_schema:
            config["response_mime_type"] = "application/json"
            config["response_schema"] = response_schema
        elif response_format == "json":
            config["response_mime_type"] = "application/json"
        
        response = self.client.models.generate_content(
            model=self.model_name,
            contents=prompt,
            config=config
        )
        
        return response.text
    
    def _generate_openai_compatible(
        self,
        prompt: str,
        temperature: float,
        max_tokens: int,
        response_format: Optional[str],
        system_prompt: Optional[str]
    ) -> str:
        """Generate content using OpenAI-compatible API."""
        messages = []
        
        # Add system prompt if provided
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        
        messages.append({"role": "user", "content": prompt})
        
        kwargs = {
            "model": self.model_name,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        
        # Add JSON mode if requested
        if response_format == "json":
            kwargs["response_format"] = {"type": "json_object"}
        
        response = self.client.chat.completions.create(**kwargs)
        
        return response.choices[0].message.content
    
    def get_base_config(
        self,
        temperature: float = TEMPERATURE_LOW,
        max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM
    ) -> Dict[str, Any]:
        """
        Get base configuration for AI calls.
        
        Args:
            temperature: Temperature setting (0.0-1.0)
            max_tokens: Maximum output tokens
        
        Returns:
            Configuration dictionary
        """
        return {
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
    
    def get_json_config(
        self,
        schema: Optional[Dict[str, Any]] = None,
        temperature: float = TEMPERATURE_PRECISE,
        max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM
    ) -> Dict[str, Any]:
        """
        Get configuration for JSON schema-enforced responses.
        
        Args:
            schema: JSON schema dictionary (Gemini format)
            temperature: Temperature setting (default: 0.0 for precision)
            max_tokens: Maximum output tokens
        
        Returns:
            Configuration dictionary
        """
        config = self.get_base_config(temperature, max_tokens)
        config["response_format"] = "json"
        
        if schema and self.provider_type == "gemini":
            config["response_schema"] = schema
        
        return config
    
    @classmethod
    def validate_config(cls) -> bool:
        """
        Validate that required configuration is present.
        
        Returns:
            True if configuration is valid
        
        Raises:
            ValueError: If required configuration is missing
        """
        provider = os.getenv("AI_PROVIDER", cls.DEFAULT_PROVIDER).lower()
        
        if provider == "gemini":
            if not os.getenv("GEMINI_API_KEY"):
                raise ValueError(
                    "GEMINI_API_KEY not found in environment variables. "
                    "Please set it in your .env file."
                )
        elif provider == "nebius":
            if not os.getenv("NEBIUS_API_KEY"):
                raise ValueError(
                    "NEBIUS_API_KEY not found in environment variables. "
                    "Please set it in your .env file."
                )
        elif provider == "openai":
            if not os.getenv("OPENAI_API_KEY"):
                raise ValueError(
                    "OPENAI_API_KEY not found in environment variables. "
                    "Please set it in your .env file."
                )
        
        return True