Spaces:
Sleeping
Sleeping
| ### import packages | |
| import torch | |
| from transformers import ( | |
| PaliGemmaProcessor, | |
| PaliGemmaForConditionalGeneration, | |
| ) | |
| import streamlit as st | |
| from PIL import Image | |
| import os | |
| ### write access token in secrets | |
| token = os.environ.get('HF_TOKEN') | |
| ### choose a paligemma model | |
| # See https://huggingface.co/collections/google/paligemma-2-release-67500e1e1dbfdd4dee27ba48 | |
| model_id = "google/paligemma2-3b-pt-896" | |
| def model_setup(model_id): | |
| """ | |
| Sets up the model with @st.cache_resource to cache the function. | |
| Args: | |
| model_id: one of the paligemma models | |
| Return: | |
| model: from PaliGemmaForConditionalGeneration.from_pretrained | |
| processor: from PaliGemmaProcessor.from_pretrained | |
| """ | |
| model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval() | |
| processor = PaliGemmaProcessor.from_pretrained(model_id,token=token) | |
| return model,processor | |
| def run_model(prompt,image): | |
| """ | |
| Performs inference on user's prompt and image | |
| Args: | |
| prompt: user prompt or task | |
| image: user's uploaded image | |
| Returns: | |
| output text | |
| """ | |
| model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device) | |
| input_len = model_inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False) | |
| generation = generation[0][input_len:] | |
| return processor.decode(generation, skip_special_tokens=True) | |
| def initialize(): | |
| """ | |
| Initializes chat history | |
| """ | |
| st.session_state.messages = [] | |
| ### load model | |
| model,processor = model_setup(model_id) | |
| ### upload a file | |
| uploaded_file = st.file_uploader("Choose an image",on_change=initialize) | |
| if uploaded_file: | |
| st.image(uploaded_file) | |
| image = Image.open(uploaded_file).convert("RGB") | |
| # tasks: Caption by default. Accept user prompt only when selected | |
| task = st.radio( | |
| "Task", | |
| tuple(['Caption','OCR','Segment','Enter your prompt']), | |
| horizontal=True) | |
| # display chat messages from history on app rerun | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if task == 'Enter your prompt': | |
| if prompt := st.chat_input("Type here!",key="user_prompt"): | |
| # display user message in chat message container | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # run the VLM | |
| response = run_model(prompt,image) | |
| # display assistant response in chat message container | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| # display user message in chat message container | |
| with st.chat_message("user"): | |
| st.markdown(task) | |
| # add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": task}) | |
| # run the VLM | |
| response = run_model(task,image) | |
| # display assistant response in chat message container | |
| with st.chat_message("assistant"): | |
| st.markdown(response) | |
| # add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) |