|
|
|
|
|
import streamlit as st |
|
|
import re |
|
|
import torch |
|
|
import requests |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from prompts import SUMMARY_PROMPT, MEME_PROMPT |
|
|
|
|
|
IMGFLIP_URL = "https://api.imgflip.com/caption_image" |
|
|
|
|
|
|
|
|
TEMPLATE_IDS = { |
|
|
"Drake Hotline Bling": "181913649", |
|
|
"UNO Draw 25 Cards": "217743513", |
|
|
"Bernie Asking For Support": "222403160", |
|
|
"Disaster Girl": "97984", |
|
|
"Waiting Skeleton": "109765", |
|
|
"Always Has Been": "252600902", |
|
|
"Woman Yelling at Cat": "188390779", |
|
|
"I Bet He's Thinking About Other Women": "110163934", |
|
|
"One Does Not Simply": "61579", |
|
|
"Success Kid": "61544", |
|
|
"Oprah You Get A": "28251713", |
|
|
"Hide the Pain Harold": "27813981", |
|
|
} |
|
|
|
|
|
@st.cache_resource |
|
|
def load_llama3(): |
|
|
""" |
|
|
Load Llama-3.2-1B and its tokenizer. |
|
|
""" |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"meta-llama/Llama-3.2-1B", |
|
|
trust_remote_code=True, |
|
|
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"] |
|
|
) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"meta-llama/Llama-3.2-1B", |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"] |
|
|
) |
|
|
return tokenizer, model |
|
|
|
|
|
def call_llama3(prompt: str, max_new_tokens: int = 200) -> str: |
|
|
""" |
|
|
Generate text with Llama-3.2-1B. |
|
|
""" |
|
|
tokenizer, model = load_llama3() |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
def article_to_meme(article_text: str) -> str: |
|
|
|
|
|
sum_prompt = SUMMARY_PROMPT.format(article_text=article_text) |
|
|
summary = call_llama3(sum_prompt, max_new_tokens=100).strip() |
|
|
|
|
|
|
|
|
meme_prompt = MEME_PROMPT.format(summary=summary) |
|
|
llama_out = call_llama3(meme_prompt, max_new_tokens=150) |
|
|
|
|
|
|
|
|
tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE) |
|
|
text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE) |
|
|
text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE) |
|
|
if not (tpl_match and text0_match and text1_match): |
|
|
raise ValueError(f"Could not parse model output:\n{llama_out}") |
|
|
|
|
|
template = tpl_match.group(1).strip() |
|
|
text0 = text0_match.group(1).strip() |
|
|
text1 = text1_match.group(1).strip() |
|
|
|
|
|
|
|
|
template_id = TEMPLATE_IDS.get(template) |
|
|
if template_id is None: |
|
|
raise KeyError(f"Unknown template: {template}") |
|
|
|
|
|
creds = st.secrets["imgflip"] |
|
|
params = { |
|
|
"template_id": template_id, |
|
|
"username": creds["username"], |
|
|
"password": creds["password"], |
|
|
"text0": text0, |
|
|
"text1": text1, |
|
|
} |
|
|
resp = requests.post(IMGFLIP_URL, params=params) |
|
|
resp.raise_for_status() |
|
|
data = resp.json() |
|
|
if not data["success"]: |
|
|
raise Exception(data["error_message"]) |
|
|
return data["data"]["url"] |
|
|
|