cosmo_backend/app/services/task_service.py

142 lines
3.8 KiB
Python
Raw Normal View History

2025-12-02 06:29:38 +00:00
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from typing import Optional, Dict, Any
from datetime import datetime
import logging
import asyncio
from app.models.db import Task
from app.services.redis_cache import redis_cache
logger = logging.getLogger(__name__)
class TaskService:
def __init__(self):
self.redis_prefix = "task:progress:"
async def create_task(
self,
db: AsyncSession,
task_type: str,
description: str = None,
params: Dict[str, Any] = None,
created_by: int = None
) -> Task:
"""Create a new task record"""
task = Task(
task_type=task_type,
description=description,
params=params,
status="pending",
created_by=created_by,
progress=0
)
db.add(task)
await db.commit()
await db.refresh(task)
# Init Redis status
await self._update_redis(task.id, 0, "pending")
return task
async def update_progress(
self,
db: AsyncSession,
task_id: int,
progress: int,
status: str = "running"
):
"""Update task progress in DB and Redis"""
# Update DB
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
progress=progress,
status=status,
started_at=datetime.utcnow() if status == "running" and progress == 0 else None
)
)
await db.execute(stmt)
await db.commit()
# Update Redis for fast polling
await self._update_redis(task_id, progress, status)
async def complete_task(
self,
db: AsyncSession,
task_id: int,
result: Dict[str, Any] = None
):
"""Mark task as completed"""
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
status="completed",
progress=100,
completed_at=datetime.utcnow(),
result=result
)
)
await db.execute(stmt)
await db.commit()
await self._update_redis(task_id, 100, "completed")
async def fail_task(
self,
db: AsyncSession,
task_id: int,
error_message: str
):
"""Mark task as failed"""
stmt = (
update(Task)
.where(Task.id == task_id)
.values(
status="failed",
completed_at=datetime.utcnow(),
error_message=error_message
)
)
await db.execute(stmt)
await db.commit()
await self._update_redis(task_id, -1, "failed", error=error_message)
async def get_task(self, db: AsyncSession, task_id: int) -> Optional[Task]:
"""Get task from DB"""
result = await db.execute(select(Task).where(Task.id == task_id))
return result.scalar_one_or_none()
async def _update_redis(
self,
task_id: int,
progress: int,
status: str,
error: str = None
):
"""Update transient state in Redis"""
key = f"{self.redis_prefix}{task_id}"
data = {
"id": task_id,
"progress": progress,
"status": status,
"updated_at": datetime.utcnow().isoformat()
}
if error:
data["error"] = error
# Set TTL for 1 hour
await redis_cache.set(key, data, ttl_seconds=3600)
async def get_task_progress_from_redis(self, task_id: int) -> Optional[Dict]:
"""Get real-time progress from Redis"""
key = f"{self.redis_prefix}{task_id}"
return await redis_cache.get(key)
task_service = TaskService()