File size: 6,079 Bytes
927bb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0669246
927bb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#!/usr/bin/env python3
"""
Inference script untuk Textilindo AI Assistant
Menggunakan model yang sudah di-fine-tune dengan LoRA
"""

import os
import sys
import torch
import argparse
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_system_prompt(system_prompt_path):
    """Load system prompt from markdown file"""
    try:
        with open(system_prompt_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Extract SYSTEM_PROMPT from markdown
        if 'SYSTEM_PROMPT = """' in content:
            start = content.find('SYSTEM_PROMPT = """') + len('SYSTEM_PROMPT = """')
            end = content.find('"""', start)
            system_prompt = content[start:end].strip()
        else:
            # Fallback: use entire content
            system_prompt = content.strip()
        
        return system_prompt
    except Exception as e:
        logger.error(f"Error loading system prompt: {e}")
        return None

def load_model(model_path, lora_path=None):
    """Load model with optional LoRA weights"""
    logger.info(f"Loading base model from: {model_path}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Load LoRA weights if provided
    if lora_path and os.path.exists(lora_path):
        logger.info(f"Loading LoRA weights from: {lora_path}")
        model = PeftModel.from_pretrained(model, lora_path)
    else:
        logger.warning("No LoRA weights found, using base model")
    
    return model, tokenizer

def generate_response(model, tokenizer, user_input, system_prompt, max_length=512):
    """Generate response from the model"""
    # Create full prompt with system prompt
    full_prompt = f"<|system|>\n{system_prompt}\n<|user|>\n{user_input}\n<|assistant|>\n"
    
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=0.7,
            top_p=0.9,
            top_k=40,
            repetition_penalty=1.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stop_strings=["<|end|>", "<|user|>"]
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the assistant's response
    if "<|assistant|>" in response:
        assistant_response = response.split("<|assistant|>")[-1].strip()
        # Remove any remaining special tokens
        assistant_response = assistant_response.replace("<|end|>", "").strip()
        return assistant_response
    else:
        return response

def interactive_chat(model, tokenizer, system_prompt):
    """Interactive chat mode"""
    print("πŸ€– Textilindo AI Assistant - Chat Mode")
    print("=" * 60)
    print("Type 'quit' to exit")
    print("-" * 60)
    
    while True:
        try:
            user_input = input("\nπŸ‘€ Customer: ").strip()
            
            if user_input.lower() in ['quit', 'exit', 'q']:
                print("πŸ‘‹ Terima kasih! Sampai jumpa!")
                break
            
            if not user_input:
                continue
            
            print("\nπŸ€– Textilindo AI: ", end="", flush=True)
            response = generate_response(model, tokenizer, user_input, system_prompt)
            print(response)
                
        except KeyboardInterrupt:
            print("\nπŸ‘‹ Terima kasih! Sampai jumpa!")
            break
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            print(f"❌ Error: {e}")

def main():
    parser = argparse.ArgumentParser(description='Textilindo AI Assistant Inference')
    parser.add_argument('--model_path', type=str, default='./models/llama-3.1-8b-instruct',
                        help='Path to base model')
    parser.add_argument('--lora_path', type=str, default=None,
                        help='Path to LoRA weights')
    parser.add_argument('--system_prompt', type=str, default='configs/system_prompt.md',
                        help='Path to system prompt file')
    parser.add_argument('--prompt', type=str, default=None,
                        help='Single prompt to process')
    
    args = parser.parse_args()
    
    print("πŸ€– Textilindo AI Assistant - Inference")
    print("=" * 60)
    
    # Load system prompt
    system_prompt = load_system_prompt(args.system_prompt)
    if not system_prompt:
        print(f"❌ System prompt tidak ditemukan: {args.system_prompt}")
        sys.exit(1)
    
    # Check if model exists
    if not os.path.exists(args.model_path):
        print(f"❌ Base model tidak ditemukan: {args.model_path}")
        print("Jalankan setup_textilindo_training.py terlebih dahulu")
        sys.exit(1)
    
    try:
        # Load model
        print("1️⃣ Loading model...")
        model, tokenizer = load_model(args.model_path, args.lora_path)
        print("βœ… Model loaded successfully!")
        
        if args.prompt:
            # Single prompt mode
            print(f"\nπŸ“ Processing prompt: {args.prompt}")
            response = generate_response(model, tokenizer, args.prompt, system_prompt)
            print(f"\nπŸ€– Response: {response}")
        else:
            # Interactive mode
            interactive_chat(model, tokenizer, system_prompt)
            
    except Exception as e:
        logger.error(f"Error: {e}")
        print(f"❌ Error loading model: {e}")

if __name__ == "__main__":
    main()