memetoday / memes.py
rushankg's picture
Update memes.py
98ef818 verified
raw
history blame
4.64 kB
# memes.py
import streamlit as st
import re
import torch
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from prompts import SUMMARY_PROMPT, MEME_PROMPT
IMGFLIP_URL = "https://api.imgflip.com/caption_image"
# 12 template names β†’ Imgflip template_ids
TEMPLATE_IDS = {
"Drake Hotline Bling": "181913649",
"UNO Draw 25 Cards": "217743513",
"Bernie Asking For Support": "222403160",
"Disaster Girl": "97984",
"Waiting Skeleton": "109765",
"Always Has Been": "252600902",
"Woman Yelling at Cat": "188390779",
"I Bet He's Thinking About Other Women": "110163934",
"One Does Not Simply": "61579",
"Success Kid": "61544",
"Oprah You Get A": "28251713",
"Hide the Pain Harold": "27813981",
}
@st.cache_resource
def load_llama3():
"""
Load Llama-3.2-1B and its tokenizer.
"""
st.write("πŸ”„ Loading Llama-3.2-1B model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B",
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
st.write("βœ… Model and tokenizer loaded.")
return tokenizer, model
def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
"""
Generate text with Llama-3.2-1B.
"""
st.write("πŸ“ Sending prompt to Llama-3.2-1B...")
st.write(prompt if len(prompt) < 200 else prompt[:200] + '...')
tokenizer, model = load_llama3()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.write("πŸ’‘ Model response:")
st.write(text)
return text
def article_to_meme(article_text: str) -> str:
"""
End-to-end pipeline:
1) Summarize the article via Llama-3.2-1B.
2) Ask Llama-3.2-1B to choose a meme template and produce two 6-8 word captions.
3) Parse the model's response.
4) Call Imgflip API to render the meme and return its URL.
"""
# 1) Summarize
st.write("▢️ Step 1: Summarizing article...")
sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
st.write(f"πŸ“ Summary: {summary}")
# 2) Template + captions
st.write("▢️ Step 2: Generating meme template and captions...")
meme_prompt = MEME_PROMPT.format(summary=summary)
llama_out = call_llama3(meme_prompt, max_new_tokens=150)
# 3) Parse the response
st.write("▢️ Step 3: Parsing model output...")
tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
if not (tpl_match and text0_match and text1_match):
st.error(f"❌ Could not parse model output:\n{llama_out}")
raise ValueError(f"Could not parse model output:\n{llama_out}")
template = tpl_match.group(1).strip()
text0 = text0_match.group(1).strip()
text1 = text1_match.group(1).strip()
st.write(f"🎯 Chosen template: {template}")
st.write(f"πŸ’¬ text0: {text0}")
st.write(f"πŸ’¬ text1: {text1}")
# 4) Render the meme via Imgflip
st.write("▢️ Step 4: Rendering meme via Imgflip...")
template_id = TEMPLATE_IDS.get(template)
if template_id is None:
st.error(f"❌ Unknown template: {template}")
raise KeyError(f"Unknown template: {template}")
creds = st.secrets["imgflip"]
params = {
"template_id": template_id,
"username": creds["username"],
"password": creds["password"],
"text0": text0,
"text1": text1,
}
resp = requests.post(IMGFLIP_URL, params=params)
st.write("πŸ”— Imgflip API called...")
resp.raise_for_status()
data = resp.json()
if not data.get("success", False):
st.error(f"❌ Imgflip error: {data.get('error_message')}")
raise Exception(data.get("error_message"))
meme_url = data["data"]["url"]
st.write(f"βœ… Meme URL: {meme_url}")
return meme_url