-
Notifications
You must be signed in to change notification settings - Fork 2.9k
perf: optimize hot paths with caching and O(1) operations #1816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,13 @@ | |
| from mcp.shared.experimental.tasks.store import TaskStore | ||
| from mcp.types import Result, Task, TaskMetadata, TaskStatus | ||
|
|
||
| CLEANUP_INTERVAL_SECONDS = 1.0 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this would be a breaking change? Setting a throttling limit on the cleanup may not resolve all expired tasks and is timing-dependent, but this makes sense if the tasks are accessed very frequently (though I'm not sure what the benchmarks on this look like). How often does this happen/are there use cases that would be addressed with throttling cleanup? |
||
|
|
||
|
|
||
| @dataclass | ||
| class StoredTask: | ||
| """Internal storage representation of a task.""" | ||
|
|
||
| task: Task | ||
| result: Result | None = None | ||
| # Time when this task should be removed (None = never) | ||
| expires_at: datetime | None = field(default=None) | ||
|
|
||
|
|
||
|
|
@@ -49,21 +48,26 @@ def __init__(self, page_size: int = 10) -> None: | |
| self._tasks: dict[str, StoredTask] = {} | ||
| self._page_size = page_size | ||
| self._update_events: dict[str, anyio.Event] = {} | ||
| self._last_cleanup: datetime | None = None | ||
|
|
||
| def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: | ||
| """Calculate expiry time from TTL in milliseconds.""" | ||
| if ttl_ms is None: | ||
| return None | ||
| return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms) | ||
|
|
||
| def _is_expired(self, stored: StoredTask) -> bool: | ||
| """Check if a task has expired.""" | ||
| if stored.expires_at is None: | ||
| return False | ||
| return datetime.now(timezone.utc) >= stored.expires_at | ||
|
|
||
| def _cleanup_expired(self) -> None: | ||
| """Remove all expired tasks. Called lazily during access operations.""" | ||
| now = datetime.now(timezone.utc) | ||
| if self._last_cleanup is not None: | ||
| elapsed = (now - self._last_cleanup).total_seconds() | ||
| if elapsed < CLEANUP_INTERVAL_SECONDS: | ||
| return | ||
|
|
||
| self._last_cleanup = now | ||
| expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)] | ||
| for task_id in expired_ids: | ||
| del self._tasks[task_id] | ||
|
|
@@ -73,34 +77,21 @@ async def create_task( | |
| metadata: TaskMetadata, | ||
| task_id: str | None = None, | ||
| ) -> Task: | ||
| """Create a new task with the given metadata.""" | ||
| # Cleanup expired tasks on access | ||
| self._cleanup_expired() | ||
|
|
||
| task = create_task_state(metadata, task_id) | ||
|
|
||
| if task.taskId in self._tasks: | ||
| raise ValueError(f"Task with ID {task.taskId} already exists") | ||
|
|
||
| stored = StoredTask( | ||
| task=task, | ||
| expires_at=self._calculate_expiry(metadata.ttl), | ||
| ) | ||
| stored = StoredTask(task=task, expires_at=self._calculate_expiry(metadata.ttl)) | ||
| self._tasks[task.taskId] = stored | ||
|
|
||
| # Return a copy to prevent external modification | ||
| return Task(**task.model_dump()) | ||
|
|
||
| async def get_task(self, task_id: str) -> Task | None: | ||
| """Get a task by ID.""" | ||
| # Cleanup expired tasks on access | ||
| self._cleanup_expired() | ||
|
|
||
| stored = self._tasks.get(task_id) | ||
| if stored is None: | ||
| return None | ||
|
|
||
| # Return a copy to prevent external modification | ||
| return Task(**stored.task.model_dump()) | ||
|
|
||
| async def update_task( | ||
|
|
@@ -109,12 +100,10 @@ async def update_task( | |
| status: TaskStatus | None = None, | ||
| status_message: str | None = None, | ||
| ) -> Task: | ||
| """Update a task's status and/or message.""" | ||
| stored = self._tasks.get(task_id) | ||
| if stored is None: | ||
| raise ValueError(f"Task with ID {task_id} not found") | ||
|
|
||
| # Per spec: Terminal states MUST NOT transition to any other status | ||
| if status is not None and status != stored.task.status and is_terminal(stored.task.status): | ||
| raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'") | ||
|
|
||
|
|
@@ -126,94 +115,69 @@ async def update_task( | |
| if status_message is not None: | ||
| stored.task.statusMessage = status_message | ||
|
|
||
| # Update lastUpdatedAt on any change | ||
| stored.task.lastUpdatedAt = datetime.now(timezone.utc) | ||
|
|
||
| # If task is now terminal and has TTL, reset expiry timer | ||
| if status is not None and is_terminal(status) and stored.task.ttl is not None: | ||
| stored.expires_at = self._calculate_expiry(stored.task.ttl) | ||
|
|
||
| # Notify waiters if status changed | ||
| if status_changed: | ||
| await self.notify_update(task_id) | ||
|
|
||
| return Task(**stored.task.model_dump()) | ||
|
|
||
| async def store_result(self, task_id: str, result: Result) -> None: | ||
| """Store the result for a task.""" | ||
| stored = self._tasks.get(task_id) | ||
| if stored is None: | ||
| raise ValueError(f"Task with ID {task_id} not found") | ||
|
|
||
| stored.result = result | ||
|
|
||
| async def get_result(self, task_id: str) -> Result | None: | ||
| """Get the stored result for a task.""" | ||
| stored = self._tasks.get(task_id) | ||
| if stored is None: | ||
| return None | ||
|
|
||
| return stored.result | ||
| return stored.result if stored else None | ||
|
|
||
| async def list_tasks( | ||
| self, | ||
| cursor: str | None = None, | ||
| ) -> tuple[list[Task], str | None]: | ||
| """List tasks with pagination.""" | ||
| # Cleanup expired tasks on access | ||
| self._cleanup_expired() | ||
|
|
||
| all_task_ids = list(self._tasks.keys()) | ||
|
|
||
| start_index = 0 | ||
| if cursor is not None: | ||
| try: | ||
| cursor_index = all_task_ids.index(cursor) | ||
| start_index = cursor_index + 1 | ||
| start_index = all_task_ids.index(cursor) + 1 | ||
| except ValueError: | ||
| raise ValueError(f"Invalid cursor: {cursor}") | ||
|
|
||
| page_task_ids = all_task_ids[start_index : start_index + self._page_size] | ||
| tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids] | ||
|
|
||
| # Determine next cursor | ||
| next_cursor = None | ||
| if start_index + self._page_size < len(all_task_ids) and page_task_ids: | ||
| next_cursor = page_task_ids[-1] | ||
|
|
||
| return tasks, next_cursor | ||
|
|
||
| async def delete_task(self, task_id: str) -> bool: | ||
| """Delete a task.""" | ||
| if task_id not in self._tasks: | ||
| return False | ||
|
|
||
| del self._tasks[task_id] | ||
| return True | ||
|
|
||
| async def wait_for_update(self, task_id: str) -> None: | ||
| """Wait until the task status changes.""" | ||
| if task_id not in self._tasks: | ||
| raise ValueError(f"Task with ID {task_id} not found") | ||
|
|
||
| # Create a fresh event for waiting (anyio.Event can't be cleared) | ||
| self._update_events[task_id] = anyio.Event() | ||
| event = self._update_events[task_id] | ||
| await event.wait() | ||
| await self._update_events[task_id].wait() | ||
|
|
||
| async def notify_update(self, task_id: str) -> None: | ||
| """Signal that a task has been updated.""" | ||
| if task_id in self._update_events: | ||
| self._update_events[task_id].set() | ||
|
|
||
| # --- Testing/debugging helpers --- | ||
|
|
||
| def cleanup(self) -> None: | ||
| """Cleanup all tasks (useful for testing or graceful shutdown).""" | ||
| self._tasks.clear() | ||
| self._update_events.clear() | ||
|
|
||
| def get_all_tasks(self) -> list[Task]: | ||
| """Get all tasks (useful for debugging). Returns copies to prevent modification.""" | ||
| self._cleanup_expired() | ||
| return [Task(**stored.task.model_dump()) for stored in self._tasks.values()] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: iiuc, these changes optimize repeated calls to pre_parse_json. I'm not sure what the protocol is for optimization-related tests, but it could be useful to add a small benchmark in the description on a ~20 field pydantic model and call pre_parse_json repeatedly (50k-100k before vs after).