Rausda6 commited on
Commit
8b21b4a
·
verified ·
1 Parent(s): 5e62778

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -42
app.py CHANGED
@@ -1,7 +1,4 @@
1
  import gradio as gr
2
- from pydub import AudioSegment
3
- from google import genai
4
- from google.genai import types
5
  import random
6
  import time
7
  import os
@@ -14,11 +11,13 @@ import edge_tts
14
  import asyncio
15
  import aiofiles
16
  import mimetypes
17
- from typing import List, Dict
18
 
 
 
19
 
20
  # Define model name clearly
21
- MODEL_NAME = "unsloth/gemma-3-1b-pt" #HuggingFaceH4/zephyr-7b-alpha"
22
 
23
  # Device setup
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -27,7 +26,7 @@ print(f"Using device: {device}")
27
  # Load model and tokenizer (explicit evaluation mode)
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
  model = AutoModelForCausalLM.from_pretrained(
30
- MODEL_NAME,
31
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
32
  ).eval().to(device)
33
 
@@ -97,7 +96,7 @@ You are a professional podcast generator. Your task is to generate a professiona
97
  Follow this example structure:
98
  {example}
99
  """
100
- user_prompt = ""
101
  if prompt and file_obj:
102
  user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
103
  elif prompt:
@@ -105,42 +104,35 @@ Follow this example structure:
105
  else:
106
  user_prompt = "Please generate a podcast script based on the uploaded file."
107
 
108
- messages = []
109
-
110
- # If file is provided, add it to the messages
111
  if file_obj:
112
- file_data = await self._read_file_bytes(file_obj)
113
- mime_type = self._get_mime_type(file_obj.name)
114
-
115
- messages.append(
116
- types.Content(
117
- role="user",
118
- parts=[
119
- types.Part.from_bytes(
120
- data=file_data,
121
- mime_type=mime_type,
122
- )
123
- ],
124
- )
125
- )
126
-
127
- # Add text prompt
128
- messages.append(
129
- types.Content(
130
- role="user",
131
- parts=[
132
- types.Part.from_text(text=user_prompt)
133
- ],
134
- )
135
- )
136
 
137
  try:
138
  if progress:
139
  progress(0.3, "Generating podcast script...")
140
 
141
- # Compose the prompt from your messages
142
- prompt_text = system_prompt + "\n" + "\n".join([msg["content"] for msg in messages])
143
-
144
  def hf_generate(prompt_text):
145
  inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
146
  outputs = model.generate(
@@ -149,8 +141,7 @@ Follow this example structure:
149
  do_sample=True,
150
  temperature=1.0
151
  )
152
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
153
- return text
154
 
155
  generated_text = await asyncio.wait_for(
156
  asyncio.to_thread(hf_generate, prompt_text),
@@ -162,14 +153,14 @@ Follow this example structure:
162
  except Exception as e:
163
  raise Exception(f"Failed to generate podcast script: {e}")
164
 
165
- print(f"Generated podcast script:\n{generated_text}")
166
-
167
  if progress:
168
  progress(0.4, "Script generated successfully!")
169
 
170
- # Ensure the return type matches the original code (as JSON)
171
  return json.loads(generated_text)
172
 
 
 
 
173
 
174
  async def _read_file_bytes(self, file_obj) -> bytes:
175
  """Read file bytes from a file object"""
 
1
  import gradio as gr
 
 
 
2
  import random
3
  import time
4
  import os
 
11
  import asyncio
12
  import aiofiles
13
  import mimetypes
14
+ from typing import List
15
 
16
+ # New import for PDF parsing
17
+ from PyPDF2 import PdfReader
18
 
19
  # Define model name clearly
20
+ MODEL_NAME = "unsloth/gemma-3-1b-pt" # HuggingFaceH4/zephyr-7b-alpha
21
 
22
  # Device setup
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
26
  # Load model and tokenizer (explicit evaluation mode)
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
  model = AutoModelForCausalLM.from_pretrained(
29
+ MODEL_NAME,
30
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
31
  ).eval().to(device)
32
 
 
96
  Follow this example structure:
97
  {example}
98
  """
99
+ # Build the user prompt
100
  if prompt and file_obj:
101
  user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
102
  elif prompt:
 
104
  else:
105
  user_prompt = "Please generate a podcast script based on the uploaded file."
106
 
107
+ # If a file is provided, extract its text and append
 
 
108
  if file_obj:
109
+ # enforce size limit
110
+ file_size = getattr(file_obj, 'size', os.path.getsize(file_obj.name))
111
+ if file_size > MAX_FILE_SIZE_BYTES:
112
+ raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.")
113
+
114
+ # extract text based on mime
115
+ ext = os.path.splitext(file_obj.name)[1].lower()
116
+ if ext == '.pdf':
117
+ reader = PdfReader(file_obj)
118
+ text = "\n\n".join(page.extract_text() or '' for page in reader.pages)
119
+ else:
120
+ # txt or other
121
+ if hasattr(file_obj, 'read'):
122
+ raw = file_obj.read()
123
+ else:
124
+ raw = await aiofiles.open(file_obj.name, 'rb').read()
125
+ text = raw.decode(errors='ignore')
126
+
127
+ user_prompt += f"\n\n―― FILE CONTENT ――\n{text}"
128
+
129
+ # Combine system and user prompts
130
+ prompt_text = system_prompt + "\n" + user_prompt
 
 
131
 
132
  try:
133
  if progress:
134
  progress(0.3, "Generating podcast script...")
135
 
 
 
 
136
  def hf_generate(prompt_text):
137
  inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
138
  outputs = model.generate(
 
141
  do_sample=True,
142
  temperature=1.0
143
  )
144
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
145
 
146
  generated_text = await asyncio.wait_for(
147
  asyncio.to_thread(hf_generate, prompt_text),
 
153
  except Exception as e:
154
  raise Exception(f"Failed to generate podcast script: {e}")
155
 
 
 
156
  if progress:
157
  progress(0.4, "Script generated successfully!")
158
 
 
159
  return json.loads(generated_text)
160
 
161
+ # ... rest of class unchanged ...
162
+
163
+
164
 
165
  async def _read_file_bytes(self, file_obj) -> bytes:
166
  """Read file bytes from a file object"""