add multiprocessing in inference and clean code
Browse files
app.py
CHANGED
|
@@ -1,29 +1,27 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
-
from transformers import pipeline
|
| 4 |
-
import torch
|
| 5 |
import json
|
| 6 |
import pandas as pd
|
| 7 |
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
GITHUB_CODE = "https://huggingface.co/datasets/lvwerra/github-code"
|
| 10 |
INCODER_IMG = "https://huggingface.co/datasets/loubnabnl/repo-images/raw/main/incoder.png"
|
| 11 |
|
| 12 |
-
@st.cache(allow_output_mutation=True)
|
| 13 |
-
def load_tokenizer(model_ckpt):
|
| 14 |
-
return AutoTokenizer.from_pretrained(model_ckpt)
|
| 15 |
-
|
| 16 |
-
@st.cache(allow_output_mutation=True)
|
| 17 |
-
def load_model(model_ckpt):
|
| 18 |
-
model = AutoModelForCausalLM.from_pretrained(model_ckpt, low_cpu_mem_usage=True)
|
| 19 |
-
return model
|
| 20 |
-
|
| 21 |
@st.cache()
|
| 22 |
def load_examples():
|
| 23 |
with open("utils/examples.json", "r") as f:
|
| 24 |
examples = json.load(f)
|
| 25 |
return examples
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
st.set_page_config(page_icon=":laptop:", layout="wide")
|
| 28 |
|
| 29 |
st.sidebar.header("Models")
|
|
@@ -84,9 +82,11 @@ elif selected_task == "Code generation":
|
|
| 84 |
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
|
| 85 |
if st.button("Generate code!"):
|
| 86 |
with st.spinner("Generating code..."):
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import pandas as pd
|
| 3 |
import requests
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
from functools import partial
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
|
| 9 |
GITHUB_CODE = "https://huggingface.co/datasets/lvwerra/github-code"
|
| 10 |
INCODER_IMG = "https://huggingface.co/datasets/loubnabnl/repo-images/raw/main/incoder.png"
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
@st.cache()
|
| 13 |
def load_examples():
|
| 14 |
with open("utils/examples.json", "r") as f:
|
| 15 |
examples = json.load(f)
|
| 16 |
return examples
|
| 17 |
|
| 18 |
+
def generate_code(model_name, gen_prompt, max_new_tokens, temperature, seed):
|
| 19 |
+
url = f'https://hf.space/embed/loubnabnl/{model_name.lower()}-subspace/+/api/predict/'
|
| 20 |
+
r = requests.post(url=url, json={"data": [gen_prompt, max_new_tokens, temperature, seed]})
|
| 21 |
+
generated_text = r.json()['data'][0]
|
| 22 |
+
st.markdown(model_name)
|
| 23 |
+
st.code(generated_text)
|
| 24 |
+
|
| 25 |
st.set_page_config(page_icon=":laptop:", layout="wide")
|
| 26 |
|
| 27 |
st.sidebar.header("Models")
|
|
|
|
| 82 |
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
|
| 83 |
if st.button("Generate code!"):
|
| 84 |
with st.spinner("Generating code..."):
|
| 85 |
+
# Create a multiprocessing Pool
|
| 86 |
+
pool = Pool()
|
| 87 |
+
generate_parallel=partial(generate_code,
|
| 88 |
+
gen_prompt=gen_prompt,
|
| 89 |
+
max_new_tokens=max_new_tokens,
|
| 90 |
+
temperature=temperature,
|
| 91 |
+
seed=seed)
|
| 92 |
+
pool.map(generate_parallel, selected_models)
|