File size: 3,489 Bytes
f116710 3946047 f116710 3946047 f116710 3946047 f116710 16fef17 7ec3808 16fef17 7ec3808 f116710 16fef17 7ec3808 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 16fef17 f116710 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# memes.py
import streamlit as st
import re
import torch
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from prompts import SUMMARY_PROMPT, MEME_PROMPT
IMGFLIP_URL = "https://api.imgflip.com/caption_image"
# 12 template names → Imgflip template_ids
TEMPLATE_IDS = {
"Drake Hotline Bling": "181913649",
"UNO Draw 25 Cards": "217743513",
"Bernie Asking For Support": "222403160",
"Disaster Girl": "97984",
"Waiting Skeleton": "109765",
"Always Has Been": "252600902",
"Woman Yelling at Cat": "188390779",
"I Bet He's Thinking About Other Women": "110163934",
"One Does Not Simply": "61579",
"Success Kid": "61544",
"Oprah You Get A": "28251713",
"Hide the Pain Harold": "27813981",
}
@st.cache_resource
def load_mpt():
"""
Load the MosaicML MPT-7B-Chat model and tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(
"mosaicml/mpt-7b-chat",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"mosaicml/mpt-7b-chat",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
return tokenizer, model
def call_mpt(prompt: str, max_new_tokens: int = 200) -> str:
"""
Generate text from MPT-7B-Chat given a prompt.
"""
tokenizer, model = load_mpt()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def article_to_meme(article_text: str) -> str:
"""
End-to-end pipeline:
1) Summarize the article via MPT-7B-Chat.
2) Ask MPT-7B-Chat to choose a meme template and produce two 6-8 word captions.
3) Parse the model's response.
4) Call Imgflip API to render the meme and return its URL.
"""
# 1) Summarize
sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
summary = call_mpt(sum_prompt, max_new_tokens=100).strip()
# 2) Template + captions
meme_prompt = MEME_PROMPT.format(summary=summary)
mpt_out = call_mpt(meme_prompt, max_new_tokens=150)
# 3) Parse the response
tpl_match = re.search(r"template:\s*(.+)", mpt_out, re.IGNORECASE)
text0_match = re.search(r"text0:\s*(.+)", mpt_out, re.IGNORECASE)
text1_match = re.search(r"text1:\s*(.+)", mpt_out, re.IGNORECASE)
if not (tpl_match and text0_match and text1_match):
raise ValueError(f"Could not parse model output:\n{mpt_out}")
template = tpl_match.group(1).strip()
text0 = text0_match.group(1).strip()
text1 = text1_match.group(1).strip()
# 4) Render the meme via Imgflip
template_id = TEMPLATE_IDS.get(template)
if template_id is None:
raise KeyError(f"Unknown template: {template}")
creds = st.secrets["imgflip"]
params = {
"template_id": template_id,
"username": creds["username"],
"password": creds["password"],
"text0": text0,
"text1": text1,
}
resp = requests.post(IMGFLIP_URL, params=params)
resp.raise_for_status()
data = resp.json()
if not data["success"]:
raise Exception(data["error_message"])
return data["data"]["url"]
|