Spaces:
Runtime error
Runtime error
Peter
commited on
Commit
·
950a38f
1
Parent(s):
203509f
✨ integrate constrained textgen
Browse filesSigned-off-by: Peter <[email protected]>
- app.py +2 -0
- converse.py +42 -17
app.py
CHANGED
|
@@ -85,6 +85,7 @@ def ask_gpt(
|
|
| 85 |
top_p=0.95,
|
| 86 |
top_k=25,
|
| 87 |
temperature=0.5,
|
|
|
|
| 88 |
) -> str:
|
| 89 |
"""
|
| 90 |
ask_gpt - helper function that asks the GPT model a question and returns the response
|
|
@@ -121,6 +122,7 @@ def ask_gpt(
|
|
| 121 |
temperature=temperature,
|
| 122 |
max_length=max_length,
|
| 123 |
min_length=min_length,
|
|
|
|
| 124 |
)
|
| 125 |
gpt_et = time.perf_counter()
|
| 126 |
gpt_rt = round(gpt_et - st, 2)
|
|
|
|
| 85 |
top_p=0.95,
|
| 86 |
top_k=25,
|
| 87 |
temperature=0.5,
|
| 88 |
+
constrained_generation=True,
|
| 89 |
) -> str:
|
| 90 |
"""
|
| 91 |
ask_gpt - helper function that asks the GPT model a question and returns the response
|
|
|
|
| 122 |
temperature=temperature,
|
| 123 |
max_length=max_length,
|
| 124 |
min_length=min_length,
|
| 125 |
+
constrained_generation = constrained_generation,
|
| 126 |
)
|
| 127 |
gpt_et = time.perf_counter()
|
| 128 |
gpt_rt = round(gpt_et - st, 2)
|
converse.py
CHANGED
|
@@ -10,6 +10,7 @@ import time
|
|
| 10 |
|
| 11 |
from grammar_improve import remove_trailing_punctuation
|
| 12 |
|
|
|
|
| 13 |
|
| 14 |
def discussion(
|
| 15 |
prompt_text: str,
|
|
@@ -28,6 +29,7 @@ def discussion(
|
|
| 28 |
num_return_sequences=1,
|
| 29 |
device=-1,
|
| 30 |
verbose=False,
|
|
|
|
| 31 |
):
|
| 32 |
"""
|
| 33 |
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
|
|
@@ -64,23 +66,46 @@ def discussion(
|
|
| 64 |
pp.pprint(this_prompt, indent=4)
|
| 65 |
# call the model
|
| 66 |
print("\n... generating...")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
|
| 85 |
bot_resp = ", ".join(bot_dialogue)
|
| 86 |
elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
|
|
|
|
| 10 |
|
| 11 |
from grammar_improve import remove_trailing_punctuation
|
| 12 |
|
| 13 |
+
from constrained_generation import constrained_generation
|
| 14 |
|
| 15 |
def discussion(
|
| 16 |
prompt_text: str,
|
|
|
|
| 29 |
num_return_sequences=1,
|
| 30 |
device=-1,
|
| 31 |
verbose=False,
|
| 32 |
+
constrained_generation=False,
|
| 33 |
):
|
| 34 |
"""
|
| 35 |
discussion - a function that takes in a prompt and generates a response. This function is meant to be used in a conversation loop, and is the main function for the bot.
|
|
|
|
| 66 |
pp.pprint(this_prompt, indent=4)
|
| 67 |
# call the model
|
| 68 |
print("\n... generating...")
|
| 69 |
+
if constrained_generation:
|
| 70 |
+
response = constrained_generation(
|
| 71 |
+
prompt=this_prompt,
|
| 72 |
+
pipeline=pipeline,
|
| 73 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 74 |
+
length_penalty=length_penalty,
|
| 75 |
+
repetition_penalty=1.0,
|
| 76 |
+
num_beams=4,
|
| 77 |
+
timeout=timeout,
|
| 78 |
+
verbose=verbose,
|
| 79 |
+
full_text=full_text,
|
| 80 |
+
speaker_name=speaker,
|
| 81 |
+
responder_name=responder,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
bot_dialogue = consolidate_texts(
|
| 85 |
+
name_resp=responder,
|
| 86 |
+
model_resp=response,
|
| 87 |
+
name_spk=speaker,
|
| 88 |
+
verbose=verbose,
|
| 89 |
+
print_debug=True,
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
bot_dialogue = gen_response(
|
| 93 |
+
this_prompt,
|
| 94 |
+
pipeline,
|
| 95 |
+
speaker,
|
| 96 |
+
responder,
|
| 97 |
+
timeout=timeout,
|
| 98 |
+
max_length=max_length,
|
| 99 |
+
top_p=top_p,
|
| 100 |
+
top_k=top_k,
|
| 101 |
+
temperature=temperature,
|
| 102 |
+
full_text=full_text,
|
| 103 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 104 |
+
length_penalty=length_penalty,
|
| 105 |
+
num_return_sequences=num_return_sequences,
|
| 106 |
+
device=device,
|
| 107 |
+
verbose=verbose,
|
| 108 |
+
)
|
| 109 |
if isinstance(bot_dialogue, list) and len(bot_dialogue) > 1:
|
| 110 |
bot_resp = ", ".join(bot_dialogue)
|
| 111 |
elif isinstance(bot_dialogue, list) and len(bot_dialogue) == 1:
|