Spaces:
Running
on
Zero
Running
on
Zero
| """Gradio UI setup""" | |
| import os | |
| import time | |
| import gradio as gr | |
| import spaces | |
| from config import TITLE, DESCRIPTION, CSS, MEDSWIN_MODELS, DEFAULT_MEDICAL_MODEL | |
| import config | |
| from indexing import create_or_update_index | |
| from pipeline import stream_chat | |
| from voice import transcribe_audio, generate_speech | |
| from models import ( | |
| initialize_medical_model, | |
| is_model_loaded, | |
| get_model_loading_state, | |
| set_model_loading_state, | |
| initialize_tts_model, | |
| initialize_whisper_model, | |
| TTS_AVAILABLE, | |
| SNAC_AVAILABLE, | |
| WHISPER_AVAILABLE, | |
| ) | |
| from logger import logger | |
| MAX_DURATION = 120 | |
| def create_demo(): | |
| """Create and return Gradio demo interface""" | |
| with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(TITLE) | |
| gr.HTML(DESCRIPTION) | |
| with gr.Row(elem_classes="main-container"): | |
| with gr.Column(elem_classes="upload-section"): | |
| file_upload = gr.File( | |
| file_count="multiple", | |
| label="Drag and Drop Files Here", | |
| file_types=[".pdf", ".txt", ".doc", ".docx", ".md", ".json", ".xml", ".csv"], | |
| elem_id="file-upload" | |
| ) | |
| upload_button = gr.Button("Upload & Index", elem_classes="upload-button") | |
| status_output = gr.Textbox( | |
| label="Status", | |
| placeholder="Upload files to start...", | |
| interactive=False | |
| ) | |
| file_info_output = gr.HTML( | |
| label="File Information", | |
| elem_classes="processing-info" | |
| ) | |
| upload_button.click( | |
| fn=create_or_update_index, | |
| inputs=[file_upload], | |
| outputs=[status_output, file_info_output] | |
| ) | |
| with gr.Column(elem_classes="chatbot-container"): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| placeholder="Chat with MedSwin... Type your question below.", | |
| show_label=False, | |
| type="messages" | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| message_input = gr.Textbox( | |
| placeholder="Type your medical question here...", | |
| show_label=False, | |
| container=False, | |
| lines=1, | |
| scale=10 | |
| ) | |
| mic_button = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="", | |
| show_label=False, | |
| container=False, | |
| scale=1 | |
| ) | |
| submit_button = gr.Button("➤", elem_classes="submit-btn", scale=1) | |
| recording_timer = gr.Textbox( | |
| value="", | |
| label="", | |
| show_label=False, | |
| interactive=False, | |
| visible=False, | |
| container=False, | |
| elem_classes="recording-timer" | |
| ) | |
| recording_start_time = [None] | |
| def handle_recording_start(): | |
| """Called when recording starts""" | |
| recording_start_time[0] = time.time() | |
| return gr.update(visible=True, value="Recording... 0s") | |
| def handle_recording_stop(audio): | |
| """Called when recording stops""" | |
| recording_start_time[0] = None | |
| if audio is None: | |
| return gr.update(visible=False, value=""), "" | |
| transcribed = transcribe_audio(audio) | |
| return gr.update(visible=False, value=""), transcribed | |
| mic_button.start_recording( | |
| fn=handle_recording_start, | |
| outputs=[recording_timer] | |
| ) | |
| mic_button.stop_recording( | |
| fn=handle_recording_stop, | |
| inputs=[mic_button], | |
| outputs=[recording_timer, message_input] | |
| ) | |
| with gr.Row(): | |
| tts_button = gr.Button("🔊 Play Response", visible=False, size="sm") | |
| tts_audio = gr.Audio(label="", visible=True, autoplay=True, show_label=False, container=False) | |
| def generate_speech_from_chat(history): | |
| """Extract last assistant message and generate speech""" | |
| if not history or len(history) == 0: | |
| logger.warning("[TTS] No history available") | |
| return None | |
| last_msg = history[-1] | |
| if last_msg.get("role") == "assistant": | |
| text = last_msg.get("content", "").replace(" 🔊", "").strip() | |
| if text: | |
| logger.info(f"[TTS] Generating speech for text: {text[:100]}...") | |
| try: | |
| audio_path = generate_speech(text) | |
| if audio_path and os.path.exists(audio_path): | |
| logger.info(f"[TTS] ✅ Generated audio successfully: {audio_path}") | |
| return audio_path | |
| else: | |
| logger.warning(f"[TTS] ❌ Failed to generate audio or file doesn't exist: {audio_path}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"[TTS] Error generating speech: {e}") | |
| import traceback | |
| logger.debug(f"[TTS] Traceback: {traceback.format_exc()}") | |
| return None | |
| else: | |
| logger.warning("[TTS] Empty text extracted from assistant message") | |
| else: | |
| logger.warning(f"[TTS] Last message is not from assistant: {last_msg.get('role')}") | |
| return None | |
| def update_tts_button(history): | |
| if history and len(history) > 0 and history[-1].get("role") == "assistant": | |
| return gr.update(visible=True) | |
| return gr.update(visible=False) | |
| chatbot.change( | |
| fn=update_tts_button, | |
| inputs=[chatbot], | |
| outputs=[tts_button] | |
| ) | |
| tts_button.click( | |
| fn=generate_speech_from_chat, | |
| inputs=[chatbot], | |
| outputs=[tts_audio] | |
| ) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| disable_agentic_reasoning = gr.Checkbox( | |
| value=False, | |
| label="Disable agentic reasoning", | |
| info="Use MedSwin model alone without agentic reasoning, RAG, or web search" | |
| ) | |
| show_agentic_thought = gr.Button( | |
| "Show agentic thought", | |
| size="sm" | |
| ) | |
| enable_clinical_intake = gr.Checkbox( | |
| value=True, | |
| label="Enable clinical intake (max 5 Q&A)", | |
| info="Ask focused follow-up questions before breaking down the case" | |
| ) | |
| agentic_thoughts_box = gr.Textbox( | |
| label="Agentic Thoughts", | |
| placeholder="Internal thoughts from MedSwin and supervisor will appear here...", | |
| lines=8, | |
| max_lines=15, | |
| interactive=False, | |
| visible=False, | |
| elem_classes="agentic-thoughts" | |
| ) | |
| with gr.Row(): | |
| use_rag = gr.Checkbox( | |
| value=False, | |
| label="Enable Document RAG", | |
| info="Answer based on uploaded documents (upload required)" | |
| ) | |
| use_web_search = gr.Checkbox( | |
| value=False, | |
| label="Enable Web Search (MCP)", | |
| info="Fetch knowledge from online medical resources" | |
| ) | |
| medical_model = gr.Radio( | |
| choices=list(MEDSWIN_MODELS.keys()), | |
| value=DEFAULT_MEDICAL_MODEL, | |
| label="Medical Model", | |
| info="MedSwin DT (default), others download on selection" | |
| ) | |
| model_status = gr.Textbox( | |
| value="Checking model status...", | |
| label="Model Status", | |
| interactive=False, | |
| visible=True, | |
| lines=3, | |
| max_lines=3, | |
| elem_classes="model-status" | |
| ) | |
| system_prompt = gr.Textbox( | |
| value="As a medical specialist, provide detailed and accurate answers based on the provided medical documents and context. Ensure all information is clinically accurate and cite sources when available. Provide answers directly without conversational prefixes like 'Here is...', 'This is...', or 'To answer your question...'. Start with the actual content immediately.", | |
| label="System Prompt", | |
| lines=3 | |
| ) | |
| with gr.Tab("Generation Parameters"): | |
| temperature = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.2, | |
| label="Temperature" | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=512, | |
| maximum=4096, | |
| step=128, | |
| value=2048, | |
| label="Max New Tokens", | |
| info="Increased for medical models to prevent early stopping" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.7, | |
| label="Top P" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, | |
| label="Top K" | |
| ) | |
| penalty = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.2, | |
| label="Repetition Penalty" | |
| ) | |
| with gr.Tab("Retrieval Parameters"): | |
| retriever_k = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| step=1, | |
| value=15, | |
| label="Initial Retrieval Size (Top K)" | |
| ) | |
| merge_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| step=0.1, | |
| value=0.5, | |
| label="Merge Threshold (lower = more merging)" | |
| ) | |
| # MedSwin Model Links | |
| gr.Markdown( | |
| """ | |
| <div style="margin-top: 20px; padding: 15px; background-color: #f5f5f5; border-radius: 8px;"> | |
| <h4 style="margin-top: 0; margin-bottom: 10px;">🔗 MedSwin Models on Hugging Face</h4> | |
| <div style="display: flex; flex-wrap: wrap; gap: 10px;"> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DT</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-NuSLERP-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Nsl</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-DaRE-Linear-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin DL</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-TIES-KD-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin Ti</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-Merged-TA-SFT-0.7" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin TA</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-7B-SFT" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin SFT</a> | |
| <a href="https://huggingface.co/MedSwin/MedSwin-7B-KD" target="_blank" style="text-decoration: none; padding: 5px 10px; background-color: #1a73e8; color: white; border-radius: 4px; font-size: 12px;">MedSwin KD</a> | |
| </div> | |
| <p style="margin-top: 10px; margin-bottom: 0; font-size: 11px; color: #666;">Click any model name to view details on Hugging Face</p> | |
| </div> | |
| """ | |
| ) | |
| show_thoughts_state = gr.State(value=False) | |
| def toggle_thoughts_box(current_state): | |
| """Toggle visibility of agentic thoughts box""" | |
| new_state = not current_state | |
| return gr.update(visible=new_state), new_state | |
| show_agentic_thought.click( | |
| fn=toggle_thoughts_box, | |
| inputs=[show_thoughts_state], | |
| outputs=[agentic_thoughts_box, show_thoughts_state] | |
| ) | |
| # GPU-decorated function to load any model (for user selection) | |
| # @spaces.GPU(max_duration=MAX_DURATION) | |
| def load_model_with_gpu(model_name): | |
| """Load medical model (GPU-decorated for ZeroGPU compatibility)""" | |
| try: | |
| if not is_model_loaded(model_name): | |
| logger.info(f"Loading medical model: {model_name}...") | |
| set_model_loading_state(model_name, "loading") | |
| try: | |
| initialize_medical_model(model_name) | |
| logger.info(f"✅ Medical model {model_name} loaded successfully!") | |
| return "✅ The model has been loaded successfully", True | |
| except Exception as e: | |
| logger.error(f"Failed to load medical model {model_name}: {e}") | |
| set_model_loading_state(model_name, "error") | |
| return f"❌ Error loading model: {str(e)[:100]}", False | |
| else: | |
| logger.info(f"Medical model {model_name} is already loaded") | |
| return "✅ The model has been loaded successfully", True | |
| except Exception as e: | |
| logger.error(f"Error loading model {model_name}: {e}") | |
| return f"❌ Error: {str(e)[:100]}", False | |
| def load_model_and_update_status(model_name): | |
| """Load model and update status, return status text and whether model is ready""" | |
| try: | |
| status_lines = [] | |
| # Medical model status | |
| if is_model_loaded(model_name): | |
| status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") | |
| else: | |
| state = get_model_loading_state(model_name) | |
| if state == "loading": | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| elif state == "error": | |
| status_lines.append(f"❌ MedSwin ({model_name}): error loading") | |
| else: | |
| # Use GPU-decorated function to load the model | |
| try: | |
| result = load_model_with_gpu(model_name) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| if is_ready: | |
| status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") | |
| else: | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| else: | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| except Exception as e: | |
| logger.error(f"Error calling load_model_with_gpu: {e}") | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| # TTS model status (only show if available or if there's an issue) | |
| if SNAC_AVAILABLE: | |
| if config.global_tts_model is not None: | |
| status_lines.append("✅ TTS (maya1): loaded and ready") | |
| else: | |
| # TTS available but not loaded - optional feature | |
| pass # Don't show if not loaded, it's optional | |
| # Don't show TTS status if library not available (it's optional) | |
| # ASR (Whisper) model status | |
| if WHISPER_AVAILABLE: | |
| if config.global_whisper_model is not None: | |
| status_lines.append("✅ ASR (Whisper): loaded and ready") | |
| else: | |
| status_lines.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_lines.append("❌ ASR: library not available") | |
| status_text = "\n".join(status_lines) | |
| is_ready = is_model_loaded(model_name) | |
| return status_text, is_ready | |
| except Exception as e: | |
| return f"❌ Error: {str(e)[:100]}", False | |
| def check_model_status(model_name): | |
| """Check current model status without loading""" | |
| status_lines = [] | |
| # Medical model status | |
| if is_model_loaded(model_name): | |
| status_lines.append(f"✅ MedSwin ({model_name}): loaded and ready") | |
| else: | |
| state = get_model_loading_state(model_name) | |
| if state == "loading": | |
| status_lines.append(f"⏳ MedSwin ({model_name}): loading...") | |
| elif state == "error": | |
| status_lines.append(f"❌ MedSwin ({model_name}): error loading") | |
| else: | |
| status_lines.append(f"⚠️ MedSwin ({model_name}): not loaded") | |
| # TTS model status (only show if available and loaded) | |
| if SNAC_AVAILABLE: | |
| if config.global_tts_model is not None: | |
| status_lines.append("✅ TTS (maya1): loaded and ready") | |
| # Don't show if TTS library available but model not loaded (optional feature) | |
| # Don't show TTS status if library not available (it's optional) | |
| # ASR (Whisper) model status | |
| if WHISPER_AVAILABLE: | |
| if config.global_whisper_model is not None: | |
| status_lines.append("✅ ASR (Whisper): loaded and ready") | |
| else: | |
| status_lines.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_lines.append("❌ ASR: library not available") | |
| status_text = "\n".join(status_lines) | |
| is_ready = is_model_loaded(model_name) | |
| return status_text, is_ready | |
| # GPU-decorated function to load ONLY medical model on startup | |
| # According to ZeroGPU best practices: | |
| # 1. Load models to CPU in global scope (no GPU decorator needed) | |
| # 2. Move models to GPU only in inference functions (with @spaces.GPU decorator) | |
| # However, for large models, loading to CPU then moving to GPU uses more memory | |
| # So we use a hybrid approach: load to GPU directly but within GPU-decorated function | |
| def load_medical_model_on_startup_cpu(): | |
| """ | |
| Load model to CPU on startup (ZeroGPU best practice - no GPU decorator needed) | |
| Model will be moved to GPU during first inference | |
| """ | |
| status_messages = [] | |
| try: | |
| # Load only medical model (MedSwin) to CPU - TTS and Whisper load on-demand | |
| if not is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| logger.info(f"[STARTUP] Loading medical model to CPU: {DEFAULT_MEDICAL_MODEL}...") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") | |
| try: | |
| # Load to CPU (no GPU decorator needed) | |
| initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=False) | |
| # Verify model is actually loaded | |
| if is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to CPU") | |
| logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to CPU successfully!") | |
| else: | |
| status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed") | |
| logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| except Exception as e: | |
| status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}") | |
| logger.error(f"[STARTUP] Failed to load medical model: {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| else: | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded") | |
| logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded") | |
| # Add ASR status (will load on first use) | |
| if WHISPER_AVAILABLE: | |
| status_messages.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_messages.append("❌ ASR: library not available") | |
| # Return status | |
| status_text = "\n".join(status_messages) | |
| logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}") | |
| return status_text | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"[STARTUP] Error loading model to CPU: {error_msg}") | |
| return f"⚠️ Error loading model: {error_msg[:100]}" | |
| # Alternative: Load directly to GPU (requires GPU decorator) | |
| # @spaces.GPU(max_duration=MAX_DURATION) | |
| def load_medical_model_on_startup_gpu(): | |
| """ | |
| Load model directly to GPU on startup (alternative approach) | |
| Uses GPU quota but model is immediately ready for inference | |
| """ | |
| import torch | |
| status_messages = [] | |
| try: | |
| # Clear GPU cache at start | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("[STARTUP] Cleared GPU cache before model loading") | |
| # Load only medical model (MedSwin) - TTS and Whisper load on-demand | |
| if not is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| logger.info(f"[STARTUP] Loading medical model to GPU: {DEFAULT_MEDICAL_MODEL}...") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "loading") | |
| try: | |
| # Load directly to GPU (within GPU-decorated function) | |
| initialize_medical_model(DEFAULT_MEDICAL_MODEL, load_to_gpu=True) | |
| # Verify model is actually loaded | |
| if is_model_loaded(DEFAULT_MEDICAL_MODEL): | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): loaded to GPU") | |
| logger.info(f"[STARTUP] ✅ Medical model {DEFAULT_MEDICAL_MODEL} loaded to GPU successfully!") | |
| else: | |
| status_messages.append(f"⚠️ MedSwin ({DEFAULT_MEDICAL_MODEL}): loading failed") | |
| logger.warning(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} initialization completed but not marked as loaded") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| except Exception as e: | |
| status_messages.append(f"❌ MedSwin ({DEFAULT_MEDICAL_MODEL}): error - {str(e)[:50]}") | |
| logger.error(f"[STARTUP] Failed to load medical model: {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| set_model_loading_state(DEFAULT_MEDICAL_MODEL, "error") | |
| else: | |
| status_messages.append(f"✅ MedSwin ({DEFAULT_MEDICAL_MODEL}): already loaded") | |
| logger.info(f"[STARTUP] Medical model {DEFAULT_MEDICAL_MODEL} already loaded") | |
| # Add ASR status (will load on first use) | |
| if WHISPER_AVAILABLE: | |
| status_messages.append("⏳ ASR (Whisper): will load on first use") | |
| else: | |
| status_messages.append("❌ ASR: library not available") | |
| # Clear cache after loading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.debug("[STARTUP] Cleared GPU cache after model loading") | |
| # Return status | |
| status_text = "\n".join(status_messages) | |
| logger.info(f"[STARTUP] ✅ Model loading complete. Status:\n{status_text}") | |
| return status_text | |
| except Exception as e: | |
| error_msg = str(e) | |
| # Check if it's a ZeroGPU quota/rate limit error | |
| is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or | |
| "quota" in error_msg.lower() or "ZeroGPU" in error_msg or | |
| "runnning out" in error_msg.lower() or "running out" in error_msg.lower()) | |
| if is_quota_error: | |
| logger.warning(f"[STARTUP] ZeroGPU quota/rate limit error detected: {error_msg[:100]}") | |
| # Return status message indicating quota error (will be handled by retry logic) | |
| status_messages.append("⚠️ ZeroGPU quota error - will retry") | |
| status_text = "\n".join(status_messages) | |
| # Also add ASR status | |
| if WHISPER_AVAILABLE: | |
| status_text += "\n⏳ ASR (Whisper): will load on first use" | |
| return status_text # Return status instead of raising, let wrapper handle retry | |
| logger.error(f"[STARTUP] ❌ Error in model loading startup: {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| # Clear cache on error | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return f"⚠️ Startup loading error: {str(e)[:100]}" | |
| # Initialize status on load | |
| def init_model_status(): | |
| try: | |
| result = check_model_status(DEFAULT_MEDICAL_MODEL) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| return status_text | |
| else: | |
| return "⚠️ Unable to check model status" | |
| except Exception as e: | |
| logger.error(f"Error in init_model_status: {e}") | |
| return f"⚠️ Error: {str(e)[:100]}" | |
| # Update status when model selection changes | |
| def update_model_status_on_change(model_name): | |
| try: | |
| result = check_model_status(model_name) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| return status_text | |
| else: | |
| return "⚠️ Unable to check model status" | |
| except Exception as e: | |
| logger.error(f"Error in update_model_status_on_change: {e}") | |
| return f"⚠️ Error: {str(e)[:100]}" | |
| # Handle model selection change | |
| def on_model_change(model_name): | |
| try: | |
| result = load_model_and_update_status(model_name) | |
| if result and isinstance(result, tuple) and len(result) == 2: | |
| status_text, is_ready = result | |
| submit_enabled = is_ready | |
| return ( | |
| status_text, | |
| gr.update(interactive=submit_enabled), | |
| gr.update(interactive=submit_enabled) | |
| ) | |
| else: | |
| error_msg = "⚠️ Unable to load model status" | |
| return ( | |
| error_msg, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in on_model_change: {e}") | |
| error_msg = f"⚠️ Error: {str(e)[:100]}" | |
| return ( | |
| error_msg, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False) | |
| ) | |
| # Update status display periodically or on model status changes | |
| def refresh_model_status(model_name): | |
| return update_model_status_on_change(model_name) | |
| medical_model.change( | |
| fn=on_model_change, | |
| inputs=[medical_model], | |
| outputs=[model_status, submit_button, message_input] | |
| ) | |
| # GPU-decorated function to load Whisper ASR model on-demand | |
| # @spaces.GPU(max_duration=MAX_DURATION) | |
| def load_whisper_model_on_demand(): | |
| """Load Whisper ASR model when needed""" | |
| try: | |
| if WHISPER_AVAILABLE and config.global_whisper_model is None: | |
| logger.info("[ASR] Loading Whisper model on-demand...") | |
| initialize_whisper_model() | |
| if config.global_whisper_model is not None: | |
| logger.info("[ASR] ✅ Whisper model loaded successfully!") | |
| return "✅ ASR (Whisper): loaded" | |
| else: | |
| logger.warning("[ASR] ⚠️ Whisper model failed to load") | |
| return "⚠️ ASR (Whisper): failed to load" | |
| elif config.global_whisper_model is not None: | |
| return "✅ ASR (Whisper): already loaded" | |
| else: | |
| return "❌ ASR: library not available" | |
| except Exception as e: | |
| logger.error(f"[ASR] Error loading Whisper model: {e}") | |
| return f"❌ ASR: error - {str(e)[:100]}" | |
| # Load medical model on startup and update status | |
| # Use a wrapper to handle GPU context properly with retry logic | |
| def load_startup_and_update_ui(): | |
| """ | |
| Load model on startup with retry logic (max 3 attempts) and return status with UI updates | |
| Uses CPU-first approach (ZeroGPU best practice): | |
| - Load model to CPU (no GPU decorator needed, avoids quota issues) | |
| - Model will be moved to GPU during first inference | |
| """ | |
| import time | |
| max_retries = 3 | |
| base_delay = 5.0 # Start with 5 seconds delay | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| logger.info(f"[STARTUP] Attempt {attempt}/{max_retries} to load medical model to CPU...") | |
| # Use CPU-first approach (no GPU decorator, avoids quota issues) | |
| status_text = load_medical_model_on_startup_cpu() | |
| # Check if model is ready and update submit button state | |
| is_ready = is_model_loaded(DEFAULT_MEDICAL_MODEL) | |
| if is_ready: | |
| logger.info(f"[STARTUP] ✅ Model loaded successfully on attempt {attempt}") | |
| return status_text, gr.update(interactive=is_ready), gr.update(interactive=is_ready) | |
| else: | |
| # Check if status text indicates quota error | |
| if status_text and ("quota" in status_text.lower() or "ZeroGPU" in status_text or | |
| "429" in status_text or "runnning out" in status_text.lower() or | |
| "running out" in status_text.lower()): | |
| if attempt < max_retries: | |
| delay = base_delay * attempt | |
| logger.warning(f"[STARTUP] Quota error detected in status, retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| # Quota exhausted after retries - allow user to proceed, model will load on-demand | |
| status_msg = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown." | |
| logger.info("[STARTUP] Quota exhausted after retries - allowing user to proceed with on-demand loading") | |
| return status_msg, gr.update(interactive=True), gr.update(interactive=True) | |
| # Model didn't load, but no exception - might be a state issue | |
| logger.warning(f"[STARTUP] Model not ready after attempt {attempt}, but no error") | |
| if attempt < max_retries: | |
| delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s | |
| logger.info(f"[STARTUP] Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| # Even if model didn't load, allow user to try selecting another model | |
| return status_text + "\n⚠️ Model not loaded. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True) | |
| except Exception as e: | |
| error_msg = str(e) | |
| is_quota_error = ("429" in error_msg or "Too Many Requests" in error_msg or | |
| "quota" in error_msg.lower() or "ZeroGPU" in error_msg or | |
| "runnning out" in error_msg.lower() or "running out" in error_msg.lower()) | |
| if is_quota_error and attempt < max_retries: | |
| delay = base_delay * attempt # Exponential backoff: 5s, 10s, 15s | |
| logger.warning(f"[STARTUP] ZeroGPU rate limit/quota error on attempt {attempt}/{max_retries}") | |
| logger.info(f"[STARTUP] Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| continue | |
| else: | |
| logger.error(f"[STARTUP] Error in load_startup_and_update_ui (attempt {attempt}/{max_retries}): {e}") | |
| import traceback | |
| logger.debug(f"[STARTUP] Full traceback: {traceback.format_exc()}") | |
| if is_quota_error: | |
| # If quota exhausted, allow user to proceed - model will load on-demand | |
| error_display = "⚠️ ZeroGPU quota exhausted.\n⏳ Model will load automatically when you send a message.\n💡 You can also select a model from the dropdown." | |
| logger.info("[STARTUP] Quota exhausted - allowing user to proceed with on-demand loading") | |
| return error_display, gr.update(interactive=True), gr.update(interactive=True) | |
| else: | |
| error_display = f"⚠️ Startup error: {str(e)[:100]}" | |
| if attempt >= max_retries: | |
| logger.error(f"[STARTUP] Failed after {max_retries} attempts") | |
| return error_display, gr.update(interactive=False), gr.update(interactive=False) | |
| # Should not reach here, but just in case | |
| return "⚠️ Startup failed after retries. Please select a model from dropdown.", gr.update(interactive=True), gr.update(interactive=True) | |
| demo.load( | |
| fn=load_startup_and_update_ui, | |
| inputs=None, | |
| outputs=[model_status, submit_button, message_input] | |
| ) | |
| # Note: We removed the preload on focus functionality because: | |
| # 1. Model loading requires GPU access (device_map="auto" needs GPU in ZeroGPU) | |
| # 2. The startup function already loads the model with GPU decorator | |
| # 3. Preloading without GPU decorator would fail or cause conflicts | |
| # 4. If startup fails, user can select a model from dropdown to trigger loading | |
| # Wrap stream_chat - ensure model is loaded before starting (don't load inside stream_chat to save time) | |
| def stream_chat_with_model_check( | |
| message, history, system_prompt, temperature, max_new_tokens, | |
| top_p, top_k, penalty, retriever_k, merge_threshold, | |
| use_rag, medical_model_name, use_web_search, | |
| enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request: gr.Request = None | |
| ): | |
| # Check if model is loaded - if not, show error (don't load here to save stream_chat time) | |
| model_loaded = is_model_loaded(medical_model_name) | |
| if not model_loaded: | |
| loading_state = get_model_loading_state(medical_model_name) | |
| # Debug logging to understand why model check fails | |
| logger.debug(f"[STREAM_CHAT] Model check: name={medical_model_name}, loaded={model_loaded}, state={loading_state}, in_dict={medical_model_name in config.global_medical_models}, model_exists={config.global_medical_models.get(medical_model_name) is not None if medical_model_name in config.global_medical_models else False}") | |
| if loading_state == "loading": | |
| error_msg = f"⏳ {medical_model_name} is still loading. Please wait until the model status shows 'loaded and ready' before sending messages." | |
| else: | |
| error_msg = f"⚠️ {medical_model_name} is not loaded. Please wait a moment for the model to finish loading, or select a model from the dropdown to load it." | |
| updated_history = history + [{"role": "assistant", "content": error_msg}] | |
| yield updated_history, "" | |
| return | |
| # If request is None, create a mock request for compatibility | |
| if request is None: | |
| class MockRequest: | |
| session_hash = "anonymous" | |
| request = MockRequest() | |
| # Model is loaded, proceed with stream_chat (no model loading here to save time) | |
| # Note: We handle "BodyStreamBuffer was aborted" errors by catching stream disconnections | |
| # and not attempting to yield after the client has disconnected | |
| last_result = None | |
| stream_aborted = False | |
| try: | |
| for result in stream_chat( | |
| message, history, system_prompt, temperature, max_new_tokens, | |
| top_p, top_k, penalty, retriever_k, merge_threshold, | |
| use_rag, medical_model_name, use_web_search, | |
| enable_clinical_intake, disable_agentic_reasoning, show_thoughts, request | |
| ): | |
| last_result = result | |
| try: | |
| yield result | |
| except (GeneratorExit, StopIteration, RuntimeError) as stream_error: | |
| # Stream was closed/aborted by client - don't try to yield again | |
| error_msg_lower = str(stream_error).lower() | |
| if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: | |
| logger.info(f"[UI] Stream was aborted by client, stopping gracefully") | |
| stream_aborted = True | |
| break | |
| raise | |
| except (GeneratorExit, StopIteration) as stream_exit: | |
| # Stream was closed - this is normal, just log and exit | |
| logger.info(f"[UI] Stream closed normally") | |
| stream_aborted = True | |
| return | |
| except Exception as e: | |
| # Handle any errors gracefully | |
| error_str = str(e) | |
| error_msg_lower = error_str.lower() | |
| # Check if this is a stream abort error | |
| is_stream_abort = ( | |
| 'bodystreambuffer' in error_msg_lower or | |
| 'stream' in error_msg_lower and 'abort' in error_msg_lower or | |
| 'connection' in error_msg_lower and 'abort' in error_msg_lower or | |
| isinstance(e, (GeneratorExit, StopIteration, RuntimeError)) and 'abort' in error_msg_lower | |
| ) | |
| if is_stream_abort: | |
| logger.info(f"[UI] Stream was aborted (BodyStreamBuffer or similar): {error_str[:100]}") | |
| stream_aborted = True | |
| # If we have a result, it was already yielded, so just return | |
| return | |
| is_gpu_timeout = 'gpu task aborted' in error_msg_lower or 'timeout' in error_msg_lower | |
| logger.error(f"Error in stream_chat_with_model_check: {error_str}") | |
| import traceback | |
| logger.debug(f"Full traceback: {traceback.format_exc()}") | |
| # Check if we have a valid answer in the last result | |
| has_valid_answer = False | |
| if last_result is not None: | |
| try: | |
| last_history, last_thoughts = last_result | |
| # Find the last assistant message in the history | |
| if last_history and isinstance(last_history, list): | |
| for msg in reversed(last_history): | |
| if isinstance(msg, dict) and msg.get("role") == "assistant": | |
| assistant_content = msg.get("content", "") | |
| # Check if it's a valid answer (not empty, not an error message) | |
| if assistant_content and len(assistant_content.strip()) > 0: | |
| # Not an error message | |
| if not assistant_content.strip().startswith("⚠️") and not assistant_content.strip().startswith("⏳"): | |
| has_valid_answer = True | |
| break | |
| except Exception as parse_error: | |
| logger.debug(f"Error parsing last_result: {parse_error}") | |
| # If stream was aborted, don't try to yield - just return | |
| if stream_aborted: | |
| logger.info(f"[UI] Stream was aborted, not yielding final result") | |
| return | |
| # If we have a valid answer, use it (don't show error message) | |
| if has_valid_answer: | |
| logger.info(f"[UI] Error occurred but final answer already generated, displaying it without error message") | |
| try: | |
| yield last_result | |
| except (GeneratorExit, StopIteration, RuntimeError) as yield_error: | |
| error_msg_lower = str(yield_error).lower() | |
| if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: | |
| logger.info(f"[UI] Stream aborted while yielding final result, ignoring") | |
| else: | |
| raise | |
| return | |
| # For GPU timeouts, try to use last result even if it's partial | |
| if is_gpu_timeout and last_result is not None: | |
| logger.info(f"[UI] GPU timeout occurred, using last available result") | |
| try: | |
| yield last_result | |
| except (GeneratorExit, StopIteration, RuntimeError) as yield_error: | |
| error_msg_lower = str(yield_error).lower() | |
| if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: | |
| logger.info(f"[UI] Stream aborted while yielding timeout result, ignoring") | |
| else: | |
| raise | |
| return | |
| # Only show error for non-timeout errors when we have no valid answer | |
| # For GPU timeouts with no result, show empty message (not error) | |
| if is_gpu_timeout: | |
| logger.info(f"[UI] GPU timeout with no result, showing empty assistant message") | |
| updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": ""}] | |
| try: | |
| yield updated_history, "" | |
| except (GeneratorExit, StopIteration, RuntimeError) as yield_error: | |
| error_msg_lower = str(yield_error).lower() | |
| if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: | |
| logger.info(f"[UI] Stream aborted while yielding empty message, ignoring") | |
| else: | |
| raise | |
| else: | |
| # For other errors, show minimal error message only if no result | |
| error_display = f"⚠️ An error occurred: {error_str[:200]}" | |
| updated_history = history + [{"role": "user", "content": message}, {"role": "assistant", "content": error_display}] | |
| try: | |
| yield updated_history, "" | |
| except (GeneratorExit, StopIteration, RuntimeError) as yield_error: | |
| error_msg_lower = str(yield_error).lower() | |
| if 'abort' in error_msg_lower or 'stream' in error_msg_lower or 'buffer' in error_msg_lower: | |
| logger.info(f"[UI] Stream aborted while yielding error message, ignoring") | |
| else: | |
| raise | |
| submit_button.click( | |
| fn=stream_chat_with_model_check, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| top_k, | |
| penalty, | |
| retriever_k, | |
| merge_threshold, | |
| use_rag, | |
| medical_model, | |
| use_web_search, | |
| enable_clinical_intake, | |
| disable_agentic_reasoning, | |
| show_thoughts_state | |
| ], | |
| outputs=[chatbot, agentic_thoughts_box] | |
| ) | |
| message_input.submit( | |
| fn=stream_chat_with_model_check, | |
| inputs=[ | |
| message_input, | |
| chatbot, | |
| system_prompt, | |
| temperature, | |
| max_new_tokens, | |
| top_p, | |
| top_k, | |
| penalty, | |
| retriever_k, | |
| merge_threshold, | |
| use_rag, | |
| medical_model, | |
| use_web_search, | |
| enable_clinical_intake, | |
| disable_agentic_reasoning, | |
| show_thoughts_state | |
| ], | |
| outputs=[chatbot, agentic_thoughts_box] | |
| ) | |
| return demo | |