QAway-to commited on
Commit
8d4e786
·
1 Parent(s): 995a334

f3nsmart/TinyLlama-MBTI-Interviewer-LoRA. v1.0

Browse files
Files changed (2) hide show
  1. app.py +18 -42
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,78 +1,64 @@
1
  import gradio as gr
 
2
  from transformers import (
3
  AutoTokenizer,
4
  AutoModelForCausalLM,
5
  AutoModelForSequenceClassification,
6
  pipeline
7
  )
 
8
 
9
  # ===============================================================
10
  # 1️⃣ Настройки и модели
11
  # ===============================================================
12
-
13
- # Fine-tuned MBTI Classifier (твоя модель)
14
  MBTI_MODEL = "f3nsmart/MBTIclassifier"
15
- mbti_pipe = pipeline("text-classification", model=MBTI_MODEL, return_all_scores=True)
 
16
 
17
- # Модель-интервьюер
18
- INTERVIEWER_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
19
 
20
- tokenizer_qwen = AutoTokenizer.from_pretrained(INTERVIEWER_MODEL)
21
- model_qwen = AutoModelForCausalLM.from_pretrained(
22
- INTERVIEWER_MODEL,
23
- torch_dtype="auto",
 
 
24
  device_map="auto"
25
  )
 
26
 
27
  llm_pipe = pipeline(
28
  "text-generation",
29
- model=model_qwen,
30
- tokenizer=tokenizer_qwen,
31
  max_new_tokens=70,
32
  temperature=0.7,
33
  top_p=0.9,
 
34
  )
35
 
36
  # ===============================================================
37
  # 2️⃣ Вспомогательные функции
38
  # ===============================================================
39
-
40
  def clean_question(text: str) -> str:
41
- """
42
- Удаляет все инструкции и оставляет чистый вопрос.
43
- """
44
- text = text.strip()
45
-
46
- # Берём только первую строку, если LLM вдруг вывела много
47
- text = text.split("\n")[0]
48
-
49
- # Иногда Qwen вставляет кавычки — убираем
50
- text = text.strip('"').strip("'")
51
-
52
- # Если модель вывела "User:" / "Assistant:" / "Instruction:" и т.п.
53
  bad_tokens = ["user:", "assistant:", "instruction", "interviewer", "system:"]
54
  for bad in bad_tokens:
55
  if bad.lower() in text.lower():
56
  text = text.split(bad)[-1].strip()
57
-
58
- # Если вопрос не оканчивается знаком вопроса — добавляем
59
  if "?" not in text:
60
  text = text.rstrip(".") + "?"
61
-
62
- # Мини-страховка от мусора
63
  if len(text.split()) < 3:
64
  return "What do you usually enjoy doing in your free time?"
65
-
66
  return text.strip()
67
 
68
  def generate_first_question():
69
- """Первый вопрос фиксированный (без ожидания генерации)"""
70
  return "What do you usually enjoy doing in your free time?"
71
 
72
  def analyze_and_ask(user_text, prev_count):
73
  if not user_text.strip():
74
  return "⚠️ Введите ответ.", "", prev_count
75
-
76
  try:
77
  n = int(prev_count.split("/")[0]) + 1
78
  except Exception:
@@ -83,7 +69,6 @@ def analyze_and_ask(user_text, prev_count):
83
  res_sorted = sorted(res, key=lambda x: x["score"], reverse=True)
84
  mbti_text = "\n".join([f"{r['label']} → {r['score']:.3f}" for r in res_sorted[:3]])
85
 
86
- # Новый, уточнённый промпт
87
  prompt = (
88
  f"User said: '{user_text}'. "
89
  "Generate one natural, open-ended question that starts with 'What', 'Why', 'How', or 'When'. "
@@ -94,25 +79,18 @@ def analyze_and_ask(user_text, prev_count):
94
 
95
  raw = llm_pipe(prompt)[0]["generated_text"]
96
  cleaned = clean_question(raw)
97
-
98
- # Если вопрос не начинается с нужного слова — создаём fallback
99
- valid_starts = ("What", "Why", "How", "When")
100
- if not cleaned.startswith(valid_starts):
101
  cleaned = "What motivates you to do the things you enjoy most?"
102
-
103
  return mbti_text, cleaned, counter
104
 
105
-
106
  # ===============================================================
107
  # 3️⃣ Интерфейс Gradio
108
  # ===============================================================
109
-
110
  with gr.Blocks(theme=gr.themes.Soft(), title="MBTI Personality Interviewer") as demo:
111
  gr.Markdown(
112
  "## 🧠 MBTI Personality Interviewer\n"
113
  "Определи личностный тип и получи следующий вопрос от интервьюера."
114
  )
115
-
116
  with gr.Row():
117
  with gr.Column(scale=1):
118
  inp = gr.Textbox(
@@ -127,8 +105,6 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MBTI Personality Interviewer") as
127
  progress = gr.Textbox(label="⏳ Прогресс", value="0/30")
128
 
129
  btn.click(analyze_and_ask, inputs=[inp, progress], outputs=[mbti_out, interviewer_out, progress])
130
-
131
- # Автоматическая загрузка первого вопроса
132
  demo.load(lambda: ("", generate_first_question(), "0/30"), inputs=None, outputs=[mbti_out, interviewer_out, progress])
133
 
134
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
  from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
6
  AutoModelForSequenceClassification,
7
  pipeline
8
  )
9
+ from peft import PeftModel # 👈 важно для LoRA адаптации
10
 
11
  # ===============================================================
12
  # 1️⃣ Настройки и модели
13
  # ===============================================================
 
 
14
  MBTI_MODEL = "f3nsmart/MBTIclassifier"
15
+ INTERVIEWER_BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16
+ INTERVIEWER_LORA = "f3nsmart/TinyLlama-MBTI-Interviewer-LoRA"
17
 
18
+ # --- MBTI классификатор ---
19
+ mbti_pipe = pipeline("text-classification", model=MBTI_MODEL, return_all_scores=True)
20
 
21
+ # --- Интервьюер TinyLlama + LoRA ---
22
+ print("🔄 Загрузка TinyLlama с адаптером LoRA...")
23
+ tokenizer_llama = AutoTokenizer.from_pretrained(INTERVIEWER_LORA)
24
+ base_model = AutoModelForCausalLM.from_pretrained(
25
+ INTERVIEWER_BASE,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
  device_map="auto"
28
  )
29
+ model_llora = PeftModel.from_pretrained(base_model, INTERVIEWER_LORA)
30
 
31
  llm_pipe = pipeline(
32
  "text-generation",
33
+ model=model_llora,
34
+ tokenizer=tokenizer_llama,
35
  max_new_tokens=70,
36
  temperature=0.7,
37
  top_p=0.9,
38
+ device_map="auto"
39
  )
40
 
41
  # ===============================================================
42
  # 2️⃣ Вспомогательные функции
43
  # ===============================================================
 
44
  def clean_question(text: str) -> str:
45
+ text = text.strip().split("\n")[0].strip('"').strip("'")
 
 
 
 
 
 
 
 
 
 
 
46
  bad_tokens = ["user:", "assistant:", "instruction", "interviewer", "system:"]
47
  for bad in bad_tokens:
48
  if bad.lower() in text.lower():
49
  text = text.split(bad)[-1].strip()
 
 
50
  if "?" not in text:
51
  text = text.rstrip(".") + "?"
 
 
52
  if len(text.split()) < 3:
53
  return "What do you usually enjoy doing in your free time?"
 
54
  return text.strip()
55
 
56
  def generate_first_question():
 
57
  return "What do you usually enjoy doing in your free time?"
58
 
59
  def analyze_and_ask(user_text, prev_count):
60
  if not user_text.strip():
61
  return "⚠️ Введите ответ.", "", prev_count
 
62
  try:
63
  n = int(prev_count.split("/")[0]) + 1
64
  except Exception:
 
69
  res_sorted = sorted(res, key=lambda x: x["score"], reverse=True)
70
  mbti_text = "\n".join([f"{r['label']} → {r['score']:.3f}" for r in res_sorted[:3]])
71
 
 
72
  prompt = (
73
  f"User said: '{user_text}'. "
74
  "Generate one natural, open-ended question that starts with 'What', 'Why', 'How', or 'When'. "
 
79
 
80
  raw = llm_pipe(prompt)[0]["generated_text"]
81
  cleaned = clean_question(raw)
82
+ if not cleaned.startswith(("What", "Why", "How", "When")):
 
 
 
83
  cleaned = "What motivates you to do the things you enjoy most?"
 
84
  return mbti_text, cleaned, counter
85
 
 
86
  # ===============================================================
87
  # 3️⃣ Интерфейс Gradio
88
  # ===============================================================
 
89
  with gr.Blocks(theme=gr.themes.Soft(), title="MBTI Personality Interviewer") as demo:
90
  gr.Markdown(
91
  "## 🧠 MBTI Personality Interviewer\n"
92
  "Определи личностный тип и получи следующий вопрос от интервьюера."
93
  )
 
94
  with gr.Row():
95
  with gr.Column(scale=1):
96
  inp = gr.Textbox(
 
105
  progress = gr.Textbox(label="⏳ Прогресс", value="0/30")
106
 
107
  btn.click(analyze_and_ask, inputs=[inp, progress], outputs=[mbti_out, interviewer_out, progress])
 
 
108
  demo.load(lambda: ("", generate_first_question(), "0/30"), inputs=None, outputs=[mbti_out, interviewer_out, progress])
109
 
110
  demo.launch()
requirements.txt CHANGED
@@ -3,4 +3,5 @@ datasets
3
  torch
4
  gradio
5
  openai
6
- accelerate
 
 
3
  torch
4
  gradio
5
  openai
6
+ accelerate
7
+ peft