diff --git a/requirements.txt b/requirements.txt index 70bd86a..1bb66c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,6 @@ llama-index-storage-docstore-mongodb==0.1.3 crawlee[playwright]==0.3.8 redis==5.2.0 defusedxml==0.7.1 -temporalio==1.8.0 pydantic==2.9.2 motor>=3.6, <4.0.0 +tc-temporal-backend==1.0.0 diff --git a/test_run_workflow.py b/test_run_workflow.py index c0763cb..858c1fc 100644 --- a/test_run_workflow.py +++ b/test_run_workflow.py @@ -6,6 +6,7 @@ from datetime import timedelta from dotenv import load_dotenv +from tc_temporal_backend.client import TemporalClient from temporalio.client import ( Schedule, ScheduleActionStartWorkflow, @@ -13,7 +14,6 @@ ScheduleSpec, ScheduleState, ) -from utils.temporal_client import TemporalClient async def start_workflow(): diff --git a/tests/unit/test_temporal_client.py b/tests/unit/test_temporal_client.py deleted file mode 100644 index c7d22f5..0000000 --- a/tests/unit/test_temporal_client.py +++ /dev/null @@ -1,99 +0,0 @@ -import unittest -from unittest.mock import AsyncMock, patch - -from temporalio.client import Client -from utils.temporal_client import TemporalClient - - -class TestTemporalClient(unittest.IsolatedAsyncioTestCase): - def setUp(self): - """Set up test environment before each test""" - self.default_env = { - "TEMPORAL_HOST": "test-host", - "TEMPORAL_API_KEY": "test-api-key", - "TEMPORAL_PORT": "8080", - } - self.env_patcher = patch.dict("os.environ", self.default_env) - self.env_patcher.start() - self.client = TemporalClient() - - def tearDown(self): - """Clean up test environment after each test""" - self.env_patcher.stop() - - def test_initialization(self): - """Test class initialization and load_dotenv call""" - with patch("utils.temporal_client.load_dotenv") as mock_load_dotenv: - client = TemporalClient() - mock_load_dotenv.assert_called_once() - - def test_load_credentials_success(self): - """Test successful loading of credentials""" - credentials = self.client._load_credentials() - - self.assertIsInstance(credentials, dict) - self.assertEqual(credentials["host"], "test-host") - self.assertEqual(credentials["api_key"], "test-api-key") - self.assertEqual(credentials["port"], "8080") - - def test_load_credentials_missing_host(self): - """Test handling of missing host""" - with patch.dict("os.environ", {"TEMPORAL_HOST": ""}): - with self.assertRaises(ValueError) as context: - self.client._load_credentials() - self.assertIn("TEMPORAL_HOST", str(context.exception)) - - def test_load_credentials_missing_api_key(self): - """Test handling of missing API key""" - with patch.dict( - "os.environ", {"TEMPORAL_HOST": "", "TEMPORAL_PORT": ""}, clear=True - ): - with self.assertRaises(ValueError) as context: - self.client._load_credentials() - self.assertNotIn("TEMPORAL_API_KEY", str(context.exception)) - - def test_load_credentials_missing_port(self): - """Test handling of missing port""" - with patch.dict("os.environ", {"TEMPORAL_PORT": ""}): - with self.assertRaises(ValueError) as context: - self.client._load_credentials() - self.assertIn("TEMPORAL_PORT", str(context.exception)) - - def test_load_credentials_empty_env(self): - """Test behavior with completely empty environment""" - with patch.dict("os.environ", {}, clear=True): - with self.assertRaises(ValueError) as context: - self.client._load_credentials() - self.assertIn("TEMPORAL_HOST", str(context.exception)) - - async def test_get_client_success(self): - """Test successful client connection""" - mock_client = AsyncMock(spec=Client) - - with patch( - "temporalio.client.Client.connect", new_callable=AsyncMock - ) as mock_connect: - mock_connect.return_value = mock_client - - result = await self.client.get_client() - - # Verify the connection was attempted with correct parameters - mock_connect.assert_called_once_with( - "test-host:8080", api_key="test-api-key" - ) - self.assertEqual(result, mock_client) - - async def test_get_client_connection_error(self): - """Test handling of connection errors""" - with patch( - "temporalio.client.Client.connect", new_callable=AsyncMock - ) as mock_connect: - mock_connect.side_effect = Exception("Connection failed") - - with self.assertRaises(Exception) as context: - await self.client.get_client() - self.assertIn("Connection failed", str(context.exception)) - - -if __name__ == "__main__": - unittest.main() diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/utils/temporal_client.py b/utils/temporal_client.py deleted file mode 100644 index 24f55a4..0000000 --- a/utils/temporal_client.py +++ /dev/null @@ -1,56 +0,0 @@ -import os - -from dotenv import load_dotenv -from temporalio.client import Client - - -class TemporalClient: - def __init__(self) -> None: - load_dotenv() - - async def get_client(self) -> Client: - credentials = self._load_credentials() - - url: str = credentials["host"] + ":" + credentials["port"] - client = await Client.connect(url, api_key=credentials["api_key"]) - - return client - - def _load_credentials(self) -> dict[str, str]: - """ - load the credentials for temporal - - Returns - ------------ - credentials : dict[str, str] - a dictionary holding temporal credentials - { - 'host': str, - 'api_key': str, - 'port': str - } - """ - host = os.getenv("TEMPORAL_HOST") - api_key = os.getenv("TEMPORAL_API_KEY") - port = os.getenv("TEMPORAL_PORT") - - if not host: - raise ValueError( - "`TEMPORAL_HOST` is not configured right in env credentials!" - ) - if not port: - raise ValueError( - "`TEMPORAL_PORT` is not configured right in env credentials!" - ) - if api_key is None: - raise ValueError( - "`TEMPORAL_API_KEY` is not configured right in env credentials!" - ) - - credentials: dict[str, str] = { - "host": host, - "api_key": api_key, - "port": port, - } - - return credentials diff --git a/worker.py b/worker.py index e0c7b2d..1a7b2c2 100644 --- a/worker.py +++ b/worker.py @@ -4,8 +4,8 @@ from dotenv import load_dotenv from registry import ACTIVITIES, WORKFLOWS +from tc_temporal_backend.client import TemporalClient from temporalio.worker import Worker -from utils.temporal_client import TemporalClient async def main():