File size: 3,307 Bytes
f116710 3946047 f116710 3946047 f116710 3946047 f116710 afe469e 16fef17 afe469e 16fef17 7ec3808 afe469e bc0ca96 7ec3808 f116710 afe469e 7ec3808 16fef17 bc0ca96 f116710 afe469e 16fef17 afe469e 16fef17 afe469e 16fef17 f116710 16fef17 f116710 16fef17 f116710 afe469e f116710 afe469e f116710 afe469e f116710 afe469e f116710 afe469e f116710 afe469e f116710 afe469e f116710 16fef17 f116710 16fef17 f116710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# memes.py
import streamlit as st
import re
import torch
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from prompts import SUMMARY_PROMPT, MEME_PROMPT
IMGFLIP_URL = "https://api.imgflip.com/caption_image"
# 12 template names → Imgflip template_ids
TEMPLATE_IDS = {
"Drake Hotline Bling": "181913649",
"UNO Draw 25 Cards": "217743513",
"Bernie Asking For Support": "222403160",
"Disaster Girl": "97984",
"Waiting Skeleton": "109765",
"Always Has Been": "252600902",
"Woman Yelling at Cat": "188390779",
"I Bet He's Thinking About Other Women": "110163934",
"One Does Not Simply": "61579",
"Success Kid": "61544",
"Oprah You Get A": "28251713",
"Hide the Pain Harold": "27813981",
}
@st.cache_resource
def load_llama3():
"""
Load Llama-3.2-1B and its tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B",
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
return tokenizer, model
def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
"""
Generate text with Llama-3.2-1B.
"""
tokenizer, model = load_llama3()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def article_to_meme(article_text: str) -> str:
# Summarize
sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
# Meme template + captions
meme_prompt = MEME_PROMPT.format(summary=summary)
llama_out = call_llama3(meme_prompt, max_new_tokens=150)
# Parse response
tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
if not (tpl_match and text0_match and text1_match):
raise ValueError(f"Could not parse model output:\n{llama_out}")
template = tpl_match.group(1).strip()
text0 = text0_match.group(1).strip()
text1 = text1_match.group(1).strip()
# Render meme
template_id = TEMPLATE_IDS.get(template)
if template_id is None:
raise KeyError(f"Unknown template: {template}")
creds = st.secrets["imgflip"]
params = {
"template_id": template_id,
"username": creds["username"],
"password": creds["password"],
"text0": text0,
"text1": text1,
}
resp = requests.post(IMGFLIP_URL, params=params)
resp.raise_for_status()
data = resp.json()
if not data["success"]:
raise Exception(data["error_message"])
return data["data"]["url"]
|