Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	create space and add fonts
Browse files- IDEFICS_logo.png +0 -0
- app_dialogue.py +1022 -0
- fonts/Impacted.ttf +0 -0
- fonts/Roboto-Regular.ttf +0 -0
- fonts/impact.ttf +0 -0
- fonts/unicode.impact.ttf +0 -0
- requirements.txt +19 -0
    	
        IDEFICS_logo.png
    ADDED
    
    |   | 
    	
        app_dialogue.py
    ADDED
    
    | @@ -0,0 +1,1022 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import ast
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import hashlib
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import re
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
            from typing import List, Optional, Tuple
         | 
| 10 | 
            +
            from urllib.parse import urlparse
         | 
| 11 | 
            +
            from PIL import Image, ImageDraw, ImageFont
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import random
         | 
| 14 | 
            +
            import gradio as gr
         | 
| 15 | 
            +
            import PIL
         | 
| 16 | 
            +
            from gradio import processing_utils
         | 
| 17 | 
            +
            from gradio_client.client import DEFAULT_TEMP_DIR
         | 
| 18 | 
            +
            from text_generation import Client
         | 
| 19 | 
            +
            from transformers import AutoProcessor
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            MODELS = [
         | 
| 23 | 
            +
                # "HuggingFaceM4/idefics-9b-instruct",
         | 
| 24 | 
            +
                "HuggingFaceM4/idefics-80b-instruct",
         | 
| 25 | 
            +
            ]
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            API_PATHS = {
         | 
| 28 | 
            +
                "HuggingFaceM4/idefics-9b-instruct": (
         | 
| 29 | 
            +
                    "https://api-inference.huggingface.co/models/HuggingFaceM4/idefics-9b-instruct"
         | 
| 30 | 
            +
                ),
         | 
| 31 | 
            +
                "HuggingFaceM4/idefics-80b-instruct": (
         | 
| 32 | 
            +
                    "https://api-inference.huggingface.co/models/HuggingFaceM4/idefics-80b-instruct"
         | 
| 33 | 
            +
                ),
         | 
| 34 | 
            +
            }
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            SYSTEM_PROMPT = [
         | 
| 37 | 
            +
                """The following is a conversation between a highly knowledgeable and intelligent visual AI assistant, called Assistant, and a human user, called User.
         | 
| 38 | 
            +
            In the following interactions, User and Assistant will converse in natural language, and Assistant will answer in a sassy way.
         | 
| 39 | 
            +
            Assistant's main purpose is to create funny meme texts from the images User provides.
         | 
| 40 | 
            +
            Assistant should be funny, sassy, and impertinent, and sometimes Assistant roasts people.
         | 
| 41 | 
            +
            Assistant should not be mean. It should not say toxic, homophobic, sexist, racist, things or any demeaning things that can make people uncomfortable.
         | 
| 42 | 
            +
            Assistant was created by Hugging Face.
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            Here's a conversation example:""",
         | 
| 45 | 
            +
                """\nUser:""",
         | 
| 46 | 
            +
                "https://ichef.bbci.co.uk/news/976/cpsprodpb/7727/production/_103330503_musk3.jpg",
         | 
| 47 | 
            +
                "Write a meme for that image.<end_of_utterance>",
         | 
| 48 | 
            +
                """\nAssistant: When you're trying to quit smoking but the cravings are too strong.<end_of_utterance>""",
         | 
| 49 | 
            +
                "\nUser:How about this image?",
         | 
| 50 | 
            +
                "https://www.boredpanda.com/blog/wp-content/uploads/2017/01/image-copy-copy-587d0e7918b57-png__700.jpg",
         | 
| 51 | 
            +
                "Write something funny about this image.<end_of_utterance>",
         | 
| 52 | 
            +
                """\nAssistant: Eggcellent service!<end_of_utterance>""",
         | 
| 53 | 
            +
                "\nUser: Roast this person",
         | 
| 54 | 
            +
                "https://i.pinimg.com/564x/98/34/4b/98344b2483bd7c8b71a5c0fed6fe20b6.jpg",
         | 
| 55 | 
            +
                "<end_of_utterance>",
         | 
| 56 | 
            +
                """\nAssistant: Damn your handwritting is pretty awful. But I suppose it must be pretty hard to hold a pen, considering you are a hammerhead shark.<end_of_utterance>""",
         | 
| 57 | 
            +
            ]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            BAN_TOKENS = (  # For documentation puporse. We are not using this list, it is hardcoded inside `idefics_causal_lm.py` inside TGI.
         | 
| 60 | 
            +
                "<image>;<fake_token_around_image>"
         | 
| 61 | 
            +
            )
         | 
| 62 | 
            +
            EOS_STRINGS = ["<end_of_utterance>", "\nUser:"]
         | 
| 63 | 
            +
            STOP_SUSPECT_LIST = []
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            GRADIO_LINK = "https://huggingfacem4-ai-meme-generator.hf.space"
         | 
| 66 | 
            +
            API_TOKEN = os.getenv("HF_AUTH_TOKEN")
         | 
| 67 | 
            +
            IDEFICS_LOGO = "https://huggingface.co/spaces/HuggingFaceM4/idefics_playground/resolve/main/IDEFICS_logo.png"
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            PROCESSOR = AutoProcessor.from_pretrained(
         | 
| 70 | 
            +
                "HuggingFaceM4/idefics-9b-instruct",
         | 
| 71 | 
            +
                token=API_TOKEN,
         | 
| 72 | 
            +
            )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            BOT_AVATAR = "IDEFICS_logo.png"
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            logging.basicConfig(level=logging.INFO)
         | 
| 77 | 
            +
            logger = logging.getLogger()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            # Monkey patch adapted from gradio.components.image.Image - mostly to make the `save` step optional in `pil_to_temp_file`
         | 
| 81 | 
            +
            def hash_bytes(bytes: bytes):
         | 
| 82 | 
            +
                sha1 = hashlib.sha1()
         | 
| 83 | 
            +
                sha1.update(bytes)
         | 
| 84 | 
            +
                return sha1.hexdigest()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def pil_to_temp_file(
         | 
| 88 | 
            +
                img: PIL.Image.Image, dir: str = DEFAULT_TEMP_DIR, format: str = "png"
         | 
| 89 | 
            +
            ) -> str:
         | 
| 90 | 
            +
                """Save a PIL image into a temp file"""
         | 
| 91 | 
            +
                bytes_data = processing_utils.encode_pil_to_bytes(img, format)
         | 
| 92 | 
            +
                temp_dir = Path(dir) / hash_bytes(bytes_data)
         | 
| 93 | 
            +
                temp_dir.mkdir(exist_ok=True, parents=True)
         | 
| 94 | 
            +
                filename = str(temp_dir / f"image.{format}")
         | 
| 95 | 
            +
                if not os.path.exists(filename):
         | 
| 96 | 
            +
                    img.save(filename, pnginfo=processing_utils.get_pil_metadata(img))
         | 
| 97 | 
            +
                return filename
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            def add_file(file):
         | 
| 101 | 
            +
                return file.name, gr.update(label="πΌοΈ Uploaded!")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def add_file_gallery(selected_state: gr.SelectData, gallery_list: List[str]):
         | 
| 105 | 
            +
                gr.update(label="π Upload image", interactive=True)
         | 
| 106 | 
            +
                return (
         | 
| 107 | 
            +
                    "Write a meme about this image.",
         | 
| 108 | 
            +
                    gallery_list[selected_state.index]["name"],
         | 
| 109 | 
            +
                    "",
         | 
| 110 | 
            +
                )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def choose_gallery(gallery_type: str):
         | 
| 114 | 
            +
                if gallery_type == "Meme templates":
         | 
| 115 | 
            +
                    image_gallery_list = [
         | 
| 116 | 
            +
                        f"example_images/meme_templates/{ex_image}"
         | 
| 117 | 
            +
                        for ex_image in os.listdir("example_images/meme_templates")
         | 
| 118 | 
            +
                    ]
         | 
| 119 | 
            +
                elif gallery_type == "Funny images":
         | 
| 120 | 
            +
                    image_gallery_list = [
         | 
| 121 | 
            +
                        f"example_images/funny_images/{ex_image}"
         | 
| 122 | 
            +
                        for ex_image in os.listdir("example_images/funny_images")
         | 
| 123 | 
            +
                    ]
         | 
| 124 | 
            +
                elif gallery_type == "Politics":
         | 
| 125 | 
            +
                    image_gallery_list = [
         | 
| 126 | 
            +
                        f"example_images/politics_memes/{ex_image}"
         | 
| 127 | 
            +
                        for ex_image in os.listdir("example_images/politics_memes")
         | 
| 128 | 
            +
                    ]
         | 
| 129 | 
            +
                else:
         | 
| 130 | 
            +
                    image_gallery_list = [
         | 
| 131 | 
            +
                        f"example_images/{image_dir}/{ex_image}"
         | 
| 132 | 
            +
                        for image_dir in os.listdir("example_images")
         | 
| 133 | 
            +
                        for ex_image in os.listdir(f"example_images/{image_dir}")
         | 
| 134 | 
            +
                    ]
         | 
| 135 | 
            +
                random.shuffle(image_gallery_list)
         | 
| 136 | 
            +
                return image_gallery_list
         | 
| 137 | 
            +
             | 
| 138 | 
            +
             | 
| 139 | 
            +
            # This is a hack to make pre-computing the default examples work.
         | 
| 140 | 
            +
            # During normal inference, we pass images as url to a local file using the method `gradio_link`
         | 
| 141 | 
            +
            # which allows the tgi server to fetch the local image from the frontend server.
         | 
| 142 | 
            +
            # however, we are building the space (and pre-computing is part of building the space), the frontend is not available
         | 
| 143 | 
            +
            # and won't answer. So tgi server will try to fetch an image that is not available yet, which will result in a timeout error
         | 
| 144 | 
            +
            # because tgi will never be able to return the generation.
         | 
| 145 | 
            +
            # To bypass that, we pass instead the images URLs from the spaces repo.
         | 
| 146 | 
            +
            DEFAULT_IMAGES_TMP_PATH_TO_URL = {}
         | 
| 147 | 
            +
            for image_dir in os.listdir("example_images"):
         | 
| 148 | 
            +
                for im_path in os.listdir(f"example_images/{image_dir}"):
         | 
| 149 | 
            +
                    H = gr.Image(
         | 
| 150 | 
            +
                        f"example_images/{image_dir}/{im_path}", visible=False, type="filepath"
         | 
| 151 | 
            +
                    )
         | 
| 152 | 
            +
                    tmp_filename = H.preprocess(H.value)
         | 
| 153 | 
            +
                    DEFAULT_IMAGES_TMP_PATH_TO_URL[
         | 
| 154 | 
            +
                        tmp_filename
         | 
| 155 | 
            +
                    ] = f"https://huggingface.co/spaces/HuggingFaceM4/AI_Meme_Generator/resolve/main/example_images/{image_dir}/{im_path}"
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            # Utils to handle the image markdown display logic
         | 
| 159 | 
            +
            def split_str_on_im_markdown(string: str) -> List[str]:
         | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
                Extract from a string (typically the user prompt string) the potential images from markdown
         | 
| 162 | 
            +
                Examples:
         | 
| 163 | 
            +
                - `User:Describe this image.` would become `["User:", "https://favurl.com/chicken_on_money.png", "Describe this image."]`
         | 
| 164 | 
            +
                - `User:Describe this image.` would become `["User:", "/my_temp/chicken_on_money.png", "Describe this image."]`
         | 
| 165 | 
            +
                """
         | 
| 166 | 
            +
                IMAGES_PATTERN = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
         | 
| 167 | 
            +
                parts = []
         | 
| 168 | 
            +
                cursor = 0
         | 
| 169 | 
            +
                for pattern in IMAGES_PATTERN.finditer(string):
         | 
| 170 | 
            +
                    start = pattern.start()
         | 
| 171 | 
            +
                    if start != cursor:
         | 
| 172 | 
            +
                        parts.append(string[cursor:start])
         | 
| 173 | 
            +
                    image_url = pattern.group(1)
         | 
| 174 | 
            +
                    if image_url.startswith("/file="):
         | 
| 175 | 
            +
                        image_url = image_url[6:]  # Remove the 'file=' prefix
         | 
| 176 | 
            +
                    parts.append(image_url)
         | 
| 177 | 
            +
                    cursor = pattern.end()
         | 
| 178 | 
            +
                if cursor != len(string):
         | 
| 179 | 
            +
                    parts.append(string[cursor:])
         | 
| 180 | 
            +
                return parts
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def is_image(string: str) -> bool:
         | 
| 184 | 
            +
                """
         | 
| 185 | 
            +
                There are two ways for images: local image path or url.
         | 
| 186 | 
            +
                """
         | 
| 187 | 
            +
                return is_url(string) or string.startswith(DEFAULT_TEMP_DIR)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            def is_url(string: str) -> bool:
         | 
| 191 | 
            +
                """
         | 
| 192 | 
            +
                Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
         | 
| 193 | 
            +
                invalidated the url
         | 
| 194 | 
            +
                """
         | 
| 195 | 
            +
                if " " in string:
         | 
| 196 | 
            +
                    return False
         | 
| 197 | 
            +
                result = urlparse(string)
         | 
| 198 | 
            +
                return all([result.scheme, result.netloc])
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            def isolate_images_urls(prompt_list: List) -> List:
         | 
| 202 | 
            +
                """
         | 
| 203 | 
            +
                Convert a full string prompt to the list format expected by the processor.
         | 
| 204 | 
            +
                In particular, image urls (as delimited by <fake_token_around_image>) should be their own elements.
         | 
| 205 | 
            +
                From:
         | 
| 206 | 
            +
                ```
         | 
| 207 | 
            +
                [
         | 
| 208 | 
            +
                    "bonjour<fake_token_around_image><image:IMG_URL><fake_token_around_image>hello",
         | 
| 209 | 
            +
                    PIL.Image.Image,
         | 
| 210 | 
            +
                    "Aurevoir",
         | 
| 211 | 
            +
                ]
         | 
| 212 | 
            +
                ```
         | 
| 213 | 
            +
                to:
         | 
| 214 | 
            +
                ```
         | 
| 215 | 
            +
                [
         | 
| 216 | 
            +
                    "bonjour",
         | 
| 217 | 
            +
                    IMG_URL,
         | 
| 218 | 
            +
                    "hello",
         | 
| 219 | 
            +
                    PIL.Image.Image,
         | 
| 220 | 
            +
                    "Aurevoir",
         | 
| 221 | 
            +
                ]
         | 
| 222 | 
            +
                ```
         | 
| 223 | 
            +
                """
         | 
| 224 | 
            +
                linearized_list = []
         | 
| 225 | 
            +
                for prompt in prompt_list:
         | 
| 226 | 
            +
                    # Prompt can be either a string, or a PIL image
         | 
| 227 | 
            +
                    if isinstance(prompt, PIL.Image.Image):
         | 
| 228 | 
            +
                        linearized_list.append(prompt)
         | 
| 229 | 
            +
                    elif isinstance(prompt, str):
         | 
| 230 | 
            +
                        if "<fake_token_around_image>" not in prompt:
         | 
| 231 | 
            +
                            linearized_list.append(prompt)
         | 
| 232 | 
            +
                        else:
         | 
| 233 | 
            +
                            prompt_splitted = prompt.split("<fake_token_around_image>")
         | 
| 234 | 
            +
                            for ps in prompt_splitted:
         | 
| 235 | 
            +
                                if ps == "":
         | 
| 236 | 
            +
                                    continue
         | 
| 237 | 
            +
                                if ps.startswith("<image:"):
         | 
| 238 | 
            +
                                    linearized_list.append(ps[7:-1])
         | 
| 239 | 
            +
                                else:
         | 
| 240 | 
            +
                                    linearized_list.append(ps)
         | 
| 241 | 
            +
                    else:
         | 
| 242 | 
            +
                        raise TypeError(
         | 
| 243 | 
            +
                            f"Unrecognized type for `prompt`. Got {type(type(prompt))}. Was expecting something in [`str`,"
         | 
| 244 | 
            +
                            " `PIL.Image.Image`]"
         | 
| 245 | 
            +
                        )
         | 
| 246 | 
            +
                return linearized_list
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            def fetch_images(url_list: str) -> PIL.Image.Image:
         | 
| 250 | 
            +
                """Fetching images"""
         | 
| 251 | 
            +
                return PROCESSOR.image_processor.fetch_images(url_list)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
            def handle_manual_images_in_user_prompt(user_prompt: str) -> List[str]:
         | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                Handle the case of textually manually inputted images (i.e. the `<fake_token_around_image><image:IMG_URL><fake_token_around_image>`) in the user prompt
         | 
| 257 | 
            +
                by fetching them, saving them locally and replacing the whole sub-sequence the image local path.
         | 
| 258 | 
            +
                """
         | 
| 259 | 
            +
                if "<fake_token_around_image>" in user_prompt:
         | 
| 260 | 
            +
                    splitted_user_prompt = isolate_images_urls([user_prompt])
         | 
| 261 | 
            +
                    resulting_user_prompt = []
         | 
| 262 | 
            +
                    for u_p in splitted_user_prompt:
         | 
| 263 | 
            +
                        if is_url(u_p):
         | 
| 264 | 
            +
                            img = fetch_images([u_p])[0]
         | 
| 265 | 
            +
                            tmp_file = pil_to_temp_file(img)
         | 
| 266 | 
            +
                            resulting_user_prompt.append(tmp_file)
         | 
| 267 | 
            +
                        else:
         | 
| 268 | 
            +
                            resulting_user_prompt.append(u_p)
         | 
| 269 | 
            +
                    return resulting_user_prompt
         | 
| 270 | 
            +
                else:
         | 
| 271 | 
            +
                    return [user_prompt]
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def gradio_link(img_path: str) -> str:
         | 
| 275 | 
            +
                url = f"{GRADIO_LINK}/file={img_path}"
         | 
| 276 | 
            +
                return url
         | 
| 277 | 
            +
             | 
| 278 | 
            +
             | 
| 279 | 
            +
            def prompt_list_to_markdown(prompt_list: List[str], size: int = None) -> str:
         | 
| 280 | 
            +
                """
         | 
| 281 | 
            +
                Convert a user prompt in the list format (i.e. elements are either a PIL image or a string) into
         | 
| 282 | 
            +
                the markdown format that is used for the chatbot history and rendering.
         | 
| 283 | 
            +
                """
         | 
| 284 | 
            +
                resulting_string = ""
         | 
| 285 | 
            +
                for elem in prompt_list:
         | 
| 286 | 
            +
                    if is_image(elem):
         | 
| 287 | 
            +
                        if is_url(elem):
         | 
| 288 | 
            +
                            if size is not None:
         | 
| 289 | 
            +
                                resulting_string += f"<img src={elem} width={size} height={size}>"
         | 
| 290 | 
            +
                            else:
         | 
| 291 | 
            +
                                resulting_string += f""
         | 
| 292 | 
            +
                        else:
         | 
| 293 | 
            +
                            if size is not None:
         | 
| 294 | 
            +
                                resulting_string += f"<img src='/file={str(elem)}' width='{size}' height={str(size)}>"
         | 
| 295 | 
            +
                            else:
         | 
| 296 | 
            +
                                resulting_string += f""
         | 
| 297 | 
            +
                    else:
         | 
| 298 | 
            +
                        resulting_string += elem
         | 
| 299 | 
            +
                return resulting_string
         | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
            def prompt_list_to_tgi_input(prompt_list: List[str]) -> str:
         | 
| 303 | 
            +
                """
         | 
| 304 | 
            +
                TGI expects a string that contains both text and images in the image markdown format (i.e. the `![]()` ).
         | 
| 305 | 
            +
                The images links are parsed on TGI side
         | 
| 306 | 
            +
                """
         | 
| 307 | 
            +
                result_string_input = ""
         | 
| 308 | 
            +
                for elem in prompt_list:
         | 
| 309 | 
            +
                    if is_image(elem):
         | 
| 310 | 
            +
                        if is_url(elem):
         | 
| 311 | 
            +
                            result_string_input += f""
         | 
| 312 | 
            +
                        else:
         | 
| 313 | 
            +
                            result_string_input += f"})"
         | 
| 314 | 
            +
                    else:
         | 
| 315 | 
            +
                        result_string_input += elem
         | 
| 316 | 
            +
                return result_string_input
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            def remove_spaces_around_token(text: str) -> str:
         | 
| 320 | 
            +
                pattern = r"\s*(<fake_token_around_image>)\s*"
         | 
| 321 | 
            +
                replacement = r"\1"
         | 
| 322 | 
            +
                result = re.sub(pattern, replacement, text)
         | 
| 323 | 
            +
                return result
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            # Chatbot utils
         | 
| 327 | 
            +
            Radio_options_to_font = {}
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            def insert_backslash(string, max_length=50):
         | 
| 331 | 
            +
                # Check if the string length is less than or equal to the max_length
         | 
| 332 | 
            +
                if len(string) <= max_length:
         | 
| 333 | 
            +
                    return string
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                # Start from the max_length character and search for the last space character before it
         | 
| 336 | 
            +
                for i in range(max_length - 1, -1, -1):
         | 
| 337 | 
            +
                    if string[i] == " ":
         | 
| 338 | 
            +
                        # Insert a backslash before the last space character
         | 
| 339 | 
            +
                        return string[:i] + "\n" + string[i:]
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                # If no space character is found, just insert a backslash at the max_length character
         | 
| 342 | 
            +
                return string[:max_length] + "\n" + string[max_length:]
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            def resize_with_ratio(image: PIL.Image.Image, fixed_width: int) -> PIL.Image.Image:
         | 
| 346 | 
            +
                # Get the current width and height
         | 
| 347 | 
            +
                width, height = image.size
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # Calculate the new width while maintaining the aspect ratio up to 2:3 ratio
         | 
| 350 | 
            +
                new_width = fixed_width
         | 
| 351 | 
            +
                new_height = min(int(height * (new_width / width)), int(1.5 * new_width))
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                # Resize the image
         | 
| 354 | 
            +
                resized_img = image.resize((new_width, new_height), Image.LANCZOS)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                return resized_img
         | 
| 357 | 
            +
             | 
| 358 | 
            +
             | 
| 359 | 
            +
            def test_font_size(
         | 
| 360 | 
            +
                draw,
         | 
| 361 | 
            +
                image,
         | 
| 362 | 
            +
                text,
         | 
| 363 | 
            +
                font,
         | 
| 364 | 
            +
                font_meme_text,
         | 
| 365 | 
            +
                num_lines=1,
         | 
| 366 | 
            +
                min_font=35,
         | 
| 367 | 
            +
                font_size_reduction=5,
         | 
| 368 | 
            +
            ):
         | 
| 369 | 
            +
                text_width = draw.textlength(text, font)
         | 
| 370 | 
            +
                text_is_too_long = True
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                if num_lines == 1:
         | 
| 373 | 
            +
                    while font.size > min_font and text_is_too_long:
         | 
| 374 | 
            +
                        font = ImageFont.truetype(
         | 
| 375 | 
            +
                            f"fonts/{font_meme_text}.ttf", size=font.size - font_size_reduction
         | 
| 376 | 
            +
                        )
         | 
| 377 | 
            +
                        text_width = draw.textlength(text, font)
         | 
| 378 | 
            +
                        text_is_too_long = text_width > image.width
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                elif num_lines == 2:
         | 
| 381 | 
            +
                    while font.size > min_font and text_is_too_long:
         | 
| 382 | 
            +
                        font = ImageFont.truetype(
         | 
| 383 | 
            +
                            f"fonts/{font_meme_text}.ttf", size=font.size - font_size_reduction
         | 
| 384 | 
            +
                        )
         | 
| 385 | 
            +
                        max_len_increment = 0
         | 
| 386 | 
            +
                        while (
         | 
| 387 | 
            +
                            text_is_too_long
         | 
| 388 | 
            +
                            and max_len_increment < 10
         | 
| 389 | 
            +
                            and max_len_increment < (len(text)) // 2
         | 
| 390 | 
            +
                        ):
         | 
| 391 | 
            +
                            temp_text = insert_backslash(
         | 
| 392 | 
            +
                                text, max_length=(len(text) + max_len_increment) // 2
         | 
| 393 | 
            +
                            )
         | 
| 394 | 
            +
                            first_line, second_line = (
         | 
| 395 | 
            +
                                temp_text.split("\n")[0],
         | 
| 396 | 
            +
                                temp_text.split("\n")[1],
         | 
| 397 | 
            +
                            )
         | 
| 398 | 
            +
                            text_width = max(
         | 
| 399 | 
            +
                                draw.textlength(first_line, font),
         | 
| 400 | 
            +
                                draw.textlength(second_line, font),
         | 
| 401 | 
            +
                            )
         | 
| 402 | 
            +
                            text_is_too_long = text_width > image.width
         | 
| 403 | 
            +
                            max_len_increment += 1
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                elif num_lines == 3:
         | 
| 406 | 
            +
                    while font.size > min_font and text_is_too_long:
         | 
| 407 | 
            +
                        font = ImageFont.truetype(
         | 
| 408 | 
            +
                            f"fonts/{font_meme_text}.ttf", size=font.size - font_size_reduction
         | 
| 409 | 
            +
                        )
         | 
| 410 | 
            +
                        max_len_incr_1_split = 0
         | 
| 411 | 
            +
                        while text_is_too_long and max_len_incr_1_split < 10:
         | 
| 412 | 
            +
                            first_temp_text = insert_backslash(
         | 
| 413 | 
            +
                                text, max_length=(len(text) + max_len_incr_1_split) // 3
         | 
| 414 | 
            +
                            )
         | 
| 415 | 
            +
                            first_line, second_line = (
         | 
| 416 | 
            +
                                first_temp_text.split("\n")[0],
         | 
| 417 | 
            +
                                first_temp_text.split("\n")[1],
         | 
| 418 | 
            +
                            )
         | 
| 419 | 
            +
                            max_len_incr_2_split = 0
         | 
| 420 | 
            +
                            while text_is_too_long and max_len_incr_2_split < 10:
         | 
| 421 | 
            +
                                temp_text_second_line = insert_backslash(
         | 
| 422 | 
            +
                                    second_line,
         | 
| 423 | 
            +
                                    max_length=(len(second_line) + max_len_incr_2_split) // 2,
         | 
| 424 | 
            +
                                )
         | 
| 425 | 
            +
                                second_line_1, second_line_2 = (
         | 
| 426 | 
            +
                                    temp_text_second_line.split("\n")[0],
         | 
| 427 | 
            +
                                    temp_text_second_line.split("\n")[1],
         | 
| 428 | 
            +
                                )
         | 
| 429 | 
            +
                                temp_text = first_line + "\n" + second_line_1 + "\n" + second_line_2
         | 
| 430 | 
            +
                                text_width = max(
         | 
| 431 | 
            +
                                    draw.textlength(first_line, font),
         | 
| 432 | 
            +
                                    draw.textlength(second_line_1, font),
         | 
| 433 | 
            +
                                    draw.textlength(second_line_2, font),
         | 
| 434 | 
            +
                                )
         | 
| 435 | 
            +
                                text_is_too_long = text_width > image.width
         | 
| 436 | 
            +
                                max_len_incr_2_split += 1
         | 
| 437 | 
            +
                            max_len_incr_1_split += 1
         | 
| 438 | 
            +
                else:
         | 
| 439 | 
            +
                    raise (ValueError("num_lines can only be 1, 2 or 3"))
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                if not text_is_too_long and num_lines > 1:
         | 
| 442 | 
            +
                    text = temp_text
         | 
| 443 | 
            +
                return text, font, text_width, text_is_too_long
         | 
| 444 | 
            +
             | 
| 445 | 
            +
             | 
| 446 | 
            +
            def make_meme_image(
         | 
| 447 | 
            +
                image: str,
         | 
| 448 | 
            +
                text: str,
         | 
| 449 | 
            +
                font_meme_text: str,
         | 
| 450 | 
            +
                all_caps_meme_text: bool = False,
         | 
| 451 | 
            +
                text_at_the_top: bool = False,
         | 
| 452 | 
            +
            ) -> PIL.Image.Image:
         | 
| 453 | 
            +
                """
         | 
| 454 | 
            +
                Takes an image and a text and returns a meme image.
         | 
| 455 | 
            +
                """
         | 
| 456 | 
            +
                text = text.replace("\nUser", " ").replace("\n", " ").strip().rstrip(".")
         | 
| 457 | 
            +
                if all_caps_meme_text:
         | 
| 458 | 
            +
                    text = text.upper()
         | 
| 459 | 
            +
                # Resize image
         | 
| 460 | 
            +
                fixed_width = 700
         | 
| 461 | 
            +
                image = Image.open(image)
         | 
| 462 | 
            +
                image = resize_with_ratio(image, fixed_width)
         | 
| 463 | 
            +
                image_width, image_height = image.size
         | 
| 464 | 
            +
                height_width_ratio = image_height / image_width
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                draw = ImageDraw.Draw(image)
         | 
| 467 | 
            +
                min_font = 30
         | 
| 468 | 
            +
                initial_font_size = 60
         | 
| 469 | 
            +
                if height_width_ratio > 1:
         | 
| 470 | 
            +
                    min_font = 40
         | 
| 471 | 
            +
                    initial_font_size = 80
         | 
| 472 | 
            +
                text_is_too_long = True
         | 
| 473 | 
            +
                num_lines = 0
         | 
| 474 | 
            +
                while text_is_too_long and num_lines < 3:
         | 
| 475 | 
            +
                    num_lines += 1
         | 
| 476 | 
            +
                    font = ImageFont.truetype(f"fonts/{font_meme_text}.ttf", size=initial_font_size)
         | 
| 477 | 
            +
                    text, font, text_width, text_is_too_long = test_font_size(
         | 
| 478 | 
            +
                        draw,
         | 
| 479 | 
            +
                        image,
         | 
| 480 | 
            +
                        text,
         | 
| 481 | 
            +
                        font,
         | 
| 482 | 
            +
                        font_meme_text,
         | 
| 483 | 
            +
                        num_lines=num_lines,
         | 
| 484 | 
            +
                        min_font=min_font,
         | 
| 485 | 
            +
                        font_size_reduction=5,
         | 
| 486 | 
            +
                    )
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                if text_is_too_long:
         | 
| 489 | 
            +
                    text = f"Text is too long to fit the image"
         | 
| 490 | 
            +
                    if all_caps_meme_text:
         | 
| 491 | 
            +
                        text = text.upper()
         | 
| 492 | 
            +
                    font = ImageFont.truetype(f"fonts/{font_meme_text}.ttf", size=font.size)
         | 
| 493 | 
            +
                    text_width = draw.textlength(text, font)
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                outline_width = 2
         | 
| 496 | 
            +
                text_x = (image_width - text_width) / 2
         | 
| 497 | 
            +
                text_y = image_height - num_lines * font.size - 10 - num_lines
         | 
| 498 | 
            +
                if text_at_the_top:
         | 
| 499 | 
            +
                    text_y = 0
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                for i in range(-outline_width, outline_width + 1):
         | 
| 502 | 
            +
                    for j in range(-outline_width, outline_width + 1):
         | 
| 503 | 
            +
                        draw.multiline_text(
         | 
| 504 | 
            +
                            (text_x + i, text_y + j), text, fill="black", align="center", font=font
         | 
| 505 | 
            +
                        )
         | 
| 506 | 
            +
                draw.multiline_text((text_x, text_y), text, fill="white", align="center", font=font)
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                return image
         | 
| 509 | 
            +
             | 
| 510 | 
            +
             | 
| 511 | 
            +
            def format_user_prompt_with_im_history_and_system_conditioning(
         | 
| 512 | 
            +
                system_prompt: List[str],
         | 
| 513 | 
            +
                current_user_prompt_str: str,
         | 
| 514 | 
            +
                current_image: Optional[str],
         | 
| 515 | 
            +
                history: List[Tuple[str, str]],
         | 
| 516 | 
            +
            ) -> Tuple[List[str], List[str]]:
         | 
| 517 | 
            +
                """
         | 
| 518 | 
            +
                Produces the resulting list that needs to go inside the processor.
         | 
| 519 | 
            +
                It handles the potential image box input, the history and the system conditionning.
         | 
| 520 | 
            +
                """
         | 
| 521 | 
            +
                # resulting_list = copy.deepcopy(SYSTEM_PROMPT)
         | 
| 522 | 
            +
                resulting_list = system_prompt
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                # Format history
         | 
| 525 | 
            +
                for turn in history:
         | 
| 526 | 
            +
                    user_utterance, assistant_utterance = turn
         | 
| 527 | 
            +
                    splitted_user_utterance = split_str_on_im_markdown(user_utterance)
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    optional_space = ""
         | 
| 530 | 
            +
                    if not is_image(splitted_user_utterance[0]):
         | 
| 531 | 
            +
                        optional_space = " "
         | 
| 532 | 
            +
                    resulting_list.append(f"\nUser:{optional_space}")
         | 
| 533 | 
            +
                    resulting_list.extend(splitted_user_utterance)
         | 
| 534 | 
            +
                    resulting_list.append(f"<end_of_utterance>\nAssistant: {assistant_utterance}")
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                # Format current input
         | 
| 537 | 
            +
                current_user_prompt_str = remove_spaces_around_token(current_user_prompt_str)
         | 
| 538 | 
            +
                if current_image is None:
         | 
| 539 | 
            +
                    if "
         | 
| 541 | 
            +
                    else:
         | 
| 542 | 
            +
                        current_user_prompt_list = handle_manual_images_in_user_prompt(
         | 
| 543 | 
            +
                            current_user_prompt_str
         | 
| 544 | 
            +
                        )
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                    optional_space = ""
         | 
| 547 | 
            +
                    if not is_image(current_user_prompt_list[0]):
         | 
| 548 | 
            +
                        # Check if the first element is an image (and more precisely a path to an image)
         | 
| 549 | 
            +
                        optional_space = " "
         | 
| 550 | 
            +
                    resulting_list.append(f"\nUser:{optional_space}")
         | 
| 551 | 
            +
                    resulting_list.extend(current_user_prompt_list)
         | 
| 552 | 
            +
                    resulting_list.append("<end_of_utterance>\nAssistant:")
         | 
| 553 | 
            +
                else:
         | 
| 554 | 
            +
                    # Choosing to put the image first when the image is inputted through the UI, but this is an arbiratrary choice.
         | 
| 555 | 
            +
                    resulting_list.extend(
         | 
| 556 | 
            +
                        [
         | 
| 557 | 
            +
                            "\nUser:",
         | 
| 558 | 
            +
                            current_image,
         | 
| 559 | 
            +
                            f"{current_user_prompt_str}<end_of_utterance>\nAssistant:",
         | 
| 560 | 
            +
                        ]
         | 
| 561 | 
            +
                    )
         | 
| 562 | 
            +
                    current_user_prompt_list = [current_user_prompt_str]
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                return resulting_list, current_user_prompt_list
         | 
| 565 | 
            +
             | 
| 566 | 
            +
             | 
| 567 | 
            +
            # dope_callback = gr.CSVLogger()
         | 
| 568 | 
            +
            # problematic_callback = gr.CSVLogger()
         | 
| 569 | 
            +
             | 
| 570 | 
            +
            textbox = gr.Textbox(
         | 
| 571 | 
            +
                placeholder="Upload an image and ask the AI to create a meme!",
         | 
| 572 | 
            +
                show_label=False,
         | 
| 573 | 
            +
                value="Write a meme about this image.",
         | 
| 574 | 
            +
                visible=True,
         | 
| 575 | 
            +
                container=False,
         | 
| 576 | 
            +
                label="Text input",
         | 
| 577 | 
            +
                scale=8,
         | 
| 578 | 
            +
                max_lines=5,
         | 
| 579 | 
            +
            )
         | 
| 580 | 
            +
            chatbot = gr.Chatbot(
         | 
| 581 | 
            +
                elem_id="chatbot",
         | 
| 582 | 
            +
                label="AI Meme Generator Chatbot",
         | 
| 583 | 
            +
                visible=False,
         | 
| 584 | 
            +
                avatar_images=[None, BOT_AVATAR],
         | 
| 585 | 
            +
            )
         | 
| 586 | 
            +
             | 
| 587 | 
            +
            with gr.Blocks(title="AI Meme Generator", theme=gr.themes.Base()) as demo:
         | 
| 588 | 
            +
                gr.HTML("""<h1 align="center">AI Meme Generator</h1>""")
         | 
| 589 | 
            +
                with gr.Row(variant="panel"):
         | 
| 590 | 
            +
                    with gr.Column(scale=1):
         | 
| 591 | 
            +
                        gr.Image(
         | 
| 592 | 
            +
                            IDEFICS_LOGO,
         | 
| 593 | 
            +
                            elem_id="banner-image",
         | 
| 594 | 
            +
                            show_label=False,
         | 
| 595 | 
            +
                            show_download_button=False,
         | 
| 596 | 
            +
                            height=200,
         | 
| 597 | 
            +
                            width=250,
         | 
| 598 | 
            +
                        )
         | 
| 599 | 
            +
                    with gr.Column(scale=5):
         | 
| 600 | 
            +
                        gr.HTML(
         | 
| 601 | 
            +
                            """
         | 
| 602 | 
            +
                            <p><strong>AI Meme Generator</strong> is an AI system that writes humorous content inspired by images, allowing you to make the funniest memes with little effort. Upload your image and ask the Idefics chatbot to make a tailored meme.</p>
         | 
| 603 | 
            +
                            <p>AI Meme Generator is a space inspired from <a href="https://huggingface.co/spaces/HuggingFaceM4/ai_dad_jokes">AI Dad Jokes</a> and powered by <a href="https://huggingface.co/blog/idefics">IDEFICS</a>, an open-access large visual language model developped by Hugging Face. Like GPT-4, the multimodal model accepts arbitrary sequences of image and text inputs and produces text outputs. IDEFICS can answer questions about images, describe visual content, create stories grounded in multiple images, etc.</p>
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                            <p>βοΈ <strong>Intended uses and limitations:</strong> This demo is provided as research artifact to the community showcasing IDEFICS'capabilities. We detail misuses and out-of-scope uses <a href="https://huggingface.co/HuggingFaceM4/idefics-80b#misuse-and-out-of-scope-use">here</a>. In particular, the system should not be used to engage in harassment, abuse and bullying. The model can produce factually incorrect texts, hallucinate facts (with or without an image) and will struggle with small details in images. While the system will tend to refuse answering questionable user requests, it can produce problematic outputs (including racist, stereotypical, and disrespectful texts), in particular when prompted to do so.</p>
         | 
| 606 | 
            +
                        """
         | 
| 607 | 
            +
                        )
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                with gr.Row(elem_id="model_selector_row"):
         | 
| 610 | 
            +
                    model_selector = gr.Dropdown(
         | 
| 611 | 
            +
                        choices=MODELS,
         | 
| 612 | 
            +
                        value="HuggingFaceM4/idefics-80b-instruct",
         | 
| 613 | 
            +
                        interactive=True,
         | 
| 614 | 
            +
                        show_label=False,
         | 
| 615 | 
            +
                        container=False,
         | 
| 616 | 
            +
                        label="Model",
         | 
| 617 | 
            +
                        visible=False,
         | 
| 618 | 
            +
                    )
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                with gr.Row(equal_height=True):
         | 
| 621 | 
            +
                    with gr.Box(elem_id="gallery_box"):
         | 
| 622 | 
            +
                        gallery_type_choice = gr.Radio(
         | 
| 623 | 
            +
                            [
         | 
| 624 | 
            +
                                "All",
         | 
| 625 | 
            +
                                "Meme templates",
         | 
| 626 | 
            +
                                "Funny images",
         | 
| 627 | 
            +
                                "Politics",
         | 
| 628 | 
            +
                            ],
         | 
| 629 | 
            +
                            value="All",
         | 
| 630 | 
            +
                            label="Gallery Type",
         | 
| 631 | 
            +
                            interactive=True,
         | 
| 632 | 
            +
                            visible=False,
         | 
| 633 | 
            +
                            info="Choose the type of gallery you want to see.",
         | 
| 634 | 
            +
                        )
         | 
| 635 | 
            +
                        template_gallery = gr.Gallery(
         | 
| 636 | 
            +
                            # value= value given by gallery_type_choice,
         | 
| 637 | 
            +
                            label="Templates Gallery",
         | 
| 638 | 
            +
                            allow_preview=False,
         | 
| 639 | 
            +
                            columns=6,
         | 
| 640 | 
            +
                            elem_id="gallery",
         | 
| 641 | 
            +
                            show_share_button=False,
         | 
| 642 | 
            +
                            height=400,
         | 
| 643 | 
            +
                        )
         | 
| 644 | 
            +
                with gr.Row(equal_height=True):
         | 
| 645 | 
            +
                    with gr.Column(equal_height=1):
         | 
| 646 | 
            +
                        imagebox = gr.Image(
         | 
| 647 | 
            +
                            type="filepath", label="Image to meme", height=400, visible=True
         | 
| 648 | 
            +
                        )
         | 
| 649 | 
            +
                        with gr.Group():
         | 
| 650 | 
            +
                            with gr.Row():
         | 
| 651 | 
            +
                                textbox.render()
         | 
| 652 | 
            +
                            with gr.Row():
         | 
| 653 | 
            +
                                submit_btn = gr.Button(value="βΆοΈ Submit", visible=True)
         | 
| 654 | 
            +
                                clear_btn = gr.ClearButton(
         | 
| 655 | 
            +
                                    [textbox, imagebox, chatbot], value="π§Ή Clear"
         | 
| 656 | 
            +
                                )
         | 
| 657 | 
            +
                                regenerate_btn = gr.Button(value="π Regenerate", visible=True)
         | 
| 658 | 
            +
                                upload_btn = gr.UploadButton(
         | 
| 659 | 
            +
                                    "π Upload image", file_types=["image"], visible=False
         | 
| 660 | 
            +
                                )
         | 
| 661 | 
            +
                        with gr.Accordion(
         | 
| 662 | 
            +
                            "Advanced settings", open=False, visible=True
         | 
| 663 | 
            +
                        ) as parameter_row:
         | 
| 664 | 
            +
                            with gr.Row():
         | 
| 665 | 
            +
                                with gr.Column():
         | 
| 666 | 
            +
                                    all_caps_meme_text = gr.Checkbox(
         | 
| 667 | 
            +
                                        value=True,
         | 
| 668 | 
            +
                                        label="All Caps",
         | 
| 669 | 
            +
                                        interactive=True,
         | 
| 670 | 
            +
                                        info="",
         | 
| 671 | 
            +
                                    )
         | 
| 672 | 
            +
                                    text_at_the_top = gr.Checkbox(
         | 
| 673 | 
            +
                                        value=False,
         | 
| 674 | 
            +
                                        label="Text at the top",
         | 
| 675 | 
            +
                                        interactive=True,
         | 
| 676 | 
            +
                                        info="",
         | 
| 677 | 
            +
                                    )
         | 
| 678 | 
            +
                                with gr.Column():
         | 
| 679 | 
            +
                                    font_meme_text = gr.Radio(
         | 
| 680 | 
            +
                                        [
         | 
| 681 | 
            +
                                            "impact",
         | 
| 682 | 
            +
                                            "Roboto-Regular",
         | 
| 683 | 
            +
                                        ],
         | 
| 684 | 
            +
                                        value="impact",
         | 
| 685 | 
            +
                                        label="Font",
         | 
| 686 | 
            +
                                        interactive=True,
         | 
| 687 | 
            +
                                        info="",
         | 
| 688 | 
            +
                                    )
         | 
| 689 | 
            +
                            system_prompt = gr.Textbox(
         | 
| 690 | 
            +
                                value=SYSTEM_PROMPT,
         | 
| 691 | 
            +
                                visible=False,
         | 
| 692 | 
            +
                                lines=20,
         | 
| 693 | 
            +
                                max_lines=50,
         | 
| 694 | 
            +
                                interactive=True,
         | 
| 695 | 
            +
                            )
         | 
| 696 | 
            +
                            max_new_tokens = gr.Slider(
         | 
| 697 | 
            +
                                minimum=8,
         | 
| 698 | 
            +
                                maximum=150,
         | 
| 699 | 
            +
                                value=90,
         | 
| 700 | 
            +
                                step=1,
         | 
| 701 | 
            +
                                interactive=True,
         | 
| 702 | 
            +
                                label="Maximum number of new tokens to generate",
         | 
| 703 | 
            +
                            )
         | 
| 704 | 
            +
                            repetition_penalty = gr.Slider(
         | 
| 705 | 
            +
                                minimum=0.0,
         | 
| 706 | 
            +
                                maximum=5.0,
         | 
| 707 | 
            +
                                value=1.2,
         | 
| 708 | 
            +
                                step=0.01,
         | 
| 709 | 
            +
                                interactive=True,
         | 
| 710 | 
            +
                                label="Repetition penalty",
         | 
| 711 | 
            +
                                info="1.0 is equivalent to no penalty",
         | 
| 712 | 
            +
                            )
         | 
| 713 | 
            +
                            decoding_strategy = gr.Radio(
         | 
| 714 | 
            +
                                [
         | 
| 715 | 
            +
                                    "Greedy",
         | 
| 716 | 
            +
                                    "Top P Sampling",
         | 
| 717 | 
            +
                                ],
         | 
| 718 | 
            +
                                value="Top P Sampling",
         | 
| 719 | 
            +
                                label="Decoding strategy",
         | 
| 720 | 
            +
                                interactive=True,
         | 
| 721 | 
            +
                                info="Higher values is equivalent to sampling more low-probability tokens.",
         | 
| 722 | 
            +
                            )
         | 
| 723 | 
            +
                            temperature = gr.Slider(
         | 
| 724 | 
            +
                                minimum=0.0,
         | 
| 725 | 
            +
                                maximum=5.0,
         | 
| 726 | 
            +
                                value=0.6,
         | 
| 727 | 
            +
                                step=0.1,
         | 
| 728 | 
            +
                                interactive=True,
         | 
| 729 | 
            +
                                visible=True,
         | 
| 730 | 
            +
                                label="Sampling temperature",
         | 
| 731 | 
            +
                                info="Higher values will produce more diverse outputs.",
         | 
| 732 | 
            +
                            )
         | 
| 733 | 
            +
                            decoding_strategy.change(
         | 
| 734 | 
            +
                                fn=lambda selection: gr.Slider.update(
         | 
| 735 | 
            +
                                    visible=(
         | 
| 736 | 
            +
                                        selection
         | 
| 737 | 
            +
                                        in [
         | 
| 738 | 
            +
                                            "contrastive_sampling",
         | 
| 739 | 
            +
                                            "beam_sampling",
         | 
| 740 | 
            +
                                            "Top P Sampling",
         | 
| 741 | 
            +
                                            "sampling_top_k",
         | 
| 742 | 
            +
                                        ]
         | 
| 743 | 
            +
                                    )
         | 
| 744 | 
            +
                                ),
         | 
| 745 | 
            +
                                inputs=decoding_strategy,
         | 
| 746 | 
            +
                                outputs=temperature,
         | 
| 747 | 
            +
                            )
         | 
| 748 | 
            +
                            top_p = gr.Slider(
         | 
| 749 | 
            +
                                minimum=0.01,
         | 
| 750 | 
            +
                                maximum=0.99,
         | 
| 751 | 
            +
                                value=0.8,
         | 
| 752 | 
            +
                                step=0.01,
         | 
| 753 | 
            +
                                interactive=True,
         | 
| 754 | 
            +
                                visible=True,
         | 
| 755 | 
            +
                                label="Top P",
         | 
| 756 | 
            +
                                info="Higher values is equivalent to sampling more low-probability tokens.",
         | 
| 757 | 
            +
                            )
         | 
| 758 | 
            +
                            decoding_strategy.change(
         | 
| 759 | 
            +
                                fn=lambda selection: gr.Slider.update(
         | 
| 760 | 
            +
                                    visible=(selection in ["Top P Sampling"])
         | 
| 761 | 
            +
                                ),
         | 
| 762 | 
            +
                                inputs=decoding_strategy,
         | 
| 763 | 
            +
                                outputs=top_p,
         | 
| 764 | 
            +
                            )
         | 
| 765 | 
            +
                    with gr.Column(scale=2):
         | 
| 766 | 
            +
                        generated_memes_gallery = gr.Gallery(
         | 
| 767 | 
            +
                            # value="Images generated will appear here",
         | 
| 768 | 
            +
                            label="Generated Memes",
         | 
| 769 | 
            +
                            allow_preview=True,
         | 
| 770 | 
            +
                            elem_id="generated_memes_gallery",
         | 
| 771 | 
            +
                            show_download_button=True,
         | 
| 772 | 
            +
                            show_share_button=True,
         | 
| 773 | 
            +
                        ).style(columns=[2], object_fit="contain", height=600)
         | 
| 774 | 
            +
                with gr.Row():
         | 
| 775 | 
            +
                    chatbot.render()
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                def model_inference(
         | 
| 778 | 
            +
                    model_selector,
         | 
| 779 | 
            +
                    system_prompt,
         | 
| 780 | 
            +
                    user_prompt_str,
         | 
| 781 | 
            +
                    chat_history,
         | 
| 782 | 
            +
                    image,
         | 
| 783 | 
            +
                    decoding_strategy,
         | 
| 784 | 
            +
                    temperature,
         | 
| 785 | 
            +
                    max_new_tokens,
         | 
| 786 | 
            +
                    repetition_penalty,
         | 
| 787 | 
            +
                    top_p,
         | 
| 788 | 
            +
                    all_caps_meme_text,
         | 
| 789 | 
            +
                    text_at_the_top,
         | 
| 790 | 
            +
                    font_meme_text,
         | 
| 791 | 
            +
                ):
         | 
| 792 | 
            +
                    chat_history = []
         | 
| 793 | 
            +
                    if user_prompt_str.strip() == "" and image is None:
         | 
| 794 | 
            +
                        return "", None, chat_history
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                    system_prompt = ast.literal_eval(system_prompt)
         | 
| 797 | 
            +
                    (
         | 
| 798 | 
            +
                        formated_prompt_list,
         | 
| 799 | 
            +
                        user_prompt_list,
         | 
| 800 | 
            +
                    ) = format_user_prompt_with_im_history_and_system_conditioning(
         | 
| 801 | 
            +
                        system_prompt=system_prompt,
         | 
| 802 | 
            +
                        current_user_prompt_str=user_prompt_str.strip(),
         | 
| 803 | 
            +
                        current_image=image,
         | 
| 804 | 
            +
                        history=chat_history,
         | 
| 805 | 
            +
                    )
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                    client_endpoint = API_PATHS[model_selector]
         | 
| 808 | 
            +
                    client = Client(
         | 
| 809 | 
            +
                        base_url=client_endpoint,
         | 
| 810 | 
            +
                        headers={"x-use-cache": "0", "Authorization": f"Bearer {API_TOKEN}"},
         | 
| 811 | 
            +
                    )
         | 
| 812 | 
            +
             | 
| 813 | 
            +
                    # Common parameters to all decoding strategies
         | 
| 814 | 
            +
                    # This documentation is useful to read: https://huggingface.co/docs/transformers/main/en/generation_strategies
         | 
| 815 | 
            +
                    generation_args = {
         | 
| 816 | 
            +
                        "max_new_tokens": max_new_tokens,
         | 
| 817 | 
            +
                        "repetition_penalty": repetition_penalty,
         | 
| 818 | 
            +
                        "stop_sequences": EOS_STRINGS,
         | 
| 819 | 
            +
                    }
         | 
| 820 | 
            +
             | 
| 821 | 
            +
                    assert decoding_strategy in [
         | 
| 822 | 
            +
                        "Greedy",
         | 
| 823 | 
            +
                        "Top P Sampling",
         | 
| 824 | 
            +
                    ]
         | 
| 825 | 
            +
                    if decoding_strategy == "Greedy":
         | 
| 826 | 
            +
                        generation_args["do_sample"] = False
         | 
| 827 | 
            +
                    elif decoding_strategy == "Top P Sampling":
         | 
| 828 | 
            +
                        generation_args["temperature"] = temperature
         | 
| 829 | 
            +
                        generation_args["do_sample"] = True
         | 
| 830 | 
            +
                        generation_args["top_p"] = top_p
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                    if image is None:
         | 
| 833 | 
            +
                        # Case where there is no image OR the image is passed as `<fake_token_around_image><image:IMAGE_URL><fake_token_around_image>`
         | 
| 834 | 
            +
                        chat_history.append([prompt_list_to_markdown(user_prompt_list), ""])
         | 
| 835 | 
            +
                    else:
         | 
| 836 | 
            +
                        # Case where the image is passed through the Image Box.
         | 
| 837 | 
            +
                        # Convert the image into base64 for both passing it through the chat history and
         | 
| 838 | 
            +
                        # displaying the image inside the same bubble as the text.
         | 
| 839 | 
            +
                        chat_history.append(
         | 
| 840 | 
            +
                            [
         | 
| 841 | 
            +
                                f"{prompt_list_to_markdown([image] + user_prompt_list)}",
         | 
| 842 | 
            +
                                "",
         | 
| 843 | 
            +
                            ]
         | 
| 844 | 
            +
                        )
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                    query = prompt_list_to_tgi_input(formated_prompt_list)
         | 
| 847 | 
            +
                    all_meme_images = []
         | 
| 848 | 
            +
                    for i in range(4):
         | 
| 849 | 
            +
                        stream = client.generate_stream(prompt=query, **generation_args)
         | 
| 850 | 
            +
             | 
| 851 | 
            +
                        acc_text = ""
         | 
| 852 | 
            +
                        full_text = ""
         | 
| 853 | 
            +
                        for idx, response in enumerate(stream):
         | 
| 854 | 
            +
                            text_token = response.token.text
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                            if response.details:
         | 
| 857 | 
            +
                                # That's the exit condition
         | 
| 858 | 
            +
                                if image is not None and full_text != "":
         | 
| 859 | 
            +
                                    meme_image = make_meme_image(
         | 
| 860 | 
            +
                                        image=image,
         | 
| 861 | 
            +
                                        text=full_text,
         | 
| 862 | 
            +
                                        font_meme_text=font_meme_text,
         | 
| 863 | 
            +
                                        all_caps_meme_text=all_caps_meme_text,
         | 
| 864 | 
            +
                                        text_at_the_top=text_at_the_top,
         | 
| 865 | 
            +
                                    )
         | 
| 866 | 
            +
                                    meme_image = pil_to_temp_file(meme_image)
         | 
| 867 | 
            +
                                    all_meme_images.append(meme_image)
         | 
| 868 | 
            +
                                    yield "", all_meme_images, chat_history
         | 
| 869 | 
            +
                                if i == 3:
         | 
| 870 | 
            +
                                    return
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                            if text_token in STOP_SUSPECT_LIST:
         | 
| 873 | 
            +
                                acc_text += text_token
         | 
| 874 | 
            +
                                continue
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                            if idx == 0 and text_token.startswith(" "):
         | 
| 877 | 
            +
                                text_token = text_token.lstrip()
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                            acc_text += text_token
         | 
| 880 | 
            +
                            # Commented to not have a chatbot history that could confuse user
         | 
| 881 | 
            +
             | 
| 882 | 
            +
                            # last_turn = chat_history.pop(-1)
         | 
| 883 | 
            +
                            # last_turn[-1] += acc_text
         | 
| 884 | 
            +
                            # if last_turn[-1].endswith("\nUser"):
         | 
| 885 | 
            +
                            #     # Safeguard: sometimes (rarely), the model won't generate the token `<end_of_utterance>` and will go directly to generating `\nUser:`
         | 
| 886 | 
            +
                            #     # It will thus stop the generation on `\nUser:`. But when it exits, it will have already generated `\nUser`
         | 
| 887 | 
            +
                            #     # This post-processing ensures that we don't have an additional `\nUser` wandering around.
         | 
| 888 | 
            +
                            #     last_turn[-1] = last_turn[-1][:-5]
         | 
| 889 | 
            +
                            # chat_history.append(last_turn)
         | 
| 890 | 
            +
                            # yield "", None, chat_history
         | 
| 891 | 
            +
                            full_text += acc_text
         | 
| 892 | 
            +
                            acc_text = ""
         | 
| 893 | 
            +
             | 
| 894 | 
            +
                textbox.submit(
         | 
| 895 | 
            +
                    fn=model_inference,
         | 
| 896 | 
            +
                    inputs=[
         | 
| 897 | 
            +
                        model_selector,
         | 
| 898 | 
            +
                        system_prompt,
         | 
| 899 | 
            +
                        textbox,
         | 
| 900 | 
            +
                        chatbot,
         | 
| 901 | 
            +
                        imagebox,
         | 
| 902 | 
            +
                        decoding_strategy,
         | 
| 903 | 
            +
                        temperature,
         | 
| 904 | 
            +
                        max_new_tokens,
         | 
| 905 | 
            +
                        repetition_penalty,
         | 
| 906 | 
            +
                        top_p,
         | 
| 907 | 
            +
                        all_caps_meme_text,
         | 
| 908 | 
            +
                        text_at_the_top,
         | 
| 909 | 
            +
                        font_meme_text,
         | 
| 910 | 
            +
                    ],
         | 
| 911 | 
            +
                    outputs=[textbox, generated_memes_gallery, chatbot],
         | 
| 912 | 
            +
                )
         | 
| 913 | 
            +
                submit_btn.click(fn=lambda: "", inputs=[], outputs=[generated_memes_gallery]).then(
         | 
| 914 | 
            +
                    fn=model_inference,
         | 
| 915 | 
            +
                    inputs=[
         | 
| 916 | 
            +
                        model_selector,
         | 
| 917 | 
            +
                        system_prompt,
         | 
| 918 | 
            +
                        textbox,
         | 
| 919 | 
            +
                        chatbot,
         | 
| 920 | 
            +
                        imagebox,
         | 
| 921 | 
            +
                        decoding_strategy,
         | 
| 922 | 
            +
                        temperature,
         | 
| 923 | 
            +
                        max_new_tokens,
         | 
| 924 | 
            +
                        repetition_penalty,
         | 
| 925 | 
            +
                        top_p,
         | 
| 926 | 
            +
                        all_caps_meme_text,
         | 
| 927 | 
            +
                        text_at_the_top,
         | 
| 928 | 
            +
                        font_meme_text,
         | 
| 929 | 
            +
                    ],
         | 
| 930 | 
            +
                    outputs=[
         | 
| 931 | 
            +
                        textbox,
         | 
| 932 | 
            +
                        generated_memes_gallery,
         | 
| 933 | 
            +
                        chatbot,
         | 
| 934 | 
            +
                    ],
         | 
| 935 | 
            +
                )
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                def remove_last_turn(chat_history):
         | 
| 938 | 
            +
                    if len(chat_history) == 0:
         | 
| 939 | 
            +
                        return gr.Update(), gr.Update()
         | 
| 940 | 
            +
                    last_interaction = chat_history[-1]
         | 
| 941 | 
            +
                    chat_history = chat_history[:-1]
         | 
| 942 | 
            +
                    last_interaction[0] = re.sub(r"!\[]\(/file=.*?\)", "", last_interaction[0])
         | 
| 943 | 
            +
                    chat_update = gr.update(value=chat_history)
         | 
| 944 | 
            +
                    text_update = gr.update(value=last_interaction[0])
         | 
| 945 | 
            +
                    return chat_update, text_update, ""
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                regenerate_btn.click(
         | 
| 948 | 
            +
                    fn=remove_last_turn,
         | 
| 949 | 
            +
                    inputs=chatbot,
         | 
| 950 | 
            +
                    outputs=[chatbot, textbox, generated_memes_gallery],
         | 
| 951 | 
            +
                ).then(
         | 
| 952 | 
            +
                    fn=model_inference,
         | 
| 953 | 
            +
                    inputs=[
         | 
| 954 | 
            +
                        model_selector,
         | 
| 955 | 
            +
                        system_prompt,
         | 
| 956 | 
            +
                        textbox,
         | 
| 957 | 
            +
                        chatbot,
         | 
| 958 | 
            +
                        imagebox,
         | 
| 959 | 
            +
                        decoding_strategy,
         | 
| 960 | 
            +
                        temperature,
         | 
| 961 | 
            +
                        max_new_tokens,
         | 
| 962 | 
            +
                        repetition_penalty,
         | 
| 963 | 
            +
                        top_p,
         | 
| 964 | 
            +
                        all_caps_meme_text,
         | 
| 965 | 
            +
                        text_at_the_top,
         | 
| 966 | 
            +
                        font_meme_text,
         | 
| 967 | 
            +
                    ],
         | 
| 968 | 
            +
                    outputs=[
         | 
| 969 | 
            +
                        textbox,
         | 
| 970 | 
            +
                        generated_memes_gallery,
         | 
| 971 | 
            +
                        chatbot,
         | 
| 972 | 
            +
                    ],
         | 
| 973 | 
            +
                )
         | 
| 974 | 
            +
             | 
| 975 | 
            +
                upload_btn.upload(add_file, [upload_btn], [imagebox, upload_btn], queue=False)
         | 
| 976 | 
            +
                submit_btn.click(
         | 
| 977 | 
            +
                    lambda: gr.update(label="π Upload image", interactive=True), [], upload_btn
         | 
| 978 | 
            +
                )
         | 
| 979 | 
            +
                textbox.submit(
         | 
| 980 | 
            +
                    lambda: gr.update(label="π Upload image", interactive=True), [], upload_btn
         | 
| 981 | 
            +
                )
         | 
| 982 | 
            +
                clear_btn.click(
         | 
| 983 | 
            +
                    lambda: gr.update(label="π Upload image", interactive=True), [], upload_btn
         | 
| 984 | 
            +
                )
         | 
| 985 | 
            +
                gallery_type_choice.change(
         | 
| 986 | 
            +
                    fn=choose_gallery,
         | 
| 987 | 
            +
                    inputs=[gallery_type_choice],
         | 
| 988 | 
            +
                    outputs=[template_gallery],
         | 
| 989 | 
            +
                    queue=False,
         | 
| 990 | 
            +
                )
         | 
| 991 | 
            +
                template_gallery.select(
         | 
| 992 | 
            +
                    fn=add_file_gallery,
         | 
| 993 | 
            +
                    inputs=[template_gallery],
         | 
| 994 | 
            +
                    outputs=[textbox, imagebox, generated_memes_gallery],
         | 
| 995 | 
            +
                ).success(
         | 
| 996 | 
            +
                    fn=model_inference,
         | 
| 997 | 
            +
                    inputs=[
         | 
| 998 | 
            +
                        model_selector,
         | 
| 999 | 
            +
                        system_prompt,
         | 
| 1000 | 
            +
                        textbox,
         | 
| 1001 | 
            +
                        chatbot,
         | 
| 1002 | 
            +
                        imagebox,
         | 
| 1003 | 
            +
                        decoding_strategy,
         | 
| 1004 | 
            +
                        temperature,
         | 
| 1005 | 
            +
                        max_new_tokens,
         | 
| 1006 | 
            +
                        repetition_penalty,
         | 
| 1007 | 
            +
                        top_p,
         | 
| 1008 | 
            +
                        all_caps_meme_text,
         | 
| 1009 | 
            +
                        text_at_the_top,
         | 
| 1010 | 
            +
                        font_meme_text,
         | 
| 1011 | 
            +
                    ],
         | 
| 1012 | 
            +
                    outputs=[
         | 
| 1013 | 
            +
                        textbox,
         | 
| 1014 | 
            +
                        generated_memes_gallery,
         | 
| 1015 | 
            +
                        chatbot,
         | 
| 1016 | 
            +
                    ],
         | 
| 1017 | 
            +
                )
         | 
| 1018 | 
            +
                demo.load(
         | 
| 1019 | 
            +
                    fn=choose_gallery, inputs=[gallery_type_choice], outputs=[template_gallery]
         | 
| 1020 | 
            +
                )
         | 
| 1021 | 
            +
            demo.queue(concurrency_count=40, max_size=40)
         | 
| 1022 | 
            +
            demo.launch()
         | 
    	
        fonts/Impacted.ttf
    ADDED
    
    | Binary file (107 kB). View file | 
|  | 
    	
        fonts/Roboto-Regular.ttf
    ADDED
    
    | Binary file (168 kB). View file | 
|  | 
    	
        fonts/impact.ttf
    ADDED
    
    | Binary file (136 kB). View file | 
|  | 
    	
        fonts/unicode.impact.ttf
    ADDED
    
    | Binary file (77.7 kB). View file | 
|  | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            --extra-index-url https://download.pytorch.org/whl/cu113
         | 
| 2 | 
            +
            torch
         | 
| 3 | 
            +
            transformers @ git+https://github.com/huggingface/transformers@5c67682b169576c4859700d551090ff79d450a9a
         | 
| 4 | 
            +
            requests
         | 
| 5 | 
            +
            pillow
         | 
| 6 | 
            +
            torchvision
         | 
| 7 | 
            +
            PyYAML
         | 
| 8 | 
            +
            opencv-python
         | 
| 9 | 
            +
            numpy
         | 
| 10 | 
            +
            accelerate
         | 
| 11 | 
            +
            joblib
         | 
| 12 | 
            +
            deepspeed
         | 
| 13 | 
            +
            parameterized
         | 
| 14 | 
            +
            einops
         | 
| 15 | 
            +
            pynvml
         | 
| 16 | 
            +
            sentencepiece
         | 
| 17 | 
            +
            text_generation
         | 
| 18 | 
            +
            gradio-client @ git+https://github.com/gradio-app/gradio@bd4570ed4343f75a7ae335ef06d5eb313d107bc9#subdirectory=client/python
         | 
| 19 | 
            +
            https://gradio-main-build.s3.amazonaws.com/92282cea6afdf7e9930ece1046d8a63be34b3cea/gradio-3.40.1-py3-none-any.whl
         | 

