barunsaha commited on
Commit
1e52982
·
1 Parent(s): 9ca95c2

Add model and template validations as well as some refactoring of the SlideDeckAI class

Browse files
Files changed (1) hide show
  1. src/slidedeckai/core.py +153 -77
src/slidedeckai/core.py CHANGED
@@ -1,59 +1,148 @@
1
  """
2
- Core classes for SlideDeckAI.
3
  """
4
  import logging
5
  import os
6
  import pathlib
7
  import tempfile
8
- from typing import Union
9
 
10
  import json5
11
  from dotenv import load_dotenv
12
 
13
  from . import global_config as gcfg
14
  from .global_config import GlobalConfig
 
15
  from .helpers import llm_helper, pptx_helper, text_helper
16
  from .helpers.chat_helper import ChatMessageHistory
17
 
18
  load_dotenv()
19
 
20
  RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'
 
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
- from .helpers import file_manager as filem
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class SlideDeckAI:
27
  """
28
  The main class for generating slide decks.
29
  """
30
 
31
- def __init__(self, model, topic, api_key=None, pdf_path_or_stream=None, pdf_page_range=None, template_idx=0):
 
 
 
 
 
 
 
 
32
  """
33
- Initializes the SlideDeckAI object.
34
-
35
- :param model: The name of the LLM model to use.
36
- :param topic: The topic of the slide deck.
37
- :param api_key: The API key for the LLM provider.
38
- :param pdf_path_or_stream: The path to a PDF file or a file-like object.
39
- :param pdf_page_range: A tuple representing the page range to use from the PDF file.
40
- :param template_idx: The index of the PowerPoint template to use.
 
 
 
 
41
  """
42
- self.model = model
43
- self.topic = topic
44
- self.api_key = api_key
 
 
 
 
 
 
45
  self.pdf_path_or_stream = pdf_path_or_stream
46
  self.pdf_page_range = pdf_page_range
47
- self.template_idx = template_idx
 
 
48
  self.chat_history = ChatMessageHistory()
49
  self.last_response = None
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def _get_prompt_template(self, is_refinement: bool) -> str:
52
  """
53
  Return a prompt template.
54
 
55
- :param is_refinement: Whether this is the initial or refinement prompt.
56
- :return: The prompt template as f-string.
 
 
 
57
  """
58
  if is_refinement:
59
  with open(GlobalConfig.REFINEMENT_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file:
@@ -65,8 +154,13 @@ class SlideDeckAI:
65
 
66
  def generate(self, progress_callback=None):
67
  """
68
- Generates the initial slide deck.
69
- :return: The path to the generated .pptx file.
 
 
 
 
 
70
  """
71
  additional_info = ''
72
  if self.pdf_path_or_stream:
@@ -74,29 +168,13 @@ class SlideDeckAI:
74
 
75
  self.chat_history.add_user_message(self.topic)
76
  prompt_template = self._get_prompt_template(is_refinement=False)
77
- formatted_template = prompt_template.format(question=self.topic, additional_info=additional_info)
78
-
79
- provider, llm_name = llm_helper.get_provider_model(self.model, use_ollama=RUN_IN_OFFLINE_MODE)
80
-
81
- llm = llm_helper.get_litellm_llm(
82
- provider=provider,
83
- model=llm_name,
84
- max_new_tokens=gcfg.get_max_output_tokens(self.model),
85
- api_key=self.api_key,
86
  )
87
 
88
- response = ""
89
- for chunk in llm.stream(formatted_template):
90
- if isinstance(chunk, str):
91
- response += chunk
92
- else:
93
- content = getattr(chunk, 'content', None)
94
- if content is not None:
95
- response += content
96
- else:
97
- response += str(chunk)
98
- if progress_callback:
99
- progress_callback(len(response))
100
 
101
  self.last_response = text_helper.get_clean_json(response)
102
  self.chat_history.add_ai_message(self.last_response)
@@ -105,22 +183,32 @@ class SlideDeckAI:
105
 
106
  def revise(self, instructions, progress_callback=None):
107
  """
108
- Revises the slide deck with new instructions.
 
 
 
 
 
 
 
109
 
110
- :param instructions: The instructions for revising the slide deck.
111
- :return: The path to the revised .pptx file.
112
  """
113
  if not self.last_response:
114
- raise ValueError("You must generate a slide deck before you can revise it.")
115
 
116
  if len(self.chat_history.messages) >= 16:
117
- raise ValueError("Chat history is full. Please reset to continue.")
118
 
119
  self.chat_history.add_user_message(instructions)
120
 
121
  prompt_template = self._get_prompt_template(is_refinement=True)
122
 
123
- list_of_msgs = [f'{idx + 1}. {msg.content}' for idx, msg in enumerate(self.chat_history.messages) if msg.role == 'user']
 
 
 
124
 
125
  additional_info = ''
126
  if self.pdf_path_or_stream:
@@ -132,27 +220,8 @@ class SlideDeckAI:
132
  additional_info=additional_info,
133
  )
134
 
135
- provider, llm_name = llm_helper.get_provider_model(self.model, use_ollama=RUN_IN_OFFLINE_MODE)
136
-
137
- llm = llm_helper.get_litellm_llm(
138
- provider=provider,
139
- model=llm_name,
140
- max_new_tokens=gcfg.get_max_output_tokens(self.model),
141
- api_key=self.api_key,
142
- )
143
-
144
- response = ""
145
- for chunk in llm.stream(formatted_template):
146
- if isinstance(chunk, str):
147
- response += chunk
148
- else:
149
- content = getattr(chunk, 'content', None)
150
- if content is not None:
151
- response += content
152
- else:
153
- response += str(chunk)
154
- if progress_callback:
155
- progress_callback(len(response))
156
 
157
  self.last_response = text_helper.get_clean_json(response)
158
  self.chat_history.add_ai_message(self.last_response)
@@ -163,17 +232,20 @@ class SlideDeckAI:
163
  """
164
  Create a slide deck and return the file path.
165
 
166
- :param json_str: The content in *valid* JSON format.
167
- :return: The path to the .pptx file or `None` in case of error.
 
 
 
168
  """
169
  try:
170
  parsed_data = json5.loads(json_str)
171
  except (ValueError, RecursionError) as e:
172
- logger.error("Error parsing JSON: %s", e)
173
  try:
174
  parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
175
  except (ValueError, RecursionError) as e2:
176
- logger.error("Error parsing fixed JSON: %s", e2)
177
  return None
178
 
179
  temp = tempfile.NamedTemporaryFile(delete=False, suffix='.pptx')
@@ -183,7 +255,7 @@ class SlideDeckAI:
183
  try:
184
  pptx_helper.generate_powerpoint_presentation(
185
  parsed_data,
186
- slides_template=list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())[self.template_idx],
187
  output_file_path=path
188
  )
189
  except Exception as ex:
@@ -194,15 +266,19 @@ class SlideDeckAI:
194
 
195
  def set_template(self, idx):
196
  """
197
- Sets the PowerPoint template to use.
198
 
199
- :param idx: The index of the template to use.
 
200
  """
201
- self.template_idx = idx
 
202
 
203
  def reset(self):
204
  """
205
- Resets the chat history.
206
  """
207
  self.chat_history = ChatMessageHistory()
208
  self.last_response = None
 
 
 
1
  """
2
+ Core functionality of SlideDeckAI.
3
  """
4
  import logging
5
  import os
6
  import pathlib
7
  import tempfile
8
+ from typing import Union, Any
9
 
10
  import json5
11
  from dotenv import load_dotenv
12
 
13
  from . import global_config as gcfg
14
  from .global_config import GlobalConfig
15
+ from .helpers import file_manager as filem
16
  from .helpers import llm_helper, pptx_helper, text_helper
17
  from .helpers.chat_helper import ChatMessageHistory
18
 
19
  load_dotenv()
20
 
21
  RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true'
22
+ VALID_MODEL_NAMES = list(GlobalConfig.VALID_MODELS.keys())
23
+ VALID_TEMPLATE_NAMES = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
+
28
+ def _process_llm_chunk(chunk: Any) -> str:
29
+ """
30
+ Helper function to process LLM response chunks consistently.
31
+
32
+ Args:
33
+ chunk: The chunk received from the LLM stream.
34
+
35
+ Returns:
36
+ The processed text from the chunk.
37
+ """
38
+ if isinstance(chunk, str):
39
+ return chunk
40
+
41
+ content = getattr(chunk, 'content', None)
42
+ return content if content is not None else str(chunk)
43
+
44
+
45
+ def _stream_llm_response(llm: Any, prompt: str, progress_callback=None) -> str:
46
+ """
47
+ Helper function to stream LLM responses with consistent handling.
48
+
49
+ Args:
50
+ llm: The LLM instance to use for generating responses.
51
+ prompt: The prompt to send to the LLM.
52
+ progress_callback: A callback function to report progress.
53
+
54
+ Returns:
55
+ The complete response from the LLM.
56
+
57
+ Raises:
58
+ RuntimeError: If there's an error getting response from LLM.
59
+ """
60
+ response = ''
61
+ try:
62
+ for chunk in llm.stream(prompt):
63
+ chunk_text = _process_llm_chunk(chunk)
64
+ response += chunk_text
65
+ if progress_callback:
66
+ progress_callback(len(response))
67
+ return response
68
+ except Exception as e:
69
+ logger.error('Error streaming LLM response: %s', str(e))
70
+ raise RuntimeError(f'Failed to get response from LLM: {str(e)}') from e
71
+
72
 
73
  class SlideDeckAI:
74
  """
75
  The main class for generating slide decks.
76
  """
77
 
78
+ def __init__(
79
+ self,
80
+ model: str,
81
+ topic: str,
82
+ api_key: str = None,
83
+ pdf_path_or_stream=None,
84
+ pdf_page_range=None,
85
+ template_idx: int = 0
86
+ ):
87
  """
88
+ Initialize the SlideDeckAI object.
89
+
90
+ Args:
91
+ model: The name of the LLM model to use.
92
+ topic: The topic of the slide deck.
93
+ api_key: The API key for the LLM provider.
94
+ pdf_path_or_stream: The path to a PDF file or a file-like object.
95
+ pdf_page_range: A tuple representing the page range to use from the PDF file.
96
+ template_idx: The index of the PowerPoint template to use.
97
+
98
+ Raises:
99
+ ValueError: If the model name is not in VALID_MODELS.
100
  """
101
+ if model not in GlobalConfig.VALID_MODELS:
102
+ raise ValueError(
103
+ f'Invalid model name: {model}.'
104
+ f' Must be one of: {", ".join(VALID_MODEL_NAMES)}.'
105
+ )
106
+
107
+ self.model: str = model
108
+ self.topic: str = topic
109
+ self.api_key: str = api_key
110
  self.pdf_path_or_stream = pdf_path_or_stream
111
  self.pdf_page_range = pdf_page_range
112
+ # Validate template_idx is within valid range
113
+ num_templates = len(GlobalConfig.PPTX_TEMPLATE_FILES)
114
+ self.template_idx: int = template_idx if 0 <= template_idx < num_templates else 0
115
  self.chat_history = ChatMessageHistory()
116
  self.last_response = None
117
 
118
+ def _initialize_llm(self):
119
+ """
120
+ Initialize and return an LLM instance with the current configuration.
121
+
122
+ Returns:
123
+ Configured LLM instance.
124
+ """
125
+ provider, llm_name = llm_helper.get_provider_model(
126
+ self.model,
127
+ use_ollama=RUN_IN_OFFLINE_MODE
128
+ )
129
+
130
+ return llm_helper.get_litellm_llm(
131
+ provider=provider,
132
+ model=llm_name,
133
+ max_new_tokens=gcfg.get_max_output_tokens(self.model),
134
+ api_key=self.api_key,
135
+ )
136
+
137
  def _get_prompt_template(self, is_refinement: bool) -> str:
138
  """
139
  Return a prompt template.
140
 
141
+ Args:
142
+ is_refinement: Whether this is the initial or refinement prompt.
143
+
144
+ Returns:
145
+ The prompt template as f-string.
146
  """
147
  if is_refinement:
148
  with open(GlobalConfig.REFINEMENT_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file:
 
154
 
155
  def generate(self, progress_callback=None):
156
  """
157
+ Generate the initial slide deck.
158
+
159
+ Args:
160
+ progress_callback: Optional callback function to report progress.
161
+
162
+ Returns:
163
+ The path to the generated .pptx file.
164
  """
165
  additional_info = ''
166
  if self.pdf_path_or_stream:
 
168
 
169
  self.chat_history.add_user_message(self.topic)
170
  prompt_template = self._get_prompt_template(is_refinement=False)
171
+ formatted_template = prompt_template.format(
172
+ question=self.topic,
173
+ additional_info=additional_info
 
 
 
 
 
 
174
  )
175
 
176
+ llm = self._initialize_llm()
177
+ response = _stream_llm_response(llm, formatted_template, progress_callback)
 
 
 
 
 
 
 
 
 
 
178
 
179
  self.last_response = text_helper.get_clean_json(response)
180
  self.chat_history.add_ai_message(self.last_response)
 
183
 
184
  def revise(self, instructions, progress_callback=None):
185
  """
186
+ Revise the slide deck with new instructions.
187
+
188
+ Args:
189
+ instructions: The instructions for revising the slide deck.
190
+ progress_callback: Optional callback function to report progress.
191
+
192
+ Returns:
193
+ The path to the revised .pptx file.
194
 
195
+ Raises:
196
+ ValueError: If no slide deck exists or chat history is full.
197
  """
198
  if not self.last_response:
199
+ raise ValueError('You must generate a slide deck before you can revise it.')
200
 
201
  if len(self.chat_history.messages) >= 16:
202
+ raise ValueError('Chat history is full. Please reset to continue.')
203
 
204
  self.chat_history.add_user_message(instructions)
205
 
206
  prompt_template = self._get_prompt_template(is_refinement=True)
207
 
208
+ list_of_msgs = [
209
+ f'{idx + 1}. {msg.content}'
210
+ for idx, msg in enumerate(self.chat_history.messages) if msg.role == 'user'
211
+ ]
212
 
213
  additional_info = ''
214
  if self.pdf_path_or_stream:
 
220
  additional_info=additional_info,
221
  )
222
 
223
+ llm = self._initialize_llm()
224
+ response = _stream_llm_response(llm, formatted_template, progress_callback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  self.last_response = text_helper.get_clean_json(response)
227
  self.chat_history.add_ai_message(self.last_response)
 
232
  """
233
  Create a slide deck and return the file path.
234
 
235
+ Args:
236
+ json_str: The content in valid JSON format.
237
+
238
+ Returns:
239
+ The path to the .pptx file or None in case of error.
240
  """
241
  try:
242
  parsed_data = json5.loads(json_str)
243
  except (ValueError, RecursionError) as e:
244
+ logger.error('Error parsing JSON: %s', e)
245
  try:
246
  parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
247
  except (ValueError, RecursionError) as e2:
248
+ logger.error('Error parsing fixed JSON: %s', e2)
249
  return None
250
 
251
  temp = tempfile.NamedTemporaryFile(delete=False, suffix='.pptx')
 
255
  try:
256
  pptx_helper.generate_powerpoint_presentation(
257
  parsed_data,
258
+ slides_template=VALID_TEMPLATE_NAMES[self.template_idx],
259
  output_file_path=path
260
  )
261
  except Exception as ex:
 
266
 
267
  def set_template(self, idx):
268
  """
269
+ Set the PowerPoint template to use.
270
 
271
+ Args:
272
+ idx: The index of the template to use.
273
  """
274
+ num_templates = len(GlobalConfig.PPTX_TEMPLATE_FILES)
275
+ self.template_idx = idx if 0 <= idx < num_templates else 0
276
 
277
  def reset(self):
278
  """
279
+ Reset the chat history and internal state.
280
  """
281
  self.chat_history = ChatMessageHistory()
282
  self.last_response = None
283
+ self.template_idx = 0
284
+ self.topic = 0