Spaces:
Runtime error
Runtime error
| # Test_SQLite_DB.py | |
| # Description: Test file for SQLite_DB.py | |
| # | |
| # Usage: python -m unittest test_sqlite_db.py | |
| # | |
| # Imports | |
| import unittest | |
| import sqlite3 | |
| import threading | |
| import time | |
| from unittest.mock import patch | |
| # | |
| # Local Imports | |
| from App_Function_Libraries.DB.SQLite_DB import Database, add_media_with_keywords, add_media_version, DatabaseError | |
| # | |
| ####################################################################################################################### | |
| # | |
| # Functions: | |
| class TestDatabase(unittest.TestCase): | |
| def setUp(self): | |
| self.db = Database(':memory:') # Use in-memory database for testing | |
| def test_connection_management(self): | |
| with self.db.get_connection() as conn: | |
| self.assertIsInstance(conn, sqlite3.Connection) | |
| self.assertEqual(len(self.db.pool), 1) | |
| def test_execute_query(self): | |
| self.db.execute_query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") | |
| self.db.execute_query("INSERT INTO test (name) VALUES (?)", ("test_name",)) | |
| with self.db.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT name FROM test") | |
| result = cursor.fetchone() | |
| self.assertEqual(result[0], "test_name") | |
| def test_execute_many(self): | |
| self.db.execute_query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") | |
| data = [("name1",), ("name2",), ("name3",)] | |
| self.db.execute_many("INSERT INTO test (name) VALUES (?)", data) | |
| with self.db.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM test") | |
| count = cursor.fetchone()[0] | |
| self.assertEqual(count, 3) | |
| def test_connection_retry(self): | |
| def lock_database(): | |
| with self.db.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("BEGIN EXCLUSIVE TRANSACTION") | |
| time.sleep(2) # Hold the lock for 2 seconds | |
| thread = threading.Thread(target=lock_database) | |
| thread.start() | |
| time.sleep(0.1) # Give the thread time to acquire the lock | |
| with self.assertRaises(DatabaseError): | |
| self.db.execute_query("SELECT 1") # This should retry and eventually fail | |
| thread.join() | |
| class TestAddMediaWithKeywords(unittest.TestCase): | |
| def setUp(self): | |
| self.db = Database(':memory:') | |
| self.db.execute_query(""" | |
| CREATE TABLE Media ( | |
| id INTEGER PRIMARY KEY, | |
| url TEXT, | |
| title TEXT NOT NULL, | |
| type TEXT NOT NULL, | |
| content TEXT, | |
| author TEXT, | |
| ingestion_date TEXT, | |
| transcription_model TEXT | |
| ) | |
| """) | |
| self.db.execute_query("CREATE TABLE Keywords (id INTEGER PRIMARY KEY, keyword TEXT NOT NULL UNIQUE)") | |
| self.db.execute_query(""" | |
| CREATE TABLE MediaKeywords ( | |
| id INTEGER PRIMARY KEY, | |
| media_id INTEGER NOT NULL, | |
| keyword_id INTEGER NOT NULL, | |
| FOREIGN KEY (media_id) REFERENCES Media(id), | |
| FOREIGN KEY (keyword_id) REFERENCES Keywords(id) | |
| ) | |
| """) | |
| self.db.execute_query(""" | |
| CREATE TABLE MediaModifications ( | |
| id INTEGER PRIMARY KEY, | |
| media_id INTEGER NOT NULL, | |
| prompt TEXT, | |
| summary TEXT, | |
| modification_date TEXT, | |
| FOREIGN KEY (media_id) REFERENCES Media(id) | |
| ) | |
| """) | |
| self.db.execute_query(""" | |
| CREATE TABLE MediaVersion ( | |
| id INTEGER PRIMARY KEY, | |
| media_id INTEGER NOT NULL, | |
| version INTEGER NOT NULL, | |
| prompt TEXT, | |
| summary TEXT, | |
| created_at TEXT NOT NULL, | |
| FOREIGN KEY (media_id) REFERENCES Media(id) | |
| ) | |
| """) | |
| self.db.execute_query("CREATE VIRTUAL TABLE media_fts USING fts5(title, content)") | |
| def test_add_new_media(self, mock_db): | |
| mock_db.get_connection = self.db.get_connection | |
| result = add_media_with_keywords( | |
| url="http://example.com", | |
| title="Test Title", | |
| media_type="article", | |
| content="Test content", | |
| keywords="test,keyword", | |
| prompt="Test prompt", | |
| summary="Test summary", | |
| transcription_model="Test model", | |
| author="Test Author", | |
| ingestion_date="2023-01-01" | |
| ) | |
| self.assertIn("added/updated successfully", result) | |
| with self.db.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM Media") | |
| self.assertEqual(cursor.fetchone()[0], 1) | |
| cursor.execute("SELECT COUNT(*) FROM Keywords") | |
| self.assertEqual(cursor.fetchone()[0], 2) | |
| cursor.execute("SELECT COUNT(*) FROM MediaKeywords") | |
| self.assertEqual(cursor.fetchone()[0], 2) | |
| cursor.execute("SELECT COUNT(*) FROM MediaModifications") | |
| self.assertEqual(cursor.fetchone()[0], 1) | |
| cursor.execute("SELECT COUNT(*) FROM MediaVersion") | |
| self.assertEqual(cursor.fetchone()[0], 1) | |
| def test_update_existing_media(self, mock_db): | |
| mock_db.get_connection = self.db.get_connection | |
| add_media_with_keywords( | |
| url="http://example.com", | |
| title="Test Title", | |
| media_type="article", | |
| content="Test content", | |
| keywords="test,keyword", | |
| prompt="Test prompt", | |
| summary="Test summary", | |
| transcription_model="Test model", | |
| author="Test Author", | |
| ingestion_date="2023-01-01" | |
| ) | |
| result = add_media_with_keywords( | |
| url="http://example.com", | |
| title="Updated Title", | |
| media_type="article", | |
| content="Updated content", | |
| keywords="test,new", | |
| prompt="Updated prompt", | |
| summary="Updated summary", | |
| transcription_model="Updated model", | |
| author="Updated Author", | |
| ingestion_date="2023-01-02" | |
| ) | |
| self.assertIn("added/updated successfully", result) | |
| with self.db.get_connection() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM Media") | |
| self.assertEqual(cursor.fetchone()[0], 1) | |
| cursor.execute("SELECT title FROM Media") | |
| self.assertEqual(cursor.fetchone()[0], "Updated Title") | |
| cursor.execute("SELECT COUNT(*) FROM Keywords") | |
| self.assertEqual(cursor.fetchone()[0], 3) | |
| cursor.execute("SELECT COUNT(*) FROM MediaKeywords") | |
| self.assertEqual(cursor.fetchone()[0], 3) | |
| cursor.execute("SELECT COUNT(*) FROM MediaModifications") | |
| self.assertEqual(cursor.fetchone()[0], 2) | |
| cursor.execute("SELECT COUNT(*) FROM MediaVersion") | |
| self.assertEqual(cursor.fetchone()[0], 2) | |
| if __name__ == '__main__': | |
| unittest.main() | |
| # | |
| # End of File | |
| ####################################################################################################################### | |