Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import gradio as gr | |
| import asyncio | |
| import atexit | |
| # Create and own a single event loop so teardown is clean on Spaces | |
| event_loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(event_loop) | |
| # Make loop destructor safe (prevents Invalid file descriptor spam on shutdown.) | |
| def _safe_loop_del(self): # pragma: no cover | |
| try: | |
| if not self.is_closed(): | |
| self.close() | |
| except Exception: | |
| pass | |
| asyncio.BaseEventLoop.__del__ = _safe_loop_del | |
| def _close_event_loop(): | |
| if event_loop.is_running(): | |
| event_loop.call_soon_threadsafe(event_loop.stop) | |
| if not event_loop.is_closed(): | |
| event_loop.close() | |
| atexit.register(_close_event_loop) | |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| LORA_REPO = "mackenzietechdocs/ml-sensei-lora-tinyllama-1.1b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| print("πΉ Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("πΉ Loading base model...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=DTYPE, | |
| device_map={"": DEVICE}, # keep model on a single device (CPU on basic Spaces) | |
| low_cpu_mem_usage=True, | |
| ) | |
| print("πΉ Loading LoRA adapter...") | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| LORA_REPO, | |
| device_map={"": DEVICE}, # keep adapter on same device as base model | |
| torch_dtype=DTYPE, | |
| ) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| SYSTEM_PROMPT = ( | |
| "You are ML Sensei, a calm, friendly machine learning tutor. " | |
| "Explain ML/AI concepts clearly using intuition, simple language, and examples." | |
| ) | |
| def generate_reply(chat_history, user_message, max_new_tokens=512, temperature=0.7, top_p=0.9): | |
| # chat_history: list of [user, assistant] | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT} | |
| ] | |
| for u, a in chat_history: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": user_message}) | |
| # Use TinyLlama's chat template | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1536).to(DEVICE) | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Only decode the newly generated tokens (after the input) | |
| new_tokens = output[0][inputs['input_ids'].shape[1]:] | |
| reply = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| return reply | |
| def gradio_chat(user_message, history, max_new_tokens, temperature, top_p): | |
| # Gradio Chatbot typically returns list of (user, assistant) tuples. | |
| # Handle both tuple and dict formats defensively. | |
| history_messages = history or [] | |
| def _as_text(content): | |
| # Gradio may wrap content as list of {"type": "text", "text": "..."} dicts | |
| if isinstance(content, list): | |
| return " ".join( | |
| c.get("text", "") if isinstance(c, dict) else str(c) | |
| for c in content | |
| ) | |
| if isinstance(content, dict): | |
| return content.get("text", "") or str(content) | |
| return content | |
| history_pairs = [] | |
| pending_user = None | |
| for msg in history_messages: | |
| if isinstance(msg, (list, tuple)) and len(msg) == 2: | |
| history_pairs.append((_as_text(msg[0]), _as_text(msg[1]))) | |
| pending_user = None | |
| continue | |
| role = msg.get("role") | |
| content = _as_text(msg.get("content", "")) | |
| if role == "user": | |
| pending_user = content | |
| elif role == "assistant": | |
| history_pairs.append((pending_user or "", content)) | |
| pending_user = None | |
| reply = generate_reply( | |
| history_pairs, | |
| user_message, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| ) | |
| new_history_pairs = history_pairs + [(user_message, reply)] | |
| return "", new_history_pairs | |
| # Custom CSS for messenger-style UI with peach/pink theme | |
| custom_css = """ | |
| /* Main gradient background */ | |
| .gradio-container { | |
| background: linear-gradient(135deg, #FFB88C 0%, #FF9A8B 50%, #FF6A88 100%) !important; | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; | |
| } | |
| /* Chatbot container - transparent to show gradient */ | |
| .chatbot { | |
| border: none !important; | |
| background: transparent !important; | |
| } | |
| /* Make ALL chatbot area elements transparent except message bubbles */ | |
| [data-testid="chatbot"], | |
| [data-testid="chatbot"] > *, | |
| .chatbot div:not([data-testid="user"]):not([data-testid="bot"]), | |
| .chatbot > div > div, | |
| .chatbot [class*="wrap"], | |
| .chatbot [class*="container"]:not([data-testid="user"]):not([data-testid="bot"]) { | |
| background: transparent !important; | |
| background-color: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| } | |
| /* User message bubble - light grey (only the outermost container) */ | |
| [data-testid="user"] { | |
| background: #E8E8E8 !important; | |
| color: #000000 !important; | |
| border: none !important; | |
| border-radius: 20px 20px 4px 20px !important; | |
| padding: 14px 18px !important; | |
| margin: 8px 20px 8px auto !important; | |
| max-width: 70% !important; | |
| min-width: 100px !important; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important; | |
| outline: none !important; | |
| display: inline-block !important; | |
| } | |
| /* All children of user bubble - transparent, NO borders, NO padding */ | |
| [data-testid="user"] *, | |
| [data-testid="user"] div, | |
| [data-testid="user"] p, | |
| [data-testid="user"] > div, | |
| [data-testid="user"] [class*="message"], | |
| [data-testid="user"] [class*="wrap"], | |
| [data-testid="user"] [class*="container"] { | |
| background: transparent !important; | |
| background-color: transparent !important; | |
| border: none !important; | |
| border-width: 0 !important; | |
| border-style: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| } | |
| /* Assistant message bubble - pink (only the outermost container) */ | |
| [data-testid="bot"] { | |
| background: linear-gradient(135deg, #FFB3D9 0%, #FFAAD1 100%) !important; | |
| color: #000000 !important; | |
| border: none !important; | |
| border-width: 0 !important; | |
| border-style: none !important; | |
| border-radius: 20px 20px 20px 4px !important; | |
| padding: 14px 18px !important; | |
| margin: 8px auto 8px 20px !important; | |
| max-width: 70% !important; | |
| min-width: 100px !important; | |
| box-shadow: 0 2px 8px rgba(255, 106, 136, 0.2) !important; | |
| outline: none !important; | |
| display: inline-block !important; | |
| } | |
| /* All children of bot bubble - transparent, NO borders, NO padding */ | |
| [data-testid="bot"] *, | |
| [data-testid="bot"] div, | |
| [data-testid="bot"] p, | |
| [data-testid="bot"] > div, | |
| [data-testid="bot"] [class*="message"], | |
| [data-testid="bot"] [class*="wrap"], | |
| [data-testid="bot"] [class*="container"] { | |
| background: transparent !important; | |
| background-color: transparent !important; | |
| border: none !important; | |
| border-width: 0 !important; | |
| border-style: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| } | |
| /* Force all text to be black everywhere in chatbot */ | |
| .chatbot, | |
| .chatbot *, | |
| [data-testid="user"], | |
| [data-testid="user"] *, | |
| [data-testid="bot"], | |
| [data-testid="bot"] * { | |
| color: #000000 !important; | |
| } | |
| /* Input container */ | |
| .input-row { | |
| background: rgba(255, 255, 255, 0.95) !important; | |
| border-radius: 28px !important; | |
| padding: 8px !important; | |
| box-shadow: 0 4px 16px rgba(0, 0, 0, 0.1) !important; | |
| border: none !important; | |
| } | |
| /* Remove all borders from EVERYTHING in input area */ | |
| .input-row, | |
| .input-row *, | |
| .input-row div, | |
| .input-row label, | |
| .input-row .block, | |
| .input-row .wrap, | |
| .input-row [class*="wrap"], | |
| .input-row [class*="container"], | |
| .input-row fieldset, | |
| .input-row > * { | |
| border: none !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| } | |
| /* Input box wrapper - absolutely no borders */ | |
| .input-box, | |
| .input-box > *, | |
| .input-box *, | |
| .input-box div, | |
| .input-box .wrap, | |
| .input-box [class*="wrap"], | |
| .input-box [class*="container"], | |
| .input-box fieldset { | |
| border: none !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| background: transparent !important; | |
| } | |
| .input-box textarea { | |
| background: white !important; | |
| color: #000000 !important; | |
| border-radius: 20px !important; | |
| padding: 12px 20px !important; | |
| font-size: 15px !important; | |
| box-shadow: 0 1px 4px rgba(0, 0, 0, 0.05) !important; | |
| } | |
| .input-box textarea::placeholder { | |
| color: #999 !important; | |
| } | |
| .input-box textarea:focus { | |
| border: 2px solid #FF6A88 !important; | |
| outline: none !important; | |
| } | |
| /* Send button */ | |
| .send-button { | |
| background: linear-gradient(135deg, #FF6A88 0%, #FF8C94 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: 20px !important; | |
| padding: 12px 28px !important; | |
| font-weight: 600 !important; | |
| font-size: 15px !important; | |
| cursor: pointer !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 3px 10px rgba(255, 106, 136, 0.3) !important; | |
| min-width: 90px !important; | |
| } | |
| .send-button:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 5px 14px rgba(255, 106, 136, 0.4) !important; | |
| } | |
| /* Settings panel */ | |
| .settings-panel { | |
| background: rgba(255, 255, 255, 0.95) !important; | |
| border-radius: 20px !important; | |
| padding: 24px !important; | |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1) !important; | |
| } | |
| /* Force all settings panel children to have white/transparent backgrounds */ | |
| .settings-panel *, | |
| .settings-panel div, | |
| .settings-panel .block, | |
| .settings-panel .wrap { | |
| background: transparent !important; | |
| background-color: transparent !important; | |
| } | |
| /* Force all text in settings panel to be black */ | |
| .settings-panel h3, | |
| .settings-panel h3 *, | |
| .settings-panel label, | |
| .settings-panel label *, | |
| .settings-panel span, | |
| .settings-panel p, | |
| .settings-panel .info, | |
| .settings-panel * { | |
| color: #000000 !important; | |
| } | |
| .settings-panel h3 { | |
| margin-bottom: 20px !important; | |
| } | |
| /* Slider labels */ | |
| .settings-panel label { | |
| font-weight: 600 !important; | |
| font-size: 14px !important; | |
| } | |
| /* Slider info text */ | |
| .settings-panel .info { | |
| color: #666 !important; | |
| font-size: 12px !important; | |
| } | |
| /* Slider styling */ | |
| input[type="range"] { | |
| accent-color: #FF6A88 !important; | |
| } | |
| /* Slider containers - force transparent */ | |
| .settings-panel .slider-container, | |
| .settings-panel [class*="slider"], | |
| .settings-panel [class*="wrap"] { | |
| background: transparent !important; | |
| border: none !important; | |
| } | |
| /* Header */ | |
| .header-title { | |
| color: white !important; | |
| text-align: center !important; | |
| font-size: 2.5em !important; | |
| font-weight: 700 !important; | |
| margin-bottom: 8px !important; | |
| text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.15) !important; | |
| } | |
| .header-subtitle { | |
| color: rgba(255, 255, 255, 0.95) !important; | |
| text-align: center !important; | |
| font-size: 1.1em !important; | |
| margin-bottom: 20px !important; | |
| text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.1) !important; | |
| } | |
| /* Hide ALL progress bars everywhere */ | |
| .progress, | |
| .progress-bar, | |
| .generating, | |
| [class*="progress"], | |
| div[class*="progress"], | |
| .chatbot .progress, | |
| .input-row .progress, | |
| .input-box .progress, | |
| .wrap .progress, | |
| .gradio-container .progress { | |
| display: none !important; | |
| visibility: hidden !important; | |
| opacity: 0 !important; | |
| height: 0 !important; | |
| overflow: hidden !important; | |
| } | |
| /* Static progress/status container */ | |
| .progress-container { | |
| margin-top: 12px !important; | |
| background: rgba(255, 255, 255, 0.9) !important; | |
| border-radius: 12px !important; | |
| padding: 10px 20px !important; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05) !important; | |
| min-height: 36px !important; | |
| } | |
| .progress-container div { | |
| background: transparent !important; | |
| } | |
| /* Scrollbar */ | |
| ::-webkit-scrollbar { | |
| width: 8px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: rgba(255, 255, 255, 0.2); | |
| border-radius: 10px; | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: rgba(255, 106, 136, 0.6); | |
| border-radius: 10px; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { | |
| background: rgba(255, 106, 136, 0.8); | |
| } | |
| /* Footer text - make black */ | |
| footer, | |
| footer *, | |
| footer p, | |
| footer a, | |
| footer span, | |
| .footer, | |
| .footer *, | |
| [class*="footer"], | |
| [class*="footer"] * { | |
| color: #000000 !important; | |
| } | |
| /* Chat control buttons - down arrow and delete button - make pink */ | |
| .chatbot button, | |
| .chatbot [role="button"], | |
| .chatbot svg, | |
| [data-testid="chatbot"] button, | |
| [data-testid="chatbot"] [role="button"], | |
| [data-testid="chatbot"] svg, | |
| .chatbot button svg, | |
| .chatbot button path, | |
| [data-testid="chatbot"] button svg, | |
| [data-testid="chatbot"] button path { | |
| color: #FF6A88 !important; | |
| fill: #FF6A88 !important; | |
| stroke: #FF6A88 !important; | |
| } | |
| /* Ensure button backgrounds are transparent */ | |
| .chatbot button, | |
| [data-testid="chatbot"] button { | |
| background: transparent !important; | |
| border: none !important; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| # Inject CSS manually (gradio 6.0.2 removed the `css` kwarg on Blocks) | |
| gr.HTML(f"<style>{custom_css}</style>") | |
| # Header | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1 class="header-title">π₯ ML Sensei Chat</h1> | |
| <p class="header-subtitle">Your friendly AI tutor for Machine Learning & Deep Learning</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Main chat area | |
| with gr.Column(scale=3): | |
| chat = gr.Chatbot( | |
| height=550, | |
| show_label=False, | |
| avatar_images=(None, "ml-chat.png"), | |
| elem_classes="chatbot", | |
| ) | |
| # Settings sidebar | |
| with gr.Column(scale=1, elem_classes="settings-panel"): | |
| gr.Markdown("### βοΈ AI Settings") | |
| max_tokens = gr.Slider( | |
| 64, 1536, | |
| value=800, | |
| step=16, | |
| label="π’ Max Tokens", | |
| info="Response length" | |
| ) | |
| temperature = gr.Slider( | |
| 0.1, 1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="π‘οΈ Temperature", | |
| info="Creativity level" | |
| ) | |
| top_p = gr.Slider( | |
| 0.1, 1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="π― Top-p", | |
| info="Diversity control" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| <div style="text-align: center; color: #666; font-size: 0.85em;"> | |
| <p><strong>TinyLlama LoRA</strong></p> | |
| <p>Powered by π¦</p> | |
| </div> | |
| """) | |
| # Input area below chat | |
| with gr.Row(elem_classes="input-row"): | |
| user_input = gr.Textbox( | |
| show_label=False, | |
| placeholder="π¬ Ask ML Sensei about ML / DL / AI...", | |
| container=False, | |
| scale=5, | |
| elem_classes="input-box" | |
| ) | |
| send_btn = gr.Button("Send", scale=1, elem_classes="send-button") | |
| # Status indicator below input | |
| with gr.Row(elem_classes="progress-container"): | |
| status = gr.HTML(value="<div style='text-align: center; color: #666; font-size: 13px;'>Ready to chat!</div>") | |
| send_btn.click( | |
| gradio_chat, | |
| inputs=[user_input, chat, max_tokens, temperature, top_p], | |
| outputs=[user_input, chat], | |
| ) | |
| user_input.submit( | |
| gradio_chat, | |
| inputs=[user_input, chat, max_tokens, temperature, top_p], | |
| outputs=[user_input, chat], | |
| ) | |
| # Use queue + main guard to avoid async loop shutdown warnings on exit (HF Spaces) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |