memetoday / memes.py
rushankg's picture
Update memes.py
bc0ca96 verified
raw
history blame
3.31 kB
# memes.py
import streamlit as st
import re
import torch
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from prompts import SUMMARY_PROMPT, MEME_PROMPT
IMGFLIP_URL = "https://api.imgflip.com/caption_image"
# 12 template names β†’ Imgflip template_ids
TEMPLATE_IDS = {
"Drake Hotline Bling": "181913649",
"UNO Draw 25 Cards": "217743513",
"Bernie Asking For Support": "222403160",
"Disaster Girl": "97984",
"Waiting Skeleton": "109765",
"Always Has Been": "252600902",
"Woman Yelling at Cat": "188390779",
"I Bet He's Thinking About Other Women": "110163934",
"One Does Not Simply": "61579",
"Success Kid": "61544",
"Oprah You Get A": "28251713",
"Hide the Pain Harold": "27813981",
}
@st.cache_resource
def load_llama3():
"""
Load Llama-3.2-1B and its tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B",
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
return tokenizer, model
def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
"""
Generate text with Llama-3.2-1B.
"""
tokenizer, model = load_llama3()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def article_to_meme(article_text: str) -> str:
# Summarize
sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
# Meme template + captions
meme_prompt = MEME_PROMPT.format(summary=summary)
llama_out = call_llama3(meme_prompt, max_new_tokens=150)
# Parse response
tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
if not (tpl_match and text0_match and text1_match):
raise ValueError(f"Could not parse model output:\n{llama_out}")
template = tpl_match.group(1).strip()
text0 = text0_match.group(1).strip()
text1 = text1_match.group(1).strip()
# Render meme
template_id = TEMPLATE_IDS.get(template)
if template_id is None:
raise KeyError(f"Unknown template: {template}")
creds = st.secrets["imgflip"]
params = {
"template_id": template_id,
"username": creds["username"],
"password": creds["password"],
"text0": text0,
"text1": text1,
}
resp = requests.post(IMGFLIP_URL, params=params)
resp.raise_for_status()
data = resp.json()
if not data["success"]:
raise Exception(data["error_message"])
return data["data"]["url"]