Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,7 +16,7 @@ def load_model(model_name):
|
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 17 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 18 |
return model, tokenizer
|
| 19 |
-
def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95
|
| 20 |
if len(input_text) == 0:
|
| 21 |
input_text = ""
|
| 22 |
encoded_prompt = tokenizer.encode(
|
|
@@ -25,8 +25,7 @@ def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95,
|
|
| 25 |
if encoded_prompt.size()[-1] == 0:
|
| 26 |
input_ids = None
|
| 27 |
else:
|
| 28 |
-
input_ids = encoded_prompt
|
| 29 |
-
|
| 30 |
bad_words = bad_words.split()
|
| 31 |
bad_word_ids = []
|
| 32 |
for bad_word in bad_words:
|
|
@@ -90,11 +89,11 @@ if __name__ == "__main__":
|
|
| 90 |
if len(text_area.strip()) == 0:
|
| 91 |
text_area = random.choice(suggested_text_list)
|
| 92 |
result = extend(input_text=text_area,
|
| 93 |
-
num_return_sequences=int(num_return_sequences),
|
|
|
|
| 94 |
max_size=int(max_len),
|
| 95 |
top_k=int(top_k),
|
| 96 |
-
top_p=float(top_p)
|
| 97 |
-
bad_words = bad_words)
|
| 98 |
print("Done length: " + str(len(result)) + " bytes")
|
| 99 |
#<div class="rtl" dir="rtl" style="text-align:right;">
|
| 100 |
st.markdown(f"{result}", unsafe_allow_html=True)
|
|
|
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 17 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 18 |
return model, tokenizer
|
| 19 |
+
def extend(input_text, num_return_sequences, bad_words, max_size=20, top_k=50, top_p=0.95):
|
| 20 |
if len(input_text) == 0:
|
| 21 |
input_text = ""
|
| 22 |
encoded_prompt = tokenizer.encode(
|
|
|
|
| 25 |
if encoded_prompt.size()[-1] == 0:
|
| 26 |
input_ids = None
|
| 27 |
else:
|
| 28 |
+
input_ids = encoded_prompt
|
|
|
|
| 29 |
bad_words = bad_words.split()
|
| 30 |
bad_word_ids = []
|
| 31 |
for bad_word in bad_words:
|
|
|
|
| 89 |
if len(text_area.strip()) == 0:
|
| 90 |
text_area = random.choice(suggested_text_list)
|
| 91 |
result = extend(input_text=text_area,
|
| 92 |
+
num_return_sequences=int(num_return_sequences),
|
| 93 |
+
bad_words = bad_words,
|
| 94 |
max_size=int(max_len),
|
| 95 |
top_k=int(top_k),
|
| 96 |
+
top_p=float(top_p))
|
|
|
|
| 97 |
print("Done length: " + str(len(result)) + " bytes")
|
| 98 |
#<div class="rtl" dir="rtl" style="text-align:right;">
|
| 99 |
st.markdown(f"{result}", unsafe_allow_html=True)
|