lora-demo / app.py
Sangjoo's picture
Upload folder using huggingface_hub
14647cf verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# ๋ชจ๋ธ ์„ค์ • (์—ฌ๊ธฐ๋ฅผ ์ˆ˜์ •ํ•˜์„ธ์š”!)
MODELS = {
# ========================================
# 03๋ฒˆ: ํ•œ๊ตญ์–ด ์š”์•ฝ (EXAONE-3.5)
# ========================================
"ํ•œ๊ตญ์–ด ์š”์•ฝ (03๋ฒˆ)": {
"base_model": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"lora_path": "Sangjoo/exaone-summary-lora", # TODO: ๋ณธ์ธ ๊ฒฝ๋กœ๋กœ!
"prompt_template": "{input}\n\n์š”์•ฝ:",
"max_new_tokens": 60,
"placeholder": "๋‰ด์Šค ๊ธฐ์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
"example": "์„œ์šธ์‹œ๊ฐ€ ๋‚ด๋…„๋ถ€ํ„ฐ ์ „๊ธฐ์ฐจ ์ถฉ์ „์†Œ๋ฅผ ๋Œ€ํญ ํ™•๋Œ€ํ•œ๋‹ค.",
},
# ========================================
# 05๋ฒˆ: ๋‹ค์ค‘ ๋ชจ๋ธ ๋น„๊ต (์„ ํƒ์‚ฌํ•ญ)
# ========================================
"Granite ์š”์•ฝ (05๋ฒˆ)": {
"base_model": "ibm-granite/granite-4.0-micro",
"lora_path": "Sangjoo/granite-summary-lora",
"prompt_template": "<|user|>\n{input}\n\n์œ„ ๊ธฐ์‚ฌ๋ฅผ ์š”์•ฝํ•ด์ฃผ์„ธ์š”.<|assistant|>\n",
"max_new_tokens": 60,
"placeholder": "๋‰ด์Šค ๊ธฐ์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
"example": "์„œ์šธ์‹œ๊ฐ€ ๋‚ด๋…„๋ถ€ํ„ฐ ์ „๊ธฐ์ฐจ ์ถฉ์ „์†Œ๋ฅผ ๋Œ€ํญ ํ™•๋Œ€ํ•œ๋‹ค.",
},
# "Qwen3 ์š”์•ฝ (05๋ฒˆ)": {
# "base_model": "Qwen/Qwen3-4B-Instruct-2507",
# "lora_path": "Sangjoo/qwen3-summary-lora",
# "prompt_template": "<|im_start|>user\n{input}\n\n์œ„ ๊ธฐ์‚ฌ๋ฅผ ์š”์•ฝํ•ด์ฃผ์„ธ์š”.<|im_end|>\n<|im_start|>assistant\n",
# "max_new_tokens": 60,
# "placeholder": "๋‰ด์Šค ๊ธฐ์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
# "example": "์„œ์šธ์‹œ๊ฐ€ ๋‚ด๋…„๋ถ€ํ„ฐ ์ „๊ธฐ์ฐจ ์ถฉ์ „์†Œ๋ฅผ ๋Œ€ํญ ํ™•๋Œ€ํ•œ๋‹ค.",
# },
# ========================================
# 06๋ฒˆ: ๊ฐ์ • ๋ถ„๋ฅ˜ + ์˜์–ด QA
# ========================================
"๊ฐ์ • ๋ถ„๋ฅ˜ (06๋ฒˆ)": {
"base_model": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"lora_path": "Sangjoo/lora-sentiment",
"prompt_template": "๋‹ค์Œ ์˜ํ™” ๋ฆฌ๋ทฐ์˜ ๊ฐ์ •์„ ๋ถ„๋ฅ˜ํ•˜์„ธ์š”.\n\n๋ฆฌ๋ทฐ: {input}\n\n๊ฐ์ •:",
"max_new_tokens": 10,
"placeholder": "์˜ํ™” ๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
"example": "This movie was amazing! Great story and excellent acting.",
},
"์˜์–ด QA (06๋ฒˆ)": {
"base_model": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"lora_path": "Sangjoo/lora-qa",
"prompt_template": "Context: The Eiffel Tower is in Paris, France.\n\nQuestion: {input}\n\nAnswer:",
"max_new_tokens": 30,
"placeholder": "์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...",
"example": "Where is the Eiffel Tower located?",
},
}
loaded_models = {}
def load_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
config = MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(config["base_model"], use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
config["base_model"],
device_map="auto",
trust_remote_code=True,
quantization_config=quant_config,
)
model = PeftModel.from_pretrained(base_model, config["lora_path"])
loaded_models[model_name] = (model, tokenizer, config)
return model, tokenizer, config
def generate_response(model_name, user_input):
try:
model, tokenizer, config = load_model(model_name)
prompt = config["prompt_template"].format(input=user_input)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=config["max_new_tokens"],
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
# ํ”„๋กฌํ”„ํŠธ ์ œ๊ฑฐ
if "์š”์•ฝ:" in result:
return result.split("์š”์•ฝ:")[-1].strip()
elif "๊ฐ์ •:" in result:
return result.split("๊ฐ์ •:")[-1].strip()
elif "Answer:" in result:
return result.split("Answer:")[-1].strip()
elif "<|assistant|>" in result:
return result.split("<|assistant|>")[-1].strip()
elif "<|im_start|>assistant" in result:
return result.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
else:
return result[len(prompt):].strip()
except Exception as e:
return f"โŒ ์˜ค๋ฅ˜: {str(e)}"
with gr.Blocks(title="LoRA ๋ชจ๋ธ ๋ฐ๋ชจ") as demo:
gr.Markdown("# ๐Ÿค– LoRA ํŒŒ์ธํŠœ๋‹ ๋ชจ๋ธ ๋ฐ๋ชจ")
gr.Markdown("Day 1์—์„œ ํ•™์Šตํ•œ ์—ฌ๋Ÿฌ LoRA ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•ด๋ณด์„ธ์š”!")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="๐Ÿ“Œ ๋ชจ๋ธ ์„ ํƒ"
)
input_text = gr.Textbox(label="๐Ÿ’ฌ ์ž…๋ ฅ", lines=5)
submit_btn = gr.Button("๐Ÿš€ ์‹คํ–‰", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="โœจ ๊ฒฐ๊ณผ", lines=10)
submit_btn.click(
fn=generate_response,
inputs=[model_dropdown, input_text],
outputs=[output_text]
)
demo.launch()