Spaces:
Sleeping
Sleeping
| # -------------------------------------------------------------- | |
| # app.py – a Gradio chat UI for maya-research/maya1 | |
| # -------------------------------------------------------------- | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from huggingface_hub import HfApi, Repository, create_repo, upload_folder | |
| # ------------------- CONFIGURATION ----------------------------- | |
| MODEL_ID = "maya-research/maya1" # the model you want to use | |
| MAX_NEW_TOKENS = 256 # generation length | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.9 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Folder inside the Space where we keep per‑session JSON files | |
| HISTORY_DIR = Path("history") | |
| HISTORY_DIR.mkdir(exist_ok=True) | |
| # ---------------------------------------------------------------- | |
| # 1️⃣ Load the model once (global, reused across requests) | |
| # ---------------------------------------------------------------- | |
| print(f"🔧 Loading {MODEL_ID} …") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| generator = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| do_sample=True, | |
| ) | |
| # -------------------------------------------------------------- | |
| # 2️⃣ Helper functions for history persistence | |
| # -------------------------------------------------------------- | |
| def _history_path(session_id: str) -> Path: | |
| """File that stores a JSON list of (user,assistant) pairs.""" | |
| safe_id = session_id.replace("/", "_") | |
| return HISTORY_DIR / f"{safe_id}.json" | |
| def load_history(session_id: str) -> List[Tuple[str, str]]: | |
| """Read JSON file → list of (user,assistant). Return [] if not present.""" | |
| p = _history_path(session_id) | |
| if p.is_file(): | |
| try: | |
| return json.loads(p.read_text(encoding="utf-8")) | |
| except Exception as e: | |
| print(f"⚠️ Failed to read history for {session_id}: {e}") | |
| return [] | |
| def save_history(session_id: str, chat: List[Tuple[str, str]]) -> None: | |
| """Write JSON file and push it back to the repo.""" | |
| p = _history_path(session_id) | |
| p.write_text(json.dumps(chat, ensure_ascii=False, indent=2), encoding="utf-8") | |
| # ----------------------------------------------------------------- | |
| # Push the new file to the repo (so it survives container restarts) | |
| # ----------------------------------------------------------------- | |
| # NOTE: This only works if the Space has a write token (see step 3). | |
| try: | |
| api = HfApi() | |
| # `repo_id` is the full name of the Space (owner/space-name) | |
| repo_id = os.getenv("HF_SPACE_REPO") # automatically set by the Hub | |
| if repo_id: | |
| api.upload_file( | |
| path_or_fileobj=str(p), | |
| path_in_repo=str(p), | |
| repo_id=repo_id, | |
| repo_type="space", | |
| token=os.getenv("HF_TOKEN"), | |
| ) | |
| except Exception as exc: | |
| # Failing to push is not fatal – the file stays on the container. | |
| print(f"⚠️ Could not push history to hub: {exc}") | |
| def list_sessions() -> List[str]: | |
| """Return a list of all stored session IDs (file names).""" | |
| return [f.stem for f in HISTORY_DIR.glob("*.json")] | |
| # -------------------------------------------------------------- | |
| # 3️⃣ The generation function – called by Gradio | |
| # -------------------------------------------------------------- | |
| def generate_reply( | |
| user_message: str, | |
| chat_history: List[Tuple[str, str]], | |
| session_id: str, | |
| ) -> Tuple[List[Tuple[str, str]], str]: | |
| """ | |
| 1. Append the user's new message. | |
| 2. Build the full `messages` list in the format expected by the model's | |
| chat_template. | |
| 3. Use the tokenizer's `apply_chat_template(..., add_generation_prompt=True)` | |
| to create the prompt. | |
| 4. Run the pipeline, decode, and strip the extra tokens. | |
| 5. Append the assistant answer and persist the whole chat. | |
| """ | |
| # ----- 1️⃣ Append user message to history ----- | |
| chat_history.append((user_message, "")) # placeholder for assistant | |
| # ----- 2️⃣ Build the messages list for the template ----- | |
| messages = [{"role": "user", "content": user_message}] | |
| # prepend previous exchanges (system messages are not needed here) | |
| for user, assistant in chat_history[:-1]: # exclude the placeholder | |
| messages.append({"role": "user", "content": user}) | |
| messages.append({"role": "assistant", "content": assistant}) | |
| # ----- 3️⃣ Render the prompt with the model's chat template ----- | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False, | |
| ) | |
| # Debug – uncomment if you want to see the raw prompt in logs | |
| # print("\n--- Prompt sent to model ---\n", prompt) | |
| # ----- 4️⃣ Generate the answer ----- | |
| result = generator(prompt, max_new_tokens=MAX_NEW_TOKENS)[0]["generated_text"] | |
| # The pipeline returns the **whole** text (prompt + answer). Remove the prompt. | |
| answer = result[len(prompt) :].strip() | |
| # Some models still emit special tokens like <|eot_id|>; strip them. | |
| for stop in ["<|eot_id|>", "</s>", "</s>"]: | |
| answer = answer.replace(stop, "").strip() | |
| # ----- 5️⃣ Update history and persist ----- | |
| chat_history[-1] = (user_message, answer) # replace placeholder | |
| save_history(session_id, chat_history) | |
| return chat_history, answer | |
| # -------------------------------------------------------------- | |
| # 4️⃣ UI definition | |
| # -------------------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| # ----------------------------------------------------------------- | |
| # Top bar – session selector + "New chat" button | |
| # ----------------------------------------------------------------- | |
| with gr.Row(): | |
| session_dropdown = gr.Dropdown( | |
| choices=list_sessions(), | |
| label="🗂️ Load previous chat", | |
| interactive=True, | |
| allow_custom_value=False, | |
| ) | |
| new_chat_btn = gr.Button("🆕 New chat", variant="primary") | |
| status_txt = gr.Markdown("", visible=False) | |
| # ----------------------------------------------------------------- | |
| # Main chat area | |
| # ----------------------------------------------------------------- | |
| chatbot = gr.Chatbot(label="🗨️ Maya‑1 Chat", height=600) | |
| txt = gr.Textbox( | |
| placeholder="Type your message and hit Enter …", | |
| label="Your message", | |
| container=False, | |
| ) | |
| submit_btn = gr.Button("Send", variant="secondary") | |
| # ----------------------------------------------------------------- | |
| # Hidden state – we keep the full list of (user,assistant) tuples | |
| # ----------------------------------------------------------------- | |
| chat_state = gr.State([]) # List[Tuple[str,str]] | |
| session_state = gr.State("") # Current session_id (string) | |
| # ----------------------------------------------------------------- | |
| # 5️⃣ Callbacks | |
| # ----------------------------------------------------------------- | |
| # When the app loads, generate a fresh anonymous session ID | |
| def init_session(): | |
| import uuid | |
| sid = str(uuid.uuid4()) | |
| return sid, [] # session_state, chat_state | |
| demo.load(fn=init_session, outputs=[session_state, chat_state]) | |
| # ----------------------------------------------------------------- | |
| # New chat → reset everything and give a brand‑new session ID | |
| # ----------------------------------------------------------------- | |
| def new_chat(): | |
| import uuid | |
| sid = str(uuid.uuid4()) | |
| return sid, [], [] # session_id, empty chat_state, empty UI | |
| new_chat_btn.click( | |
| fn=new_chat, | |
| outputs=[session_state, chat_state, chatbot], | |
| ) | |
| # ----------------------------------------------------------------- | |
| # Load a saved session from the dropdown | |
| # ----------------------------------------------------------------- | |
| def load_session(selected: str): | |
| if not selected: | |
| return "", [], [] # nothing selected → blank | |
| # The file name is the session_id we used when saving. | |
| session_id = selected | |
| history = load_history(session_id) | |
| # Convert List[Tuple] → format expected by Gradio.Chatbot | |
| ui_history = [(u, a) for u, a in history] | |
| return session_id, history, ui_history | |
| session_dropdown.change( | |
| fn=load_session, | |
| inputs=[session_dropdown], | |
| outputs=[session_state, chat_state, chatbot], | |
| ) | |
| # ----------------------------------------------------------------- | |
| # When the user hits "Enter" or clicks Send → generate a reply | |
| # ----------------------------------------------------------------- | |
| def user_submit(user_msg: str, chat_hist: List[Tuple[str, str]], sid: str): | |
| # Call the generation function | |
| updated_hist, answer = generate_reply(user_msg, chat_hist, sid) | |
| # UI expects List[Tuple[user,assistant]] | |
| ui_hist = [(u, a) for u, a in updated_hist] | |
| return "", ui_hist, updated_hist, answer | |
| txt.submit( | |
| fn=user_submit, | |
| inputs=[txt, chat_state, session_state], | |
| outputs=[txt, chatbot, chat_state, status_txt], | |
| ) | |
| submit_btn.click( | |
| fn=user_submit, | |
| inputs=[txt, chat_state, session_state], | |
| outputs=[txt, chatbot, chat_state, status_txt], | |
| ) | |
| # ----------------------------------------------------------------- | |
| # Keep the session‑dropdown up‑to‑date after each new save | |
| # ----------------------------------------------------------------- | |
| def refresh_dropdown(): | |
| return gr.Dropdown.update(choices=list_sessions()) | |
| # Whenever we save a new session (i.e., after every reply) we refresh the list. | |
| chat_state.change(fn=refresh_dropdown, inputs=None, outputs=session_dropdown) | |
| # -------------------------------------------------------------- | |
| # 6️⃣ Run the demo | |
| # -------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo.queue() # enables concurrent users | |
| demo.launch() |