Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,12 +5,15 @@ import streamlit as st
|
|
| 5 |
import torch
|
| 6 |
from transformers import pipeline, set_seed
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
|
| 11 |
DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
|
| 12 |
if DEVICE != "cpu" and not torch.cuda.is_available():
|
| 13 |
DEVICE = "cpu"
|
|
|
|
| 14 |
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
|
| 15 |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
| 16 |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
|
@@ -133,15 +136,13 @@ class TextGeneration:
|
|
| 133 |
)[0]["generated_text"]
|
| 134 |
|
| 135 |
|
| 136 |
-
#@st.cache(
|
| 137 |
-
@st.cache(allow_output_mutation=True
|
| 138 |
def load_text_generator():
|
| 139 |
text_generator = TextGeneration()
|
| 140 |
text_generator.load()
|
| 141 |
return text_generator
|
| 142 |
|
| 143 |
-
generator = load_text_generator()
|
| 144 |
-
|
| 145 |
|
| 146 |
def main():
|
| 147 |
st.set_page_config(
|
|
@@ -151,7 +152,7 @@ def main():
|
|
| 151 |
initial_sidebar_state="expanded"
|
| 152 |
)
|
| 153 |
style()
|
| 154 |
-
|
| 155 |
st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
|
| 156 |
|
| 157 |
max_length = st.sidebar.slider(
|
|
|
|
| 5 |
import torch
|
| 6 |
from transformers import pipeline, set_seed
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
+
import logging
|
| 9 |
+
logger = logging.getLogger()
|
| 10 |
+
logger.addHandler(logging.StreamHandler())
|
| 11 |
|
| 12 |
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None)
|
| 13 |
DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0
|
| 14 |
if DEVICE != "cpu" and not torch.cuda.is_available():
|
| 15 |
DEVICE = "cpu"
|
| 16 |
+
logger.info(f"DEVICE {DEVICE}")
|
| 17 |
DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16
|
| 18 |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B")
|
| 19 |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024))
|
|
|
|
| 136 |
)[0]["generated_text"]
|
| 137 |
|
| 138 |
|
| 139 |
+
#@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
|
| 140 |
+
@st.cache(allow_output_mutation=True)
|
| 141 |
def load_text_generator():
|
| 142 |
text_generator = TextGeneration()
|
| 143 |
text_generator.load()
|
| 144 |
return text_generator
|
| 145 |
|
|
|
|
|
|
|
| 146 |
|
| 147 |
def main():
|
| 148 |
st.set_page_config(
|
|
|
|
| 152 |
initial_sidebar_state="expanded"
|
| 153 |
)
|
| 154 |
style()
|
| 155 |
+
generator = load_text_generator()
|
| 156 |
st.sidebar.markdown(SIDEBAR_INFO, unsafe_allow_html=True)
|
| 157 |
|
| 158 |
max_length = st.sidebar.slider(
|