textilindo-ai-assistant / scripts /inference_textilindo_ai.py
Stefanus Simandjuntak
Update to use Llama 3.1 8B Instruct model
0669246
#!/usr/bin/env python3
"""
Inference script untuk Textilindo AI Assistant
Menggunakan model yang sudah di-fine-tune dengan LoRA
"""
import os
import sys
import torch
import argparse
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def load_system_prompt(system_prompt_path):
"""Load system prompt from markdown file"""
try:
with open(system_prompt_path, 'r', encoding='utf-8') as f:
content = f.read()
# Extract SYSTEM_PROMPT from markdown
if 'SYSTEM_PROMPT = """' in content:
start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """')
end = content.find('"""', start)
system_prompt = content[start:end].strip()
else:
# Fallback: use entire content
system_prompt = content.strip()
return system_prompt
except Exception as e:
logger.error(f"Error loading system prompt: {e}")
return None
def load_model(model_path, lora_path=None):
"""Load model with optional LoRA weights"""
logger.info(f"Loading base model from: {model_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Load LoRA weights if provided
if lora_path and os.path.exists(lora_path):
logger.info(f"Loading LoRA weights from: {lora_path}")
model = PeftModel.from_pretrained(model, lora_path)
else:
logger.warning("No LoRA weights found, using base model")
return model, tokenizer
def generate_response(model, tokenizer, user_input, system_prompt, max_length=512):
"""Generate response from the model"""
# Create full prompt with system prompt
full_prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{user_input}\n<|assistant|>\n"
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=0.7,
top_p=0.9,
top_k=40,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
stop_strings=["<|end|>", "<|user|>"]
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "<|assistant|>" in response:
assistant_response = response.split("<|assistant|>")[-1].strip()
# Remove any remaining special tokens
assistant_response = assistant_response.replace("<|end|>", "").strip()
return assistant_response
else:
return response
def interactive_chat(model, tokenizer, system_prompt):
"""Interactive chat mode"""
print("πŸ€– Textilindo AI Assistant - Chat Mode")
print("=" * 60)
print("Type 'quit' to exit")
print("-" * 60)
while True:
try:
user_input = input("\nπŸ‘€ Customer: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
print("πŸ‘‹ Terima kasih! Sampai jumpa!")
break
if not user_input:
continue
print("\nπŸ€– Textilindo AI: ", end="", flush=True)
response = generate_response(model, tokenizer, user_input, system_prompt)
print(response)
except KeyboardInterrupt:
print("\nπŸ‘‹ Terima kasih! Sampai jumpa!")
break
except Exception as e:
logger.error(f"Error generating response: {e}")
print(f"❌ Error: {e}")
def main():
parser = argparse.ArgumentParser(description='Textilindo AI Assistant Inference')
parser.add_argument('--model_path', type=str, default='./models/llama-3.1-8b-instruct',
help='Path to base model')
parser.add_argument('--lora_path', type=str, default=None,
help='Path to LoRA weights')
parser.add_argument('--system_prompt', type=str, default='configs/system_prompt.md',
help='Path to system prompt file')
parser.add_argument('--prompt', type=str, default=None,
help='Single prompt to process')
args = parser.parse_args()
print("πŸ€– Textilindo AI Assistant - Inference")
print("=" * 60)
# Load system prompt
system_prompt = load_system_prompt(args.system_prompt)
if not system_prompt:
print(f"❌ System prompt tidak ditemukan: {args.system_prompt}")
sys.exit(1)
# Check if model exists
if not os.path.exists(args.model_path):
print(f"❌ Base model tidak ditemukan: {args.model_path}")
print("Jalankan setup_textilindo_training.py terlebih dahulu")
sys.exit(1)
try:
# Load model
print("1️⃣ Loading model...")
model, tokenizer = load_model(args.model_path, args.lora_path)
print("βœ… Model loaded successfully!")
if args.prompt:
# Single prompt mode
print(f"\nπŸ“ Processing prompt: {args.prompt}")
response = generate_response(model, tokenizer, args.prompt, system_prompt)
print(f"\nπŸ€– Response: {response}")
else:
# Interactive mode
interactive_chat(model, tokenizer, system_prompt)
except Exception as e:
logger.error(f"Error: {e}")
print(f"❌ Error loading model: {e}")
if __name__ == "__main__":
main()