Spaces:
Runtime error
Runtime error
| import os, json, re, time | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import FastAPI, HTTPException, Header | |
| from pydantic import BaseModel | |
| # ------------------------------------------------------------------- | |
| # Choose a writable cache dir *before* importing huggingface_hub | |
| # ------------------------------------------------------------------- | |
| def first_writable(paths: List[Optional[str]]) -> str: | |
| for p in paths: | |
| if not p: | |
| continue | |
| try: | |
| os.makedirs(p, exist_ok=True) | |
| testfile = os.path.join(p, ".write_test") | |
| with open(testfile, "w") as f: | |
| f.write("ok") | |
| os.remove(testfile) | |
| return p | |
| except Exception: | |
| continue | |
| # final fallback | |
| p = "/tmp/app_cache" | |
| os.makedirs(p, exist_ok=True) | |
| return p | |
| CACHE_BASE = first_writable([ | |
| os.getenv("SPACE_CACHE_DIR"), # optional override via Settings β Variables | |
| "/app/.cache", # WORKDIR is usually writable on HF Spaces | |
| "/home/user/.cache", # typical home dir | |
| "/tmp/app_cache", # safe fallback | |
| ]) | |
| HF_HOME = os.path.join(CACHE_BASE, "huggingface") | |
| os.environ["HF_HOME"] = HF_HOME | |
| os.environ["HF_HUB_CACHE"] = os.path.join(HF_HOME, "hub") | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| os.makedirs(HF_HOME, exist_ok=True) | |
| MODELS_DIR = os.path.join(CACHE_BASE, "models") | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| # Only now import libs that read the env vars | |
| from huggingface_hub import snapshot_download | |
| from llama_cpp import Llama | |
| # ------------------------------------------------------------------- | |
| # Config (can be overridden in Settings β Variables) | |
| # ------------------------------------------------------------------- | |
| MODEL_REPO = os.getenv("MODEL_REPO", "Qwen/Qwen2.5-3B-Instruct-GGUF") | |
| MODEL_FILE = os.getenv("MODEL_FILE", "qwen2.5-3b-instruct-q4_k_m.gguf") # hint; not mandatory | |
| MODEL_REV = os.getenv("MODEL_REV") # optional commit SHA to pin | |
| # Tuning (lower if memory is tight: N_CTX=1024, N_BATCH=32) | |
| N_CTX = int(os.getenv("N_CTX", 2048)) | |
| N_BATCH = int(os.getenv("N_BATCH", 64)) | |
| N_THREADS = os.cpu_count() or 2 | |
| # Optional bearer auth for endpoints | |
| API_SECRET = os.getenv("API_SECRET") # set in Settings β Variables if you want auth | |
| # ------------------------------------------------------------------- | |
| # App + globals | |
| # ------------------------------------------------------------------- | |
| app = FastAPI(title="Qwen Planner API (CPU)") | |
| llm: Optional[Llama] = None | |
| model_loaded: bool = False | |
| chosen_model_path: Optional[str] = None | |
| # ------------------------------------------------------------------- | |
| # Model loader (lazy, robust gguf discovery) | |
| # ------------------------------------------------------------------- | |
| def ensure_model() -> None: | |
| """ | |
| Lazy-load the model. Downloads any .gguf if needed, then auto-selects one: | |
| 1) exact MODEL_FILE if present, | |
| 2) else a *q4*.gguf, | |
| 3) else the first .gguf found. | |
| Surfaces clear errors to the HTTP layer. | |
| """ | |
| global llm, model_loaded, chosen_model_path | |
| if llm is not None: | |
| return | |
| try: | |
| local_dir = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| revision=MODEL_REV, | |
| allow_patterns=["*.gguf"], # flexible on filenames | |
| local_dir=MODELS_DIR, | |
| local_dir_use_symlinks=False, # copy instead of symlink | |
| ) | |
| # find gguf files | |
| ggufs: List[str] = [] | |
| for root, _, files in os.walk(local_dir): | |
| for f in files: | |
| if f.endswith(".gguf"): | |
| ggufs.append(os.path.join(root, f)) | |
| if not ggufs: | |
| raise FileNotFoundError("No .gguf files found after download.") | |
| # choose file | |
| path = None | |
| if MODEL_FILE: | |
| cand = os.path.join(local_dir, MODEL_FILE) | |
| if os.path.exists(cand): | |
| path = cand | |
| if path is None: | |
| q4 = [p for p in ggufs if "q4" in os.path.basename(p).lower()] | |
| path = (q4 or ggufs)[0] | |
| chosen_model_path = path | |
| print(f"[loader] Using GGUF: {path}") | |
| # load model (CPU) | |
| llm = Llama( | |
| model_path=path, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| logits_all=False, | |
| n_gpu_layers=0, # ensure CPU | |
| ) | |
| model_loaded = True | |
| except Exception as e: | |
| raise RuntimeError(f"ensure_model failed: {e}") | |
| # ------------------------------------------------------------------- | |
| # Helpers | |
| # ------------------------------------------------------------------- | |
| def require_auth(authorization: Optional[str]) -> None: | |
| if API_SECRET and authorization != f"Bearer {API_SECRET}": | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def extract_json_block(txt: str) -> str: | |
| m = re.search(r"\{.*\}\s*$", txt, flags=re.S) | |
| if not m: | |
| raise ValueError("No JSON object found in output.") | |
| return m.group(0) | |
| # ------------------------------------------------------------------- | |
| # Schemas | |
| # ------------------------------------------------------------------- | |
| SYSTEM_PROMPT_CHAT = "You are a concise assistant. Reply briefly in plain text." | |
| class ChatReq(BaseModel): | |
| prompt: str | |
| class PlanRequest(BaseModel): | |
| profile: Dict[str, Any] | |
| sample_rows: List[Dict[str, Any]] | |
| goal: str = "auto" # "classification" | "regression" | "auto" | |
| constraints: Dict[str, Any] = {} | |
| SYSTEM_PROMPT_PLAN = """You are a data-planning assistant. | |
| Return ONLY minified JSON matching exactly this schema: | |
| { | |
| "cleaning": [{"op":"impute_mean|impute_mode|drop_col|clip","cols":["..."],"params":{}}], | |
| "encoding": [{"op":"one_hot|ordinal|hash|target","cols":["..."],"params":{}}], | |
| "scaling": "none|standard|robust|minmax", | |
| "target": {"name":"<col_or_empty>","type":"classification|regression|auto"}, | |
| "split": {"strategy":"random|stratified","test_size":0.2,"cv":5}, | |
| "metric": "f1|roc_auc|accuracy|mae|rmse|r2", | |
| "models": ["lgbm","rf","xgb","logreg","ridge","catboost"], | |
| "notes":"<short justification>" | |
| } | |
| No prose. No markdown. JSON only.""" | |
| # ------------------------------------------------------------------- | |
| # Routes | |
| # ------------------------------------------------------------------- | |
| def healthz(): | |
| return { | |
| "status": "ok", | |
| "loaded": model_loaded, | |
| "cache_base": CACHE_BASE, | |
| "model_repo": MODEL_REPO, | |
| "model_file_hint": MODEL_FILE, | |
| "chosen_model_path": chosen_model_path, | |
| "n_ctx": N_CTX, | |
| "n_batch": N_BATCH, | |
| "threads": N_THREADS, | |
| } | |
| def chat(req: ChatReq, authorization: Optional[str] = Header(default=None)): | |
| require_auth(authorization) | |
| try: | |
| ensure_model() # first call may take minutes (download + load) | |
| except Exception as e: | |
| raise HTTPException(status_code=503, detail=f"loading_error: {e}") | |
| try: | |
| full_prompt = ( | |
| f"<|system|>\n{SYSTEM_PROMPT_CHAT}\n</|system|>\n" | |
| f"<|user|>\n{req.prompt}\n</|user|>\n" | |
| ) | |
| out = llm( | |
| prompt=full_prompt, | |
| temperature=0.2, | |
| top_p=0.9, | |
| max_tokens=256, | |
| stop=["</s>"], | |
| ) | |
| return {"response": out["choices"][0]["text"].strip()} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"infer_error: {e}") | |
| def plan(req: PlanRequest, authorization: Optional[str] = Header(default=None)): | |
| require_auth(authorization) | |
| try: | |
| ensure_model() | |
| except Exception as e: | |
| raise HTTPException(status_code=503, detail=f"loading_error: {e}") | |
| try: | |
| # Keep inputs small for free tier | |
| sample = req.sample_rows[:200] | |
| profile_json = json.dumps(req.profile)[:8000] | |
| sample_json = json.dumps(sample)[:8000] | |
| constraints_json = json.dumps(req.constraints)[:2000] | |
| user_block = ( | |
| f"Goal:{req.goal}\n" | |
| f"Constraints:{constraints_json}\n" | |
| f"Profile:{profile_json}\n" | |
| f"Sample:{sample_json}\n" | |
| ) | |
| full_prompt = ( | |
| f"<|system|>\n{SYSTEM_PROMPT_PLAN}\n</|system|>\n" | |
| f"<|user|>\n{user_block}\n</|user|>\n" | |
| ) | |
| out = llm( | |
| prompt=full_prompt, | |
| temperature=0.2, | |
| top_p=0.9, | |
| max_tokens=512, | |
| stop=["</s>"], | |
| ) | |
| text = out["choices"][0]["text"] | |
| payload = extract_json_block(text) | |
| data = json.loads(payload) | |
| return {"plan": data} | |
| except ValueError as e: | |
| raise HTTPException(status_code=422, detail=f"bad_json: {e}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"infer_error: {e}") | |