barunsaha commited on
Commit
9b00e1e
·
1 Parent(s): 5713b5e

Increase the test coverage of LLM helper

Browse files
Files changed (1) hide show
  1. tests/unit/test_llm_helper.py +266 -0
tests/unit/test_llm_helper.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for llm_helper module.
3
+ """
4
+ from unittest.mock import patch, MagicMock
5
+
6
+ import pytest
7
+
8
+ from slidedeckai.helpers.llm_helper import (
9
+ get_provider_model,
10
+ is_valid_llm_provider_model,
11
+ get_litellm_model_name,
12
+ stream_litellm_completion,
13
+ get_litellm_llm,
14
+ )
15
+ from slidedeckai.global_config import GlobalConfig
16
+
17
+
18
+ @pytest.mark.parametrize(
19
+ 'provider_model, use_ollama, expected',
20
+ [
21
+ ('[co]command', False, ('co', 'command')),
22
+ ('[gg]gemini-pro', False, ('gg', 'gemini-pro')),
23
+ ('[or]gpt-4', False, ('or', 'gpt-4')),
24
+ ('mistral', True, (GlobalConfig.PROVIDER_OLLAMA, 'mistral')),
25
+ ('llama2', True, (GlobalConfig.PROVIDER_OLLAMA, 'llama2')),
26
+ ('invalid[]model', False, ('', '')),
27
+ ('', False, ('', '')),
28
+ ('[invalid]model', False, ('', '')),
29
+ ('[hf]mistral', False, ('', '')), # hf is not in VALID_PROVIDERS
30
+ ],
31
+ )
32
+ def test_get_provider_model(provider_model, use_ollama, expected):
33
+ """Test get_provider_model with various inputs."""
34
+ result = get_provider_model(provider_model, use_ollama)
35
+ assert result == expected
36
+
37
+
38
+ @pytest.mark.parametrize(
39
+ (
40
+ 'provider, model, api_key, azure_endpoint_url,'
41
+ ' azure_deployment_name, azure_api_version, expected'
42
+ ),
43
+ [
44
+ # Valid non-Azure cases
45
+ ('co', 'command', 'valid-key-12345', '', '', '', True),
46
+ ('gg', 'gemini-pro', 'valid-key-12345', '', '', '', True),
47
+ ('or', 'gpt-4', 'valid-key-12345', '', '', '', True),
48
+ # Invalid cases
49
+ ('', 'model', 'key', '', '', '', False),
50
+ ('invalid', 'model', 'key', '', '', '', False),
51
+ ('co', '', 'key', '', '', '', False),
52
+ ('co', 'model', '', '', '', '', False),
53
+ ('co', 'model', 'short', '', '', '', False),
54
+ # Ollama cases (no API key needed)
55
+ (GlobalConfig.PROVIDER_OLLAMA, 'llama2', '', '', '', '', True),
56
+ # Azure cases
57
+ (
58
+ GlobalConfig.PROVIDER_AZURE_OPENAI,
59
+ 'gpt-4',
60
+ 'valid-key-12345',
61
+ 'https://valid.azure.com',
62
+ 'deployment1',
63
+ '2024-02-01',
64
+ True,
65
+ ),
66
+ (
67
+ GlobalConfig.PROVIDER_AZURE_OPENAI,
68
+ 'gpt-4',
69
+ 'valid-key-12345',
70
+ 'https://invalid-url',
71
+ 'deployment1',
72
+ '2024-02-01',
73
+ True, # URL validation is not done
74
+ ),
75
+ (
76
+ GlobalConfig.PROVIDER_AZURE_OPENAI,
77
+ 'gpt-4',
78
+ 'valid-key-12345',
79
+ 'https://valid.azure.com',
80
+ '',
81
+ '2024-02-01',
82
+ False,
83
+ ),
84
+ ],
85
+ )
86
+ def test_is_valid_llm_provider_model(
87
+ provider,
88
+ model,
89
+ api_key,
90
+ azure_endpoint_url,
91
+ azure_deployment_name,
92
+ azure_api_version,
93
+ expected,
94
+ ):
95
+ """Test is_valid_llm_provider_model with various inputs."""
96
+ result = is_valid_llm_provider_model(
97
+ provider,
98
+ model,
99
+ api_key,
100
+ azure_endpoint_url,
101
+ azure_deployment_name,
102
+ azure_api_version,
103
+ )
104
+ assert result == expected
105
+
106
+
107
+ @pytest.mark.parametrize(
108
+ 'provider, model, expected',
109
+ [
110
+ (GlobalConfig.PROVIDER_HUGGING_FACE, 'mistral', 'huggingface/mistral'),
111
+ (GlobalConfig.PROVIDER_GOOGLE_GEMINI, 'gemini-pro', 'gemini/gemini-pro'),
112
+ (GlobalConfig.PROVIDER_OPENROUTER, 'openai/gpt-4', 'openrouter/openai/gpt-4'),
113
+ (GlobalConfig.PROVIDER_COHERE, 'command', 'cohere/command'),
114
+ (GlobalConfig.PROVIDER_TOGETHER_AI, 'llama2', 'together_ai/llama2'),
115
+ (GlobalConfig.PROVIDER_OLLAMA, 'mistral', 'ollama/mistral'),
116
+ ('invalid', 'model', None),
117
+ ],
118
+ )
119
+ def test_get_litellm_model_name(provider, model, expected):
120
+ """Test get_litellm_model_name with various providers and models."""
121
+ result = get_litellm_model_name(provider, model)
122
+ assert result == expected
123
+
124
+
125
+ @patch('slidedeckai.helpers.llm_helper.litellm')
126
+ def test_stream_litellm_completion_success(mock_litellm):
127
+ """Test successful streaming completion."""
128
+ # Mock response chunks
129
+ mock_chunk1 = MagicMock()
130
+ mock_chunk1.choices = [
131
+ MagicMock(delta=MagicMock(content='Hello')),
132
+ ]
133
+ mock_chunk2 = MagicMock()
134
+ mock_chunk2.choices = [
135
+ MagicMock(delta=MagicMock(content=' world')),
136
+ ]
137
+ mock_litellm.completion.return_value = [mock_chunk1, mock_chunk2]
138
+
139
+ messages = [{'role': 'user', 'content': 'Say hello'}]
140
+ result = list(
141
+ stream_litellm_completion(
142
+ provider='hf',
143
+ model='mistral',
144
+ messages=messages,
145
+ max_tokens=100,
146
+ api_key='test-key',
147
+ )
148
+ )
149
+
150
+ assert result == ['Hello', ' world']
151
+ mock_litellm.completion.assert_called_once()
152
+
153
+
154
+ @patch('slidedeckai.helpers.llm_helper.litellm')
155
+ def test_stream_litellm_completion_azure(mock_litellm):
156
+ """Test streaming completion with Azure OpenAI."""
157
+ mock_chunk = MagicMock()
158
+ mock_chunk.choices = [
159
+ MagicMock(delta=MagicMock(content='Response')),
160
+ ]
161
+ mock_litellm.completion.return_value = [mock_chunk]
162
+
163
+ messages = [{'role': 'user', 'content': 'Test'}]
164
+ result = list(
165
+ stream_litellm_completion(
166
+ provider=GlobalConfig.PROVIDER_AZURE_OPENAI,
167
+ model='gpt-4',
168
+ messages=messages,
169
+ max_tokens=100,
170
+ api_key='test-key',
171
+ azure_endpoint_url='https://test.azure.com',
172
+ azure_deployment_name='deployment1',
173
+ azure_api_version='2024-02-01',
174
+ )
175
+ )
176
+
177
+ assert result == ['Response']
178
+ mock_litellm.completion.assert_called_once()
179
+
180
+
181
+ @patch('slidedeckai.helpers.llm_helper.litellm')
182
+ def test_stream_litellm_completion_error(mock_litellm):
183
+ """Test error handling in streaming completion."""
184
+ mock_litellm.completion.side_effect = Exception('API Error')
185
+
186
+ messages = [{'role': 'user', 'content': 'Test'}]
187
+ with pytest.raises(Exception) as exc_info:
188
+ list(
189
+ stream_litellm_completion(
190
+ provider='hf',
191
+ model='mistral',
192
+ messages=messages,
193
+ max_tokens=100,
194
+ api_key='test-key',
195
+ )
196
+ )
197
+ assert str(exc_info.value) == 'API Error'
198
+
199
+
200
+ @patch('slidedeckai.helpers.llm_helper.stream_litellm_completion')
201
+ def test_get_litellm_llm(mock_stream):
202
+ """Test LiteLLM wrapper creation and streaming."""
203
+ mock_stream.return_value = iter(['Hello', ' world'])
204
+
205
+ llm = get_litellm_llm(
206
+ provider='hf',
207
+ model='mistral',
208
+ max_new_tokens=100,
209
+ api_key='test-key',
210
+ )
211
+
212
+ result = list(llm.stream('Say hello'))
213
+ assert result == ['Hello', ' world']
214
+ mock_stream.assert_called_once()
215
+
216
+
217
+ def test_litellm_not_installed():
218
+ """Test behavior when LiteLLM is not installed."""
219
+ with patch('slidedeckai.helpers.llm_helper.litellm', None) as mock_litellm:
220
+ from slidedeckai.helpers.llm_helper import stream_litellm_completion
221
+
222
+ with pytest.raises(ImportError) as exc_info:
223
+ # Try to use stream_litellm_completion which requires LiteLLM
224
+ list(stream_litellm_completion(
225
+ provider='co',
226
+ model='command',
227
+ messages=[],
228
+ max_tokens=100,
229
+ api_key='test-key'
230
+ ))
231
+
232
+ assert 'LiteLLM is not installed' in str(exc_info.value)
233
+
234
+
235
+ @patch('slidedeckai.helpers.llm_helper.litellm')
236
+ def test_stream_litellm_completion_message_format(mock_litellm):
237
+ """Test handling different message format in streaming response."""
238
+ # Test message format instead of delta format
239
+ mock_chunk = MagicMock()
240
+ mock_delta = MagicMock()
241
+ mock_delta.content = None # First chunk has no content
242
+ mock_choices = [MagicMock(delta=mock_delta)]
243
+ mock_chunk.choices = mock_choices
244
+
245
+ # Second chunk with content
246
+ mock_chunk2 = MagicMock()
247
+ mock_delta2 = MagicMock()
248
+ mock_delta2.content = 'Alternative format'
249
+ mock_choices2 = [MagicMock(delta=mock_delta2)]
250
+ mock_chunk2.choices = mock_choices2
251
+
252
+ mock_litellm.completion.return_value = [mock_chunk, mock_chunk2]
253
+
254
+ messages = [{'role': 'user', 'content': 'Test'}]
255
+ result = list(
256
+ stream_litellm_completion(
257
+ provider='hf',
258
+ model='mistral',
259
+ messages=messages,
260
+ max_tokens=100,
261
+ api_key='test-key',
262
+ )
263
+ )
264
+
265
+ assert result == ['Alternative format']
266
+ mock_litellm.completion.assert_called_once()