Spaces:
Sleeping
Sleeping
| import pytest | |
| import numpy as np | |
| from unittest.mock import MagicMock | |
| from app.engine import PromptSearchEngine | |
| def mock_prompts(): | |
| return ["prompt 1", "prompt 2", "prompt 3"] | |
| def mock_model(): | |
| model = MagicMock() | |
| model.encode = MagicMock(return_value=np.array([ | |
| [0.1, 0.2, 0.3], | |
| [0.4, 0.5, 0.6], | |
| [0.7, 0.8, 0.9] | |
| ])) | |
| return model | |
| def test_engine_initialization(mock_prompts, mock_model, monkeypatch): | |
| # SentenceTransformer is mocked to return the mock model | |
| monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) | |
| engine = PromptSearchEngine(mock_prompts) | |
| # Verify that the engine initializes correctly with the mock prompts and vectors | |
| assert engine.prompts == mock_prompts | |
| assert engine.corpus_vectors.shape == (3, 3) | |
| assert np.array_equal( | |
| engine.corpus_vectors, | |
| np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) | |
| ) | |
| def test_most_similar_valid_query(mock_prompts, mock_model, monkeypatch): | |
| monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) | |
| engine = PromptSearchEngine(mock_prompts) | |
| # Mock the vectorizer's transform method to return a single query vector | |
| engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]])) | |
| results = engine.most_similar("test query", n=2) | |
| assert len(results) == 2 | |
| assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) | |
| def test_most_similar_exceeding_n(mock_prompts, mock_model, monkeypatch): | |
| monkeypatch.setattr("app.engine.SentenceTransformer", MagicMock(return_value=mock_model)) | |
| engine = PromptSearchEngine(mock_prompts) | |
| engine.vectorizer.transform = MagicMock(return_value=np.array([[0.1, 0.2, 0.3]])) | |
| # Call most_similar with n greater than the number of prompts | |
| results = engine.most_similar("test query", n=10) | |
| assert len(results) == len(mock_prompts) # Should return at most the number of prompts | |
| assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) | |
| def test_most_similar_integration(mock_prompts): | |
| engine = PromptSearchEngine(mock_prompts) | |
| results = engine.most_similar("prompt 1", n=2) | |
| # Verify that the results include the expected number of matches and correct types | |
| assert len(results) == 2 | |
| assert all(isinstance(score, float) and isinstance(prompt, str) for score, prompt in results) | |
| assert results[0][1] == "prompt 1" | |