barunsaha commited on
Commit
e1720ec
·
1 Parent(s): 33f121d

Improve the formatting of displayed model names

Browse files
Files changed (1) hide show
  1. src/slidedeckai/cli.py +163 -8
src/slidedeckai/cli.py CHANGED
@@ -4,16 +4,167 @@ Command-line interface for SlideDeck AI.
4
  import argparse
5
  import sys
6
  import shutil
 
7
 
8
  from slidedeckai.core import SlideDeckAI
9
  from slidedeckai.global_config import GlobalConfig
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def main():
13
  """
14
  The main function for the CLI.
15
  """
16
- parser = argparse.ArgumentParser(description='Generate slide decks with SlideDeck AI.')
 
 
 
17
  subparsers = parser.add_subparsers(dest='command')
18
 
19
  # Top-level flag to list supported models
@@ -25,15 +176,21 @@ def main():
25
  )
26
 
27
  # 'generate' command
28
- parser_generate = subparsers.add_parser('generate', help='Generate a new slide deck.')
 
 
 
 
 
29
  parser_generate.add_argument(
30
  '--model',
31
  required=True,
 
32
  help=(
33
- 'Model name to use. The model must be one of the supported models;'
34
- ' see `--list-models` for details.'
35
- ' Model name must be in the `[provider-code]model_name` format.'
36
  ),
 
37
  )
38
  parser_generate.add_argument(
39
  '--topic',
@@ -69,9 +226,7 @@ def main():
69
 
70
  # If --list-models flag was provided, print models and exit
71
  if getattr(args, 'list_models', False):
72
- print('Supported SlideDeck AI models (these are the only supported models):')
73
- for k in GlobalConfig.VALID_MODELS:
74
- print(k)
75
  return
76
 
77
  if args.command == 'generate':
 
4
  import argparse
5
  import sys
6
  import shutil
7
+ from typing import Any
8
 
9
  from slidedeckai.core import SlideDeckAI
10
  from slidedeckai.global_config import GlobalConfig
11
 
12
 
13
+ class CustomHelpFormatter(argparse.HelpFormatter):
14
+ """
15
+ Custom formatter for argparse that improves the display of choices.
16
+ """
17
+ def _format_action_invocation(self, action: Any) -> str:
18
+ if not action.option_strings or action.nargs == 0:
19
+ return super()._format_action_invocation(action)
20
+
21
+ default = self._get_default_metavar_for_optional(action)
22
+ args_string = self._format_args(action, default)
23
+
24
+ # If there are choices, and it's the model argument, handle it specially
25
+ if action.choices and '--model' in action.option_strings:
26
+ return ', '.join(action.option_strings) + ' MODEL'
27
+
28
+ return f"{', '.join(action.option_strings)} {args_string}"
29
+
30
+ def _split_lines(self, text: str, width: int) -> list[str]:
31
+ if text.startswith('Model choices:') or text.startswith('choose from'):
32
+ # Special handling for model choices and error messages
33
+ lines = []
34
+ header = 'Available models:'
35
+ lines.append(header)
36
+ lines.append('-' * len(header))
37
+
38
+ # Extract models from text
39
+ if text.startswith('choose from'):
40
+ models = [
41
+ m.strip("' ") for m in text.replace('choose from', '').split(',')
42
+ ]
43
+ else:
44
+ models = text.split('\n')[1:]
45
+
46
+ # Group models by provider
47
+ provider_models = {}
48
+ for model in sorted(models):
49
+ if not model.strip():
50
+ continue
51
+ if match := GlobalConfig.PROVIDER_REGEX.match(model):
52
+ provider = match.group(1)
53
+ if provider not in provider_models:
54
+ provider_models[provider] = []
55
+ provider_models[provider].append(model.strip())
56
+
57
+ # Add models grouped by provider
58
+ for provider in sorted(provider_models.keys()):
59
+ lines.append(f'\n{provider}:')
60
+ for model in provider_models[provider]:
61
+ lines.append(f' {model}')
62
+
63
+ return lines
64
+
65
+ return super()._split_lines(text, width)
66
+
67
+
68
+ class CustomArgumentParser(argparse.ArgumentParser):
69
+ """
70
+ Custom argument parser that formats error messages better.
71
+ """
72
+ def error(self, message: str) -> None:
73
+ """Custom error handler that formats model choices better"""
74
+ if 'invalid choice' in message and '--model' in message:
75
+ # Extract models from the error message
76
+ choices_str = message[message.find('(choose from'):]
77
+ models = [m.strip("' ") for m in choices_str.replace('(choose from', '').rstrip(')').split(',')]
78
+
79
+ # Group models by provider
80
+ provider_models = {}
81
+ for model in sorted(models):
82
+ if match := GlobalConfig.PROVIDER_REGEX.match(model):
83
+ provider = match.group(1)
84
+ if provider not in provider_models:
85
+ provider_models[provider] = []
86
+ provider_models[provider].append(model.strip())
87
+
88
+ # Format the error message with grouped models
89
+ error_lines = ['Error: Invalid model choice. Available models:']
90
+ for provider in sorted(provider_models.keys()):
91
+ error_lines.append(f'\n{provider}:')
92
+ for model in sorted(provider_models[provider]):
93
+ error_lines.append(f' • {model}')
94
+
95
+ self.print_help()
96
+ print('\n' + '\n'.join(error_lines), file=sys.stderr)
97
+ sys.exit(2)
98
+
99
+ super().error(message)
100
+
101
+
102
+ def format_models_list() -> str:
103
+ """Format the models list in a nice grouped format with descriptions."""
104
+ lines = ['Supported SlideDeck AI models:', '']
105
+
106
+ # Group models by provider
107
+ provider_models = {}
108
+ for model, info in sorted(GlobalConfig.VALID_MODELS.items()):
109
+ if match := GlobalConfig.PROVIDER_REGEX.match(model):
110
+ provider = match.group(1)
111
+ if provider not in provider_models:
112
+ provider_models[provider] = []
113
+ provider_models[provider].append((model, info))
114
+
115
+ # Add models grouped by provider
116
+ for provider in sorted(provider_models.keys()):
117
+ lines.append(f'{provider}:')
118
+ # Find the longest model name for alignment
119
+ max_model_len = max(len(model) for model, _ in provider_models[provider])
120
+ max_desc_len = max(len(info['description']) for _, info in provider_models[provider])
121
+
122
+ # Format as a table with aligned columns
123
+ format_str = f' {{:<{max_model_len}}} | {{:<{max_desc_len}}} | {{:>4}}'
124
+ lines.append(' ' + '-' * (max_model_len + max_desc_len + 13))
125
+
126
+ for model, info in sorted(provider_models[provider]):
127
+ paid_status = 'Paid' if info.get('paid', False) else 'Free'
128
+ lines.append(format_str.format(
129
+ model,
130
+ info['description'],
131
+ paid_status
132
+ ))
133
+ lines.append('') # Add spacing between provider sections
134
+
135
+ return '\n'.join(lines)
136
+
137
+
138
+ def format_model_help() -> str:
139
+ """Format model choices as a grouped bulleted list for help text."""
140
+ lines = []
141
+
142
+ # Group models by provider
143
+ provider_models = {}
144
+ for model in sorted(GlobalConfig.VALID_MODELS.keys()):
145
+ if match := GlobalConfig.PROVIDER_REGEX.match(model):
146
+ provider = match.group(1)
147
+ if provider not in provider_models:
148
+ provider_models[provider] = []
149
+ provider_models[provider].append(model)
150
+
151
+ # Add models grouped by provider
152
+ for provider in sorted(provider_models.keys()):
153
+ lines.append(f'\n{provider}:')
154
+ for model in sorted(provider_models[provider]):
155
+ lines.append(f' • {model}')
156
+
157
+ return '\n'.join(lines)
158
+
159
+
160
  def main():
161
  """
162
  The main function for the CLI.
163
  """
164
+ parser = CustomArgumentParser(
165
+ description='Generate slide decks with SlideDeck AI.',
166
+ formatter_class=CustomHelpFormatter
167
+ )
168
  subparsers = parser.add_subparsers(dest='command')
169
 
170
  # Top-level flag to list supported models
 
176
  )
177
 
178
  # 'generate' command
179
+ parser_generate = subparsers.add_parser(
180
+ 'generate',
181
+ help='Generate a new slide deck.',
182
+ formatter_class=CustomHelpFormatter
183
+ )
184
+
185
  parser_generate.add_argument(
186
  '--model',
187
  required=True,
188
+ choices=GlobalConfig.VALID_MODELS.keys(),
189
  help=(
190
+ 'Model name to use. Must be one of the supported models in the'
191
+ ' `[provider-code]model_name` format.' + format_model_help()
 
192
  ),
193
+ metavar='MODEL'
194
  )
195
  parser_generate.add_argument(
196
  '--topic',
 
226
 
227
  # If --list-models flag was provided, print models and exit
228
  if getattr(args, 'list_models', False):
229
+ print(format_models_list())
 
 
230
  return
231
 
232
  if args.command == 'generate':