Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Inference script untuk Textilindo AI Assistant | |
| Menggunakan model yang sudah di-fine-tune dengan LoRA | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import argparse | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def load_system_prompt(system_prompt_path): | |
| """Load system prompt from markdown file""" | |
| try: | |
| with open(system_prompt_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Extract SYSTEM_PROMPT from markdown | |
| if 'SYSTEM_PROMPT = """' in content: | |
| start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """') | |
| end = content.find('"""', start) | |
| system_prompt = content[start:end].strip() | |
| else: | |
| # Fallback: use entire content | |
| system_prompt = content.strip() | |
| return system_prompt | |
| except Exception as e: | |
| logger.error(f"Error loading system prompt: {e}") | |
| return None | |
| def load_model(model_path, lora_path=None): | |
| """Load model with optional LoRA weights""" | |
| logger.info(f"Loading base model from: {model_path}") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load base model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Load LoRA weights if provided | |
| if lora_path and os.path.exists(lora_path): | |
| logger.info(f"Loading LoRA weights from: {lora_path}") | |
| model = PeftModel.from_pretrained(model, lora_path) | |
| else: | |
| logger.warning("No LoRA weights found, using base model") | |
| return model, tokenizer | |
| def generate_response(model, tokenizer, user_input, system_prompt, max_length=512): | |
| """Generate response from the model""" | |
| # Create full prompt with system prompt | |
| full_prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{user_input}\n<|assistant|>\n" | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| temperature=0.7, | |
| top_p=0.9, | |
| top_k=40, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| stop_strings=["<|end|>", "<|user|>"] | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| if "<|assistant|>" in response: | |
| assistant_response = response.split("<|assistant|>")[-1].strip() | |
| # Remove any remaining special tokens | |
| assistant_response = assistant_response.replace("<|end|>", "").strip() | |
| return assistant_response | |
| else: | |
| return response | |
| def interactive_chat(model, tokenizer, system_prompt): | |
| """Interactive chat mode""" | |
| print("π€ Textilindo AI Assistant - Chat Mode") | |
| print("=" * 60) | |
| print("Type 'quit' to exit") | |
| print("-" * 60) | |
| while True: | |
| try: | |
| user_input = input("\nπ€ Customer: ").strip() | |
| if user_input.lower() in ['quit', 'exit', 'q']: | |
| print("π Terima kasih! Sampai jumpa!") | |
| break | |
| if not user_input: | |
| continue | |
| print("\nπ€ Textilindo AI: ", end="", flush=True) | |
| response = generate_response(model, tokenizer, user_input, system_prompt) | |
| print(response) | |
| except KeyboardInterrupt: | |
| print("\nπ Terima kasih! Sampai jumpa!") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| print(f"β Error: {e}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Textilindo AI Assistant Inference') | |
| parser.add_argument('--model_path', type=str, default='./models/llama-3.1-8b-instruct', | |
| help='Path to base model') | |
| parser.add_argument('--lora_path', type=str, default=None, | |
| help='Path to LoRA weights') | |
| parser.add_argument('--system_prompt', type=str, default='configs/system_prompt.md', | |
| help='Path to system prompt file') | |
| parser.add_argument('--prompt', type=str, default=None, | |
| help='Single prompt to process') | |
| args = parser.parse_args() | |
| print("π€ Textilindo AI Assistant - Inference") | |
| print("=" * 60) | |
| # Load system prompt | |
| system_prompt = load_system_prompt(args.system_prompt) | |
| if not system_prompt: | |
| print(f"β System prompt tidak ditemukan: {args.system_prompt}") | |
| sys.exit(1) | |
| # Check if model exists | |
| if not os.path.exists(args.model_path): | |
| print(f"β Base model tidak ditemukan: {args.model_path}") | |
| print("Jalankan setup_textilindo_training.py terlebih dahulu") | |
| sys.exit(1) | |
| try: | |
| # Load model | |
| print("1οΈβ£ Loading model...") | |
| model, tokenizer = load_model(args.model_path, args.lora_path) | |
| print("β Model loaded successfully!") | |
| if args.prompt: | |
| # Single prompt mode | |
| print(f"\nπ Processing prompt: {args.prompt}") | |
| response = generate_response(model, tokenizer, args.prompt, system_prompt) | |
| print(f"\nπ€ Response: {response}") | |
| else: | |
| # Interactive mode | |
| interactive_chat(model, tokenizer, system_prompt) | |
| except Exception as e: | |
| logger.error(f"Error: {e}") | |
| print(f"β Error loading model: {e}") | |
| if __name__ == "__main__": | |
| main() | |