rushankg commited on
Commit
16fef17
·
verified ·
1 Parent(s): 7ec3808

Update memes.py

Browse files
Files changed (1) hide show
  1. memes.py +38 -23
memes.py CHANGED
@@ -25,59 +25,73 @@ TEMPLATE_IDS = {
25
  }
26
 
27
  @st.cache_resource
28
- def load_gemma():
 
 
 
29
  tokenizer = AutoTokenizer.from_pretrained(
30
- "google/gemma-3-27b-it",
31
- use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
32
  )
33
  model = AutoModelForCausalLM.from_pretrained(
34
- "google/gemma-3-27b-it",
35
- torch_dtype=torch.float16,
36
  device_map="auto",
37
- use_auth_token=st.secrets["HUGGINGFACE_TOKEN"]
 
38
  )
39
  return tokenizer, model
40
 
41
- def call_gemma(prompt: str, max_new_tokens=200) -> str:
42
- tok, model = load_gemma()
43
- inputs = tok(prompt, return_tensors="pt").to(model.device)
44
- out = model.generate(
 
 
 
45
  **inputs,
46
  max_new_tokens=max_new_tokens,
47
  do_sample=False,
48
- pad_token_id=tok.eos_token_id
49
  )
50
- return tok.decode(out[0], skip_special_tokens=True)
51
 
52
  def article_to_meme(article_text: str) -> str:
 
 
 
 
 
 
 
53
  # 1) Summarize
54
  sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
55
- summary = call_gemma(sum_prompt, max_new_tokens=100).strip()
56
 
57
- # 2) Meme‐text generation + template choice
58
  meme_prompt = MEME_PROMPT.format(summary=summary)
59
- gemma_out = call_gemma(meme_prompt, max_new_tokens=150)
60
 
61
- # 3) Parse Gemini’s response
62
- tpl_match = re.search(r"template:\s*(.+)", gemma_out, re.IGNORECASE)
63
- text0_match = re.search(r"text0:\s*(.+)", gemma_out, re.IGNORECASE)
64
- text1_match = re.search(r"text1:\s*(.+)", gemma_out, re.IGNORECASE)
65
  if not (tpl_match and text0_match and text1_match):
66
- raise ValueError(f"Could not parse Gemini output:\n{gemma_out}")
67
 
68
  template = tpl_match.group(1).strip()
69
  text0 = text0_match.group(1).strip()
70
  text1 = text1_match.group(1).strip()
71
 
72
- # 4) Render meme
73
  template_id = TEMPLATE_IDS.get(template)
74
  if template_id is None:
75
  raise KeyError(f"Unknown template: {template}")
76
 
 
77
  params = {
78
  "template_id": template_id,
79
- "username": st.secrets["IMGFLIP_USERNAME"],
80
- "password": st.secrets["IMGFLIP_PASSWORD"],
81
  "text0": text0,
82
  "text1": text1,
83
  }
@@ -86,4 +100,5 @@ def article_to_meme(article_text: str) -> str:
86
  data = resp.json()
87
  if not data["success"]:
88
  raise Exception(data["error_message"])
 
89
  return data["data"]["url"]
 
25
  }
26
 
27
  @st.cache_resource
28
+ def load_mpt():
29
+ """
30
+ Load the MosaicML MPT-7B-Chat model and tokenizer.
31
+ """
32
  tokenizer = AutoTokenizer.from_pretrained(
33
+ "mosaicml/mpt-7b-chat",
34
+ trust_remote_code=True
35
  )
36
  model = AutoModelForCausalLM.from_pretrained(
37
+ "mosaicml/mpt-7b-chat",
 
38
  device_map="auto",
39
+ torch_dtype=torch.float16,
40
+ trust_remote_code=True
41
  )
42
  return tokenizer, model
43
 
44
+ def call_mpt(prompt: str, max_new_tokens: int = 200) -> str:
45
+ """
46
+ Generate text from MPT-7B-Chat given a prompt.
47
+ """
48
+ tokenizer, model = load_mpt()
49
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
+ outputs = model.generate(
51
  **inputs,
52
  max_new_tokens=max_new_tokens,
53
  do_sample=False,
54
+ pad_token_id=tokenizer.eos_token_id,
55
  )
56
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
 
58
  def article_to_meme(article_text: str) -> str:
59
+ """
60
+ End-to-end pipeline:
61
+ 1) Summarize the article via MPT-7B-Chat.
62
+ 2) Ask MPT-7B-Chat to choose a meme template and produce two 6-8 word captions.
63
+ 3) Parse the model's response.
64
+ 4) Call Imgflip API to render the meme and return its URL.
65
+ """
66
  # 1) Summarize
67
  sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
68
+ summary = call_mpt(sum_prompt, max_new_tokens=100).strip()
69
 
70
+ # 2) Template + captions
71
  meme_prompt = MEME_PROMPT.format(summary=summary)
72
+ mpt_out = call_mpt(meme_prompt, max_new_tokens=150)
73
 
74
+ # 3) Parse the response
75
+ tpl_match = re.search(r"template:\s*(.+)", mpt_out, re.IGNORECASE)
76
+ text0_match = re.search(r"text0:\s*(.+)", mpt_out, re.IGNORECASE)
77
+ text1_match = re.search(r"text1:\s*(.+)", mpt_out, re.IGNORECASE)
78
  if not (tpl_match and text0_match and text1_match):
79
+ raise ValueError(f"Could not parse model output:\n{mpt_out}")
80
 
81
  template = tpl_match.group(1).strip()
82
  text0 = text0_match.group(1).strip()
83
  text1 = text1_match.group(1).strip()
84
 
85
+ # 4) Render the meme via Imgflip
86
  template_id = TEMPLATE_IDS.get(template)
87
  if template_id is None:
88
  raise KeyError(f"Unknown template: {template}")
89
 
90
+ creds = st.secrets["imgflip"]
91
  params = {
92
  "template_id": template_id,
93
+ "username": creds["username"],
94
+ "password": creds["password"],
95
  "text0": text0,
96
  "text1": text1,
97
  }
 
100
  data = resp.json()
101
  if not data["success"]:
102
  raise Exception(data["error_message"])
103
+
104
  return data["data"]["url"]