File size: 4,637 Bytes
f116710
 
 
 
 
 
 
 
 
 
 
 
 
3946047
 
f116710
3946047
 
f116710
3946047
 
 
 
 
f116710
 
 
afe469e
16fef17
afe469e
16fef17
98ef818
7ec3808
afe469e
bc0ca96
 
7ec3808
f116710
afe469e
7ec3808
16fef17
bc0ca96
 
f116710
98ef818
f116710
 
98ef818
afe469e
16fef17
afe469e
16fef17
98ef818
 
afe469e
16fef17
 
f116710
 
 
16fef17
f116710
98ef818
 
 
 
 
f116710
 
98ef818
 
 
 
 
 
 
 
 
f116710
afe469e
98ef818
f116710
98ef818
 
f116710
afe469e
f116710
98ef818
 
afe469e
 
 
f116710
98ef818
afe469e
f116710
 
 
 
98ef818
 
 
f116710
98ef818
 
f116710
 
98ef818
f116710
 
16fef17
f116710
 
16fef17
 
f116710
 
 
 
98ef818
f116710
 
98ef818
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# memes.py
import streamlit as st
import re
import torch
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from prompts import SUMMARY_PROMPT, MEME_PROMPT

IMGFLIP_URL = "https://api.imgflip.com/caption_image"

# 12 template names β†’ Imgflip template_ids
TEMPLATE_IDS = {
    "Drake Hotline Bling":        "181913649",
    "UNO Draw 25 Cards":           "217743513",
    "Bernie Asking For Support":   "222403160",
    "Disaster Girl":               "97984",
    "Waiting Skeleton":            "109765",
    "Always Has Been":             "252600902",
    "Woman Yelling at Cat":        "188390779",
    "I Bet He's Thinking About Other Women": "110163934",
    "One Does Not Simply":         "61579",
    "Success Kid":                 "61544",
    "Oprah You Get A":             "28251713",
    "Hide the Pain Harold":        "27813981",
}

@st.cache_resource
def load_llama3():
    """
    Load Llama-3.2-1B and its tokenizer.
    """
    st.write("πŸ”„ Loading Llama-3.2-1B model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.2-1B",
        trust_remote_code=True,
        use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
    )
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.2-1B",
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
    )
    st.write("βœ… Model and tokenizer loaded.")
    return tokenizer, model


def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
    """
    Generate text with Llama-3.2-1B.
    """
    st.write("πŸ“ Sending prompt to Llama-3.2-1B...")
    st.write(prompt if len(prompt) < 200 else prompt[:200] + '...')
    tokenizer, model = load_llama3()
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    st.write("πŸ’‘ Model response:")
    st.write(text)
    return text


def article_to_meme(article_text: str) -> str:
    """
    End-to-end pipeline:
    1) Summarize the article via Llama-3.2-1B.
    2) Ask Llama-3.2-1B to choose a meme template and produce two 6-8 word captions.
    3) Parse the model's response.
    4) Call Imgflip API to render the meme and return its URL.
    """
    # 1) Summarize
    st.write("▢️ Step 1: Summarizing article...")
    sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
    summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
    st.write(f"πŸ“ Summary: {summary}")

    # 2) Template + captions
    st.write("▢️ Step 2: Generating meme template and captions...")
    meme_prompt = MEME_PROMPT.format(summary=summary)
    llama_out = call_llama3(meme_prompt, max_new_tokens=150)

    # 3) Parse the response
    st.write("▢️ Step 3: Parsing model output...")
    tpl_match   = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
    text0_match = re.search(r"text0:\s*(.+)",    llama_out, re.IGNORECASE)
    text1_match = re.search(r"text1:\s*(.+)",    llama_out, re.IGNORECASE)
    if not (tpl_match and text0_match and text1_match):
        st.error(f"❌ Could not parse model output:\n{llama_out}")
        raise ValueError(f"Could not parse model output:\n{llama_out}")

    template = tpl_match.group(1).strip()
    text0     = text0_match.group(1).strip()
    text1     = text1_match.group(1).strip()
    st.write(f"🎯 Chosen template: {template}")
    st.write(f"πŸ’¬ text0: {text0}")
    st.write(f"πŸ’¬ text1: {text1}")

    # 4) Render the meme via Imgflip
    st.write("▢️ Step 4: Rendering meme via Imgflip...")
    template_id = TEMPLATE_IDS.get(template)
    if template_id is None:
        st.error(f"❌ Unknown template: {template}")
        raise KeyError(f"Unknown template: {template}")

    creds = st.secrets["imgflip"]
    params = {
        "template_id": template_id,
        "username":    creds["username"],
        "password":    creds["password"],
        "text0":       text0,
        "text1":       text1,
    }
    resp = requests.post(IMGFLIP_URL, params=params)
    st.write("πŸ”— Imgflip API called...")
    resp.raise_for_status()
    data = resp.json()
    if not data.get("success", False):
        st.error(f"❌ Imgflip error: {data.get('error_message')}")
        raise Exception(data.get("error_message"))

    meme_url = data["data"]["url"]
    st.write(f"βœ… Meme URL: {meme_url}")
    return meme_url