google-labs-jules[bot] commited on
Commit
70680dd
·
1 Parent(s): f243b4c

fix: Address PR comments

Browse files

- Refactor PDF handling to be internal to SlideDeckAI class.

Files changed (2) hide show
  1. app.py +2 -7
  2. src/slidedeckai/core.py +15 -7
app.py CHANGED
@@ -386,12 +386,6 @@ def set_up_chat_ui():
386
  f' {st.session_state["end_page"]} in {st.session_state["pdf_file"].name}'
387
  )
388
 
389
- # Get pdf contents
390
- st.session_state[ADDITIONAL_INFO] = filem.get_pdf_contents(
391
- st.session_state[PDF_FILE_KEY],
392
- (st.session_state['start_page'], st.session_state['end_page'])
393
- )
394
-
395
  st.chat_message('user').write(prompt_text)
396
 
397
  slide_generator = SlideDeckAI(
@@ -399,7 +393,8 @@ def set_up_chat_ui():
399
  topic=prompt_text,
400
  api_key=api_key_token.strip(),
401
  template_idx=list(GlobalConfig.PPTX_TEMPLATE_FILES.keys()).index(pptx_template),
402
- additional_info=st.session_state.get(ADDITIONAL_INFO, ''),
 
403
  )
404
 
405
  progress_bar = st.progress(0, 'Preparing to call LLM...')
 
386
  f' {st.session_state["end_page"]} in {st.session_state["pdf_file"].name}'
387
  )
388
 
 
 
 
 
 
 
389
  st.chat_message('user').write(prompt_text)
390
 
391
  slide_generator = SlideDeckAI(
 
393
  topic=prompt_text,
394
  api_key=api_key_token.strip(),
395
  template_idx=list(GlobalConfig.PPTX_TEMPLATE_FILES.keys()).index(pptx_template),
396
+ pdf_path_or_stream=st.session_state.get(PDF_FILE_KEY),
397
+ pdf_page_range=(st.session_state.get('start_page'), st.session_state.get('end_page')),
398
  )
399
 
400
  progress_bar = st.progress(0, 'Preparing to call LLM...')
src/slidedeckai/core.py CHANGED
@@ -21,30 +21,30 @@ RUN_IN_OFFLINE_MODE = os.getenv('RUN_IN_OFFLINE_MODE', 'False').lower() == 'true
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
24
  class SlideDeckAI:
25
  """
26
  The main class for generating slide decks.
27
  """
28
 
29
- def __init__(self, model, topic, api_key=None, pdf_file_path=None, pdf_page_range=None, template_idx=0, additional_info=''):
30
  """
31
  Initializes the SlideDeckAI object.
32
 
33
  :param model: The name of the LLM model to use.
34
  :param topic: The topic of the slide deck.
35
  :param api_key: The API key for the LLM provider.
36
- :param pdf_file_path: The path to a PDF file to use as a source for the slide deck.
37
  :param pdf_page_range: A tuple representing the page range to use from the PDF file.
38
  :param template_idx: The index of the PowerPoint template to use.
39
- :param additional_info: Additional information to be sent to the LLM, such as text from a PDF.
40
  """
41
  self.model = model
42
  self.topic = topic
43
  self.api_key = api_key
44
- self.pdf_file_path = pdf_file_path
45
  self.pdf_page_range = pdf_page_range
46
  self.template_idx = template_idx
47
- self.additional_info = additional_info
48
  self.chat_history = ChatMessageHistory()
49
  self.last_response = None
50
 
@@ -68,9 +68,13 @@ class SlideDeckAI:
68
  Generates the initial slide deck.
69
  :return: The path to the generated .pptx file.
70
  """
 
 
 
 
71
  self.chat_history.add_user_message(self.topic)
72
  prompt_template = self._get_prompt_template(is_refinement=False)
73
- formatted_template = prompt_template.format(question=self.topic, additional_info=self.additional_info)
74
 
75
  provider, llm_name = llm_helper.get_provider_model(self.model, use_ollama=RUN_IN_OFFLINE_MODE)
76
 
@@ -118,10 +122,14 @@ class SlideDeckAI:
118
 
119
  list_of_msgs = [f'{idx + 1}. {msg.content}' for idx, msg in enumerate(self.chat_history.messages) if msg.role == 'user']
120
 
 
 
 
 
121
  formatted_template = prompt_template.format(
122
  instructions='\n'.join(list_of_msgs),
123
  previous_content=self.last_response,
124
- additional_info=self.additional_info,
125
  )
126
 
127
  provider, llm_name = llm_helper.get_provider_model(self.model, use_ollama=RUN_IN_OFFLINE_MODE)
 
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
+ from .helpers import file_manager as filem
25
+
26
  class SlideDeckAI:
27
  """
28
  The main class for generating slide decks.
29
  """
30
 
31
+ def __init__(self, model, topic, api_key=None, pdf_path_or_stream=None, pdf_page_range=None, template_idx=0):
32
  """
33
  Initializes the SlideDeckAI object.
34
 
35
  :param model: The name of the LLM model to use.
36
  :param topic: The topic of the slide deck.
37
  :param api_key: The API key for the LLM provider.
38
+ :param pdf_path_or_stream: The path to a PDF file or a file-like object.
39
  :param pdf_page_range: A tuple representing the page range to use from the PDF file.
40
  :param template_idx: The index of the PowerPoint template to use.
 
41
  """
42
  self.model = model
43
  self.topic = topic
44
  self.api_key = api_key
45
+ self.pdf_path_or_stream = pdf_path_or_stream
46
  self.pdf_page_range = pdf_page_range
47
  self.template_idx = template_idx
 
48
  self.chat_history = ChatMessageHistory()
49
  self.last_response = None
50
 
 
68
  Generates the initial slide deck.
69
  :return: The path to the generated .pptx file.
70
  """
71
+ additional_info = ''
72
+ if self.pdf_path_or_stream:
73
+ additional_info = filem.get_pdf_contents(self.pdf_path_or_stream, self.pdf_page_range)
74
+
75
  self.chat_history.add_user_message(self.topic)
76
  prompt_template = self._get_prompt_template(is_refinement=False)
77
+ formatted_template = prompt_template.format(question=self.topic, additional_info=additional_info)
78
 
79
  provider, llm_name = llm_helper.get_provider_model(self.model, use_ollama=RUN_IN_OFFLINE_MODE)
80
 
 
122
 
123
  list_of_msgs = [f'{idx + 1}. {msg.content}' for idx, msg in enumerate(self.chat_history.messages) if msg.role == 'user']
124
 
125
+ additional_info = ''
126
+ if self.pdf_path_or_stream:
127
+ additional_info = filem.get_pdf_contents(self.pdf_path_or_stream, self.pdf_page_range)
128
+
129
  formatted_template = prompt_template.format(
130
  instructions='\n'.join(list_of_msgs),
131
  previous_content=self.last_response,
132
+ additional_info=additional_info,
133
  )
134
 
135
  provider, llm_name = llm_helper.get_provider_model(self.model, use_ollama=RUN_IN_OFFLINE_MODE)