Files
test_task_crm/app/repositories/task_repo.py
T
Artem Kashaev 5fcb574aca
Test / test (push) Successful in 15s
Refactor code for improved readability and consistency
- Reformatted function signatures in `organization_service.py` and `task_service.py` for better alignment.
- Updated import statements across multiple files for consistency and organization.
- Enhanced test files by improving formatting and ensuring consistent use of async session factories.
- Added type hints and improved type safety in various service and test files.
- Adjusted `pyproject.toml` to include configuration for isort, mypy, and ruff for better code quality checks.
- Cleaned up unused imports and organized existing ones in several test files.
2025-12-01 16:18:03 +05:00

127 lines
4.2 KiB
Python

"""Task repository providing role-aware CRUD helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.deal import Deal
from app.models.organization_member import OrganizationRole
from app.models.task import Task, TaskCreate
class TaskAccessError(Exception):
"""Raised when a user attempts to modify a forbidden task."""
class TaskOrganizationMismatchError(Exception):
"""Raised when a task or deal belongs to another organization."""
@dataclass(slots=True)
class TaskQueryParams:
"""Filtering options supported by list queries."""
organization_id: int
deal_id: int | None = None
only_open: bool = False
due_before: datetime | None = None
due_after: datetime | None = None
class TaskRepository:
"""Encapsulates database access for Task entities."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
@property
def session(self) -> AsyncSession:
return self._session
async def list(self, *, params: TaskQueryParams) -> Sequence[Task]:
stmt = (
select(Task)
.join(Deal, Deal.id == Task.deal_id)
.where(Deal.organization_id == params.organization_id)
.options(selectinload(Task.deal))
.order_by(Task.due_date.is_(None), Task.due_date, Task.id)
)
stmt = self._apply_filters(stmt, params)
result = await self._session.scalars(stmt)
return result.all()
async def get(self, task_id: int, *, organization_id: int) -> Task | None:
stmt = (
select(Task)
.join(Deal, Deal.id == Task.deal_id)
.where(Task.id == task_id, Deal.organization_id == organization_id)
.options(selectinload(Task.deal))
)
result = await self._session.scalars(stmt)
return result.first()
async def create(
self,
data: TaskCreate,
*,
organization_id: int,
role: OrganizationRole,
user_id: int,
) -> Task:
deal = await self._session.get(Deal, data.deal_id)
if deal is None or deal.organization_id != organization_id:
raise TaskOrganizationMismatchError("Deal belongs to another organization")
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise TaskAccessError("Members can only create tasks for their own deals")
task = Task(**data.model_dump())
self._session.add(task)
await self._session.flush()
return task
async def update(
self,
task: Task,
updates: Mapping[str, Any],
*,
role: OrganizationRole,
user_id: int,
) -> Task:
owner_id = await self._resolve_task_owner(task)
if owner_id is None:
raise TaskOrganizationMismatchError("Task is missing an owner context")
if role == OrganizationRole.MEMBER and owner_id != user_id:
raise TaskAccessError("Members can only modify their own tasks")
for field, value in updates.items():
if hasattr(task, field):
setattr(task, field, value)
await self._session.flush()
return task
def _apply_filters(
self, stmt: Select[tuple[Task]], params: TaskQueryParams
) -> Select[tuple[Task]]:
if params.deal_id is not None:
stmt = stmt.where(Task.deal_id == params.deal_id)
if params.only_open:
stmt = stmt.where(Task.is_done.is_(False))
if params.due_before is not None:
stmt = stmt.where(Task.due_date <= params.due_before)
if params.due_after is not None:
stmt = stmt.where(Task.due_date >= params.due_after)
return stmt
async def _resolve_task_owner(self, task: Task) -> int | None:
if task.deal is not None:
return task.deal.owner_id
stmt = select(Deal.owner_id).where(Deal.id == task.deal_id)
return await self._session.scalar(stmt)