rushankg commited on
Commit
98ef818
Β·
verified Β·
1 Parent(s): 3818df8

Update memes.py

Browse files
Files changed (1) hide show
  1. memes.py +39 -8
memes.py CHANGED
@@ -29,6 +29,7 @@ def load_llama3():
29
  """
30
  Load Llama-3.2-1B and its tokenizer.
31
  """
 
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  "meta-llama/Llama-3.2-1B",
34
  trust_remote_code=True,
@@ -41,12 +42,16 @@ def load_llama3():
41
  trust_remote_code=True,
42
  use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
43
  )
 
44
  return tokenizer, model
45
 
 
46
  def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
47
  """
48
  Generate text with Llama-3.2-1B.
49
  """
 
 
50
  tokenizer, model = load_llama3()
51
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
52
  outputs = model.generate(
@@ -55,31 +60,52 @@ def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
55
  do_sample=False,
56
  pad_token_id=tokenizer.eos_token_id,
57
  )
58
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
59
 
60
  def article_to_meme(article_text: str) -> str:
61
- # Summarize
 
 
 
 
 
 
 
 
62
  sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
63
  summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
 
64
 
65
- # Meme template + captions
 
66
  meme_prompt = MEME_PROMPT.format(summary=summary)
67
  llama_out = call_llama3(meme_prompt, max_new_tokens=150)
68
 
69
- # Parse response
 
70
  tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
71
  text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
72
  text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
73
  if not (tpl_match and text0_match and text1_match):
 
74
  raise ValueError(f"Could not parse model output:\n{llama_out}")
75
 
76
  template = tpl_match.group(1).strip()
77
  text0 = text0_match.group(1).strip()
78
  text1 = text1_match.group(1).strip()
 
 
 
79
 
80
- # Render meme
 
81
  template_id = TEMPLATE_IDS.get(template)
82
  if template_id is None:
 
83
  raise KeyError(f"Unknown template: {template}")
84
 
85
  creds = st.secrets["imgflip"]
@@ -91,8 +117,13 @@ def article_to_meme(article_text: str) -> str:
91
  "text1": text1,
92
  }
93
  resp = requests.post(IMGFLIP_URL, params=params)
 
94
  resp.raise_for_status()
95
  data = resp.json()
96
- if not data["success"]:
97
- raise Exception(data["error_message"])
98
- return data["data"]["url"]
 
 
 
 
 
29
  """
30
  Load Llama-3.2-1B and its tokenizer.
31
  """
32
+ st.write("πŸ”„ Loading Llama-3.2-1B model and tokenizer...")
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  "meta-llama/Llama-3.2-1B",
35
  trust_remote_code=True,
 
42
  trust_remote_code=True,
43
  use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
44
  )
45
+ st.write("βœ… Model and tokenizer loaded.")
46
  return tokenizer, model
47
 
48
+
49
  def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
50
  """
51
  Generate text with Llama-3.2-1B.
52
  """
53
+ st.write("πŸ“ Sending prompt to Llama-3.2-1B...")
54
+ st.write(prompt if len(prompt) < 200 else prompt[:200] + '...')
55
  tokenizer, model = load_llama3()
56
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
57
  outputs = model.generate(
 
60
  do_sample=False,
61
  pad_token_id=tokenizer.eos_token_id,
62
  )
63
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+ st.write("πŸ’‘ Model response:")
65
+ st.write(text)
66
+ return text
67
+
68
 
69
  def article_to_meme(article_text: str) -> str:
70
+ """
71
+ End-to-end pipeline:
72
+ 1) Summarize the article via Llama-3.2-1B.
73
+ 2) Ask Llama-3.2-1B to choose a meme template and produce two 6-8 word captions.
74
+ 3) Parse the model's response.
75
+ 4) Call Imgflip API to render the meme and return its URL.
76
+ """
77
+ # 1) Summarize
78
+ st.write("▢️ Step 1: Summarizing article...")
79
  sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
80
  summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
81
+ st.write(f"πŸ“ Summary: {summary}")
82
 
83
+ # 2) Template + captions
84
+ st.write("▢️ Step 2: Generating meme template and captions...")
85
  meme_prompt = MEME_PROMPT.format(summary=summary)
86
  llama_out = call_llama3(meme_prompt, max_new_tokens=150)
87
 
88
+ # 3) Parse the response
89
+ st.write("▢️ Step 3: Parsing model output...")
90
  tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
91
  text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
92
  text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
93
  if not (tpl_match and text0_match and text1_match):
94
+ st.error(f"❌ Could not parse model output:\n{llama_out}")
95
  raise ValueError(f"Could not parse model output:\n{llama_out}")
96
 
97
  template = tpl_match.group(1).strip()
98
  text0 = text0_match.group(1).strip()
99
  text1 = text1_match.group(1).strip()
100
+ st.write(f"🎯 Chosen template: {template}")
101
+ st.write(f"πŸ’¬ text0: {text0}")
102
+ st.write(f"πŸ’¬ text1: {text1}")
103
 
104
+ # 4) Render the meme via Imgflip
105
+ st.write("▢️ Step 4: Rendering meme via Imgflip...")
106
  template_id = TEMPLATE_IDS.get(template)
107
  if template_id is None:
108
+ st.error(f"❌ Unknown template: {template}")
109
  raise KeyError(f"Unknown template: {template}")
110
 
111
  creds = st.secrets["imgflip"]
 
117
  "text1": text1,
118
  }
119
  resp = requests.post(IMGFLIP_URL, params=params)
120
+ st.write("πŸ”— Imgflip API called...")
121
  resp.raise_for_status()
122
  data = resp.json()
123
+ if not data.get("success", False):
124
+ st.error(f"❌ Imgflip error: {data.get('error_message')}")
125
+ raise Exception(data.get("error_message"))
126
+
127
+ meme_url = data["data"]["url"]
128
+ st.write(f"βœ… Meme URL: {meme_url}")
129
+ return meme_url