rushankg commited on
Commit
afe469e
·
verified ·
1 Parent(s): 5a63adb

Update memes.py

Browse files
Files changed (1) hide show
  1. memes.py +17 -25
memes.py CHANGED
@@ -25,27 +25,27 @@ TEMPLATE_IDS = {
25
  }
26
 
27
  @st.cache_resource
28
- def load_mpt():
29
  """
30
- Load the MosaicML MPT-7B-Chat model and tokenizer.
31
  """
32
  tokenizer = AutoTokenizer.from_pretrained(
33
- "mosaicml/mpt-7b-chat",
34
  trust_remote_code=True
35
  )
36
  model = AutoModelForCausalLM.from_pretrained(
37
- "mosaicml/mpt-7b-chat",
38
  device_map="auto",
39
  torch_dtype=torch.float16,
40
  trust_remote_code=True
41
  )
42
  return tokenizer, model
43
 
44
- def call_mpt(prompt: str, max_new_tokens: int = 200) -> str:
45
  """
46
- Generate text from MPT-7B-Chat given a prompt.
47
  """
48
- tokenizer, model = load_mpt()
49
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
  outputs = model.generate(
51
  **inputs,
@@ -56,33 +56,26 @@ def call_mpt(prompt: str, max_new_tokens: int = 200) -> str:
56
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
 
58
  def article_to_meme(article_text: str) -> str:
59
- """
60
- End-to-end pipeline:
61
- 1) Summarize the article via MPT-7B-Chat.
62
- 2) Ask MPT-7B-Chat to choose a meme template and produce two 6-8 word captions.
63
- 3) Parse the model's response.
64
- 4) Call Imgflip API to render the meme and return its URL.
65
- """
66
- # 1) Summarize
67
  sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
68
- summary = call_mpt(sum_prompt, max_new_tokens=100).strip()
69
 
70
- # 2) Template + captions
71
  meme_prompt = MEME_PROMPT.format(summary=summary)
72
- mpt_out = call_mpt(meme_prompt, max_new_tokens=150)
73
 
74
- # 3) Parse the response
75
- tpl_match = re.search(r"template:\s*(.+)", mpt_out, re.IGNORECASE)
76
- text0_match = re.search(r"text0:\s*(.+)", mpt_out, re.IGNORECASE)
77
- text1_match = re.search(r"text1:\s*(.+)", mpt_out, re.IGNORECASE)
78
  if not (tpl_match and text0_match and text1_match):
79
- raise ValueError(f"Could not parse model output:\n{mpt_out}")
80
 
81
  template = tpl_match.group(1).strip()
82
  text0 = text0_match.group(1).strip()
83
  text1 = text1_match.group(1).strip()
84
 
85
- # 4) Render the meme via Imgflip
86
  template_id = TEMPLATE_IDS.get(template)
87
  if template_id is None:
88
  raise KeyError(f"Unknown template: {template}")
@@ -100,5 +93,4 @@ def article_to_meme(article_text: str) -> str:
100
  data = resp.json()
101
  if not data["success"]:
102
  raise Exception(data["error_message"])
103
-
104
  return data["data"]["url"]
 
25
  }
26
 
27
  @st.cache_resource
28
+ def load_llama3():
29
  """
30
+ Load Llama-3.2-1B and its tokenizer.
31
  """
32
  tokenizer = AutoTokenizer.from_pretrained(
33
+ "meta-llama/Llama-3.2-1B",
34
  trust_remote_code=True
35
  )
36
  model = AutoModelForCausalLM.from_pretrained(
37
+ "meta-llama/Llama-3.2-1B",
38
  device_map="auto",
39
  torch_dtype=torch.float16,
40
  trust_remote_code=True
41
  )
42
  return tokenizer, model
43
 
44
+ def call_llama3(prompt: str, max_new_tokens: int = 200) -> str:
45
  """
46
+ Generate text with Llama-3.2-1B.
47
  """
48
+ tokenizer, model = load_llama3()
49
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
50
  outputs = model.generate(
51
  **inputs,
 
56
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
 
58
  def article_to_meme(article_text: str) -> str:
59
+ # Summarize
 
 
 
 
 
 
 
60
  sum_prompt = SUMMARY_PROMPT.format(article_text=article_text)
61
+ summary = call_llama3(sum_prompt, max_new_tokens=100).strip()
62
 
63
+ # Meme template + captions
64
  meme_prompt = MEME_PROMPT.format(summary=summary)
65
+ llama_out = call_llama3(meme_prompt, max_new_tokens=150)
66
 
67
+ # Parse response
68
+ tpl_match = re.search(r"template:\s*(.+)", llama_out, re.IGNORECASE)
69
+ text0_match = re.search(r"text0:\s*(.+)", llama_out, re.IGNORECASE)
70
+ text1_match = re.search(r"text1:\s*(.+)", llama_out, re.IGNORECASE)
71
  if not (tpl_match and text0_match and text1_match):
72
+ raise ValueError(f"Could not parse model output:\n{llama_out}")
73
 
74
  template = tpl_match.group(1).strip()
75
  text0 = text0_match.group(1).strip()
76
  text1 = text1_match.group(1).strip()
77
 
78
+ # Render meme
79
  template_id = TEMPLATE_IDS.get(template)
80
  if template_id is None:
81
  raise KeyError(f"Unknown template: {template}")
 
93
  data = resp.json()
94
  if not data["success"]:
95
  raise Exception(data["error_message"])
 
96
  return data["data"]["url"]