File size: 4,637 Bytes
f116710 3946047 f116710 3946047 f116710 3946047 f116710 afe469e 16fef17 afe469e 16fef17 98ef818 7ec3808 afe469e bc0ca96 7ec3808 f116710 afe469e 7ec3808 16fef17 bc0ca96 f116710 98ef818 f116710 98ef818 afe469e 16fef17 afe469e 16fef17 98ef818 afe469e 16fef17 f116710 16fef17 f116710 98ef818 f116710 98ef818 f116710 afe469e 98ef818 f116710 98ef818 f116710 afe469e f116710 98ef818 afe469e f116710 98ef818 afe469e f116710 98ef818 f116710 98ef818 f116710 98ef818 f116710 16fef17 f116710 16fef17 f116710 98ef818 f116710 98ef818 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# memes.py
import streamlit as st
import re
import torch
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM
from prompts import SUMMARY_PROMPT, MEME_PROMPT
IMGFLIP_URL = "https://api.imgflip.com/caption_image"
# 12 template names β Imgflip template_ids
TEMPLATE_IDS = {
"Drake Hotline Bling": "181913649",
"UNO Draw 25 Cards": "217743513",
"Bernie Asking For Support": "222403160",
"Disaster Girl": "97984",
"Waiting Skeleton": "109765",
"Always Has Been": "252600902",
"Woman Yelling at Cat": "188390779",
"I Bet He's Thinking About Other Women": "110163934",
"One Does Not Simply": "61579",
"Success Kid": "61544",
"Oprah You Get A": "28251713",
"Hide the Pain Harold": "27813981",
}
@st.cache_resource
def load_llama3():
"""
Load Llama-3.2-1B and its tokenizer.
"""
st.write("π Loading Llama-3.2-1B model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.2-1B",
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
)
st.write("β
Model and tokenizer loaded.")
return tokenizer, model
def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
"""
Generate text with Llama-3.2-1B.
"""
st.write("π Sending prompt to Llama-3.2-1B...")
st.write(prompt if len(prompt) < 200 else prompt[:200] + '...')
tokenizer, model = load_llama3()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.write("π‘ Model response:")
st.write(text)
return text
def article_to_meme(article_text: str) -> str:
"""
End-to-end pipeline:
1) Summarize the article via Llama-3.2-1B.
2) Ask Llama-3.2-1B to choose a meme template and produce two 6-8 word captions.
3) Parse the model's response.
4) Call Imgflip API to render the meme and return its URL.
"""
# 1) Summarize
st.write("βΆοΈ Step 1: Summarizing article...")
sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
st.write(f"π Summary: {summary}")
# 2) Template + captions
st.write("βΆοΈ Step 2: Generating meme template and captions...")
meme_prompt = MEME_PROMPT.format(summary=summary)
llama_out = call_llama3(meme_prompt, max_new_tokens=150)
# 3) Parse the response
st.write("βΆοΈ Step 3: Parsing model output...")
tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
if not (tpl_match and text0_match and text1_match):
st.error(f"β Could not parse model output:\n{llama_out}")
raise ValueError(f"Could not parse model output:\n{llama_out}")
template = tpl_match.group(1).strip()
text0 = text0_match.group(1).strip()
text1 = text1_match.group(1).strip()
st.write(f"π― Chosen template: {template}")
st.write(f"π¬ text0: {text0}")
st.write(f"π¬ text1: {text1}")
# 4) Render the meme via Imgflip
st.write("βΆοΈ Step 4: Rendering meme via Imgflip...")
template_id = TEMPLATE_IDS.get(template)
if template_id is None:
st.error(f"β Unknown template: {template}")
raise KeyError(f"Unknown template: {template}")
creds = st.secrets["imgflip"]
params = {
"template_id": template_id,
"username": creds["username"],
"password": creds["password"],
"text0": text0,
"text1": text1,
}
resp = requests.post(IMGFLIP_URL, params=params)
st.write("π Imgflip API called...")
resp.raise_for_status()
data = resp.json()
if not data.get("success", False):
st.error(f"β Imgflip error: {data.get('error_message')}")
raise Exception(data.get("error_message"))
meme_url = data["data"]["url"]
st.write(f"β
Meme URL: {meme_url}")
return meme_url
|