Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import streamlit as st | |
| from llmlib.runtime import filled_model_registry | |
| from llmlib.model_registry import ModelEntry, ModelRegistry | |
| from llmlib.base_llm import Message | |
| from llmlib.bundler import Bundler | |
| from llmlib.bundler_request import BundlerRequest | |
| from login_mask_simple import check_password | |
| if not check_password(): | |
| st.stop() | |
| st.set_page_config(page_title="LLM App", layout="wide") | |
| st.title("LLM App") | |
| model_registry: ModelRegistry = filled_model_registry() | |
| def create_model_bundler() -> Bundler: | |
| return Bundler(registry=model_registry) | |
| def display_warnings(r: ModelRegistry, model_id: str) -> None: | |
| e1: ModelEntry = r.get_entry(model_id) | |
| if len(e1.warnings) > 0: | |
| st.warning(" \n".join(e1.warnings)) | |
| cs = st.columns(2) | |
| with cs[0]: | |
| model1_id: str = st.selectbox("Select model", model_registry.all_model_ids()) | |
| display_warnings(model_registry, model1_id) | |
| with cs[1]: | |
| if "img-key" not in st.session_state: | |
| st.session_state["img-key"] = 0 | |
| image = st.file_uploader("Include an image", key=st.session_state["img-key"]) | |
| if "messages1" not in st.session_state: | |
| st.session_state.messages1 = [] # list[Message] | |
| st.session_state.messages2 = [] # list[Message] | |
| if st.button("Restart chat"): | |
| st.session_state.messages1 = [] # list[Message] | |
| st.session_state.messages2 = [] # list[Message] | |
| def render_messages(msgs: list[Message]) -> None: | |
| for msg in msgs: | |
| render_message(msg) | |
| def render_message(msg: Message): | |
| with st.chat_message(msg.role): | |
| if msg.img_name is not None: | |
| render_img(msg) | |
| st.markdown(msg.msg) | |
| def render_img(msg: Message): | |
| st.image(msg.img, caption=msg.img_name, width=400) | |
| n_cols = 1 | |
| cs = st.columns(n_cols) | |
| render_messages(st.session_state.messages1) | |
| prompt = st.chat_input("Type here") | |
| if prompt is None: | |
| st.stop() | |
| msg = Message( | |
| role="user", | |
| msg=prompt, | |
| img_name=image.name if image is not None else None, | |
| img=Image.open(image) if image is not None else None, | |
| ) | |
| if image is not None: | |
| st.session_state["img-key"] += 1 | |
| st.session_state.messages1.append(msg) | |
| render_message(msg) | |
| model_bundler: Bundler = create_model_bundler() | |
| with st.spinner("Initializing model..."): | |
| model_bundler.set_model_on_gpu(model_id=model1_id) | |
| with st.spinner("Generating response..."): | |
| req = BundlerRequest(model_id=model1_id, msgs=st.session_state.messages1) | |
| response = model_bundler.get_response(req) | |
| msg = Message(role="assistant", msg=response) | |
| st.session_state.messages1.append(msg) | |
| render_message(msg) | |