rushankg commited on
Commit
66a27f1
Β·
verified Β·
1 Parent(s): 146332a

Update memes.py

Browse files
Files changed (1) hide show
  1. memes.py +42 -87
memes.py CHANGED
@@ -3,7 +3,7 @@ import streamlit as st
3
  import re
4
  import torch
5
  import requests
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from prompts import SUMMARY_PROMPT, MEME_PROMPT
8
 
9
  IMGFLIP_URL = "https://api.imgflip.com/caption_image"
@@ -24,106 +24,61 @@ TEMPLATE_IDS = {
24
  "Hide the Pain Harold": "27813981",
25
  }
26
 
27
- @st.cache_resource
28
- def load_llama3():
29
- """
30
- Load Llama-3.2-1B and its tokenizer.
31
- """
32
- st.write("πŸ”„ Loading Llama-3.2-1B model and tokenizer...")
33
- tokenizer = AutoTokenizer.from_pretrained(
34
- "meta-llama/Llama-3.2-1B",
35
- trust_remote_code=True,
36
- use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
37
- )
38
- model = AutoModelForCausalLM.from_pretrained(
39
- "meta-llama/Llama-3.2-1B",
40
- device_map="auto",
41
- torch_dtype=torch.float16,
42
- trust_remote_code=True,
43
- use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
44
- )
45
- st.write("βœ… Model and tokenizer loaded.")
46
- return tokenizer, model
47
 
48
-
49
- def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
50
- """
51
- Generate text with Llama-3.2-1B.
52
- """
53
- st.write("πŸ“ Sending prompt to Llama-3.2-1B...")
54
- st.write(prompt if len(prompt) < 200 else prompt[:200] + '...')
55
- tokenizer, model = load_llama3()
56
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
57
- outputs = model.generate(
58
- **inputs,
59
- max_new_tokens=max_new_tokens,
60
- do_sample=False,
61
- pad_token_id=tokenizer.eos_token_id,
62
  )
63
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
- st.write("πŸ’‘ Model response:")
65
- st.write(text)
66
- return text
67
-
68
 
69
  def article_to_meme(article_text: str) -> str:
70
- """
71
- End-to-end pipeline:
72
- 1) Summarize the article via Llama-3.2-1B.
73
- 2) Ask Llama-3.2-1B to choose a meme template and produce two 6-8 word captions.
74
- 3) Parse the model's response.
75
- 4) Call Imgflip API to render the meme and return its URL.
76
- """
77
  # 1) Summarize
78
- st.write("▢️ Step 1: Summarizing article...")
79
- sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
80
- summary = call_llama3(sum_prompt, max_new_tokens=500).strip()
81
- st.write(f"πŸ“ Summary: {summary}")
82
 
83
- # 2) Template + captions
84
- st.write("▢️ Step 2: Generating meme template and captions...")
85
- meme_prompt = MEME_PROMPT.format(summary=summary)
86
- llama_out = call_llama3(meme_prompt, max_new_tokens=150)
87
 
88
- # 3) Parse the response
89
- st.write("▢️ Step 3: Parsing model output...")
90
- tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
91
- text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
92
- text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
93
- if not (tpl_match and text0_match and text1_match):
94
- st.error(f"❌ Could not parse model output:\n{llama_out}")
95
- raise ValueError(f"Could not parse model output:\n{llama_out}")
 
96
 
97
- template = tpl_match.group(1).strip()
98
- text0 = text0_match.group(1).strip()
99
- text1 = text1_match.group(1).strip()
100
- st.write(f"🎯 Chosen template: {template}")
101
- st.write(f"πŸ’¬ text0: {text0}")
102
- st.write(f"πŸ’¬ text1: {text1}")
103
-
104
- # 4) Render the meme via Imgflip
105
- st.write("▢️ Step 4: Rendering meme via Imgflip...")
106
- template_id = TEMPLATE_IDS.get(template)
107
- if template_id is None:
108
- st.error(f"❌ Unknown template: {template}")
109
  raise KeyError(f"Unknown template: {template}")
110
-
111
  creds = st.secrets["imgflip"]
112
- params = {
113
- "template_id": template_id,
114
- "username": creds["username"],
115
- "password": creds["password"],
116
- "text0": text0,
117
- "text1": text1,
118
- }
119
- resp = requests.post(IMGFLIP_URL, params=params)
120
- st.write("πŸ”— Imgflip API called...")
 
121
  resp.raise_for_status()
122
  data = resp.json()
123
  if not data.get("success", False):
124
- st.error(f"❌ Imgflip error: {data.get('error_message')}")
125
  raise Exception(data.get("error_message"))
126
 
127
  meme_url = data["data"]["url"]
128
- st.write(f"βœ… Meme URL: {meme_url}")
129
  return meme_url
 
3
  import re
4
  import torch
5
  import requests
6
+ from openai import ChatCompletion
7
  from prompts import SUMMARY_PROMPT, MEME_PROMPT
8
 
9
  IMGFLIP_URL = "https://api.imgflip.com/caption_image"
 
24
  "Hide the Pain Harold": "27813981",
25
  }
26
 
27
+ # OpenAI config
28
+ openai_api_key = st.secrets["OPENAI_API_KEY"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def call_openai(prompt: str) -> str:
31
+ """Call gpt-4o-mini once with given prompt."""
32
+ response = ChatCompletion.create(
33
+ model="gpt-4o-mini",
34
+ messages=[{"role": "user", "content": prompt}],
35
+ max_tokens=200,
36
+ temperature=0.7
 
 
 
 
 
 
 
37
  )
38
+ return response.choices[0].message.content.strip()
 
 
 
 
39
 
40
  def article_to_meme(article_text: str) -> str:
 
 
 
 
 
 
 
41
  # 1) Summarize
42
+ st.write("⏳ Summarizing article...")
43
+ summary = call_openai(SUMMARY_PROMPT.format(article_text=article_text))
44
+ st.write("βœ… Summary complete.")
 
45
 
46
+ # 2) Choose template + captions
47
+ st.write("⏳ Generating meme captions...")
48
+ output = call_openai(MEME_PROMPT.format(summary=summary))
49
+ st.write("βœ… Captions generated.")
50
 
51
+ # 3) Parse model output
52
+ match_t = re.search(r"template:\s*(.+)", output, re.IGNORECASE)
53
+ match0 = re.search(r"text0:\s*(.+)", output, re.IGNORECASE)
54
+ match1 = re.search(r"text1:\s*(.+)", output, re.IGNORECASE)
55
+ if not (match_t and match0 and match1):
56
+ raise ValueError(f"Parsing failed: {output}")
57
+ template = match_t.group(1).strip()
58
+ text0 = match0.group(1).strip()
59
+ text1 = match1.group(1).strip()
60
 
61
+ # 4) Render meme
62
+ st.write("⏳ Rendering meme...")
63
+ tpl_id = TEMPLATE_IDS.get(template)
64
+ if not tpl_id:
 
 
 
 
 
 
 
 
65
  raise KeyError(f"Unknown template: {template}")
 
66
  creds = st.secrets["imgflip"]
67
+ resp = requests.post(
68
+ IMGFLIP_URL,
69
+ params={
70
+ "template_id": tpl_id,
71
+ "username": creds["username"],
72
+ "password": creds["password"],
73
+ "text0": text0,
74
+ "text1": text1,
75
+ }
76
+ )
77
  resp.raise_for_status()
78
  data = resp.json()
79
  if not data.get("success", False):
 
80
  raise Exception(data.get("error_message"))
81
 
82
  meme_url = data["data"]["url"]
83
+ st.write(f"βœ… Meme ready: [View here]({meme_url})")
84
  return meme_url