Refactor code for improved readability and consistency
Test / test (push) Successful in 15s

- Reformatted function signatures in `organization_service.py` and `task_service.py` for better alignment.
- Updated import statements across multiple files for consistency and organization.
- Enhanced test files by improving formatting and ensuring consistent use of async session factories.
- Added type hints and improved type safety in various service and test files.
- Adjusted `pyproject.toml` to include configuration for isort, mypy, and ruff for better code quality checks.
- Cleaned up unused imports and organized existing ones in several test files.
This commit is contained in:
Artem Kashaev
2025-12-01 16:18:03 +05:00
parent eecb74c523
commit 5fcb574aca
62 changed files with 765 additions and 476 deletions
+9 -4
View File
@@ -1,9 +1,11 @@
"""Reusable FastAPI dependencies."""
from collections.abc import AsyncGenerator
import jwt
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from redis.asyncio.client import Redis
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.cache import get_cache_client
@@ -18,9 +20,9 @@ from app.repositories.deal_repo import DealRepository
from app.repositories.org_repo import OrganizationRepository
from app.repositories.task_repo import TaskRepository
from app.repositories.user_repo import UserRepository
from app.services.activity_service import ActivityService
from app.services.analytics_service import AnalyticsService
from app.services.auth_service import AuthService
from app.services.activity_service import ActivityService
from app.services.contact_service import ContactService
from app.services.deal_service import DealService
from app.services.organization_service import (
@@ -30,7 +32,6 @@ from app.services.organization_service import (
OrganizationService,
)
from app.services.task_service import TaskService
from redis.asyncio.client import Redis
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.api_v1_prefix}/auth/token")
@@ -45,7 +46,9 @@ def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> User
return UserRepository(session=session)
def get_organization_repository(session: AsyncSession = Depends(get_db_session)) -> OrganizationRepository:
def get_organization_repository(
session: AsyncSession = Depends(get_db_session),
) -> OrganizationRepository:
return OrganizationRepository(session=session)
@@ -65,7 +68,9 @@ def get_activity_repository(session: AsyncSession = Depends(get_db_session)) ->
return ActivityRepository(session=session)
def get_analytics_repository(session: AsyncSession = Depends(get_db_session)) -> AnalyticsRepository:
def get_analytics_repository(
session: AsyncSession = Depends(get_db_session),
) -> AnalyticsRepository:
return AnalyticsRepository(session=session)
+8 -7
View File
@@ -1,14 +1,15 @@
"""Root API router that aggregates versioned routers."""
from fastapi import APIRouter
from app.api.v1 import (
activities,
analytics,
auth,
contacts,
deals,
organizations,
tasks,
activities,
analytics,
auth,
contacts,
deals,
organizations,
tasks,
)
from app.core.config import settings
+15 -14
View File
@@ -1,20 +1,21 @@
"""Version 1 API routers."""
from . import (
activities,
analytics,
auth,
contacts,
deals,
organizations,
tasks,
activities,
analytics,
auth,
contacts,
deals,
organizations,
tasks,
)
__all__ = [
"activities",
"analytics",
"auth",
"contacts",
"deals",
"organizations",
"tasks",
"activities",
"analytics",
"auth",
"contacts",
"deals",
"organizations",
"tasks",
]
+1
View File
@@ -1,4 +1,5 @@
"""Activity timeline endpoints and payload schemas."""
from __future__ import annotations
from typing import Literal
+5 -1
View File
@@ -1,4 +1,5 @@
"""Analytics API endpoints for summaries and funnels."""
from __future__ import annotations
from decimal import Decimal
@@ -16,6 +17,7 @@ def _decimal_to_str(value: Decimal) -> str:
normalized = value.normalize()
return format(normalized, "f")
router = APIRouter(prefix="/analytics", tags=["analytics"])
@@ -92,4 +94,6 @@ async def deals_funnel(
"""Return funnel breakdown by stages and statuses."""
breakdowns: list[StageBreakdown] = await service.get_deal_funnel(context.organization_id)
return DealFunnelResponse(stages=[StageBreakdownModel.model_validate(item) for item in breakdowns])
return DealFunnelResponse(
stages=[StageBreakdownModel.model_validate(item) for item in breakdowns]
)
+3 -2
View File
@@ -1,8 +1,9 @@
"""Authentication API endpoints and payloads."""
from __future__ import annotations
from pydantic import BaseModel, EmailStr
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -41,7 +42,7 @@ async def register_user(
organization: Organization | None = None
if payload.organization_name:
existing_org = await repo.session.scalar(
select(Organization).where(Organization.name == payload.organization_name)
select(Organization).where(Organization.name == payload.organization_name),
)
if existing_org is not None:
raise HTTPException(
+4 -1
View File
@@ -1,4 +1,5 @@
"""Contact API endpoints."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, Query, status
@@ -81,7 +82,9 @@ async def create_contact(
context: OrganizationContext = Depends(get_organization_context),
service: ContactService = Depends(get_contact_service),
) -> ContactRead:
data = payload.to_domain(organization_id=context.organization_id, fallback_owner=context.user_id)
data = payload.to_domain(
organization_id=context.organization_id, fallback_owner=context.user_id
)
try:
contact = await service.create_contact(data, context=context)
except ContactForbiddenError as exc:
+8 -3
View File
@@ -1,4 +1,5 @@
"""Deal API endpoints backed by DealService with inline payload schemas."""
from __future__ import annotations
from decimal import Decimal
@@ -8,7 +9,7 @@ from pydantic import BaseModel
from app.api.deps import get_deal_repository, get_deal_service, get_organization_context
from app.models.deal import DealCreate, DealRead, DealStage, DealStatus
from app.repositories.deal_repo import DealRepository, DealAccessError, DealQueryParams
from app.repositories.deal_repo import DealAccessError, DealQueryParams, DealRepository
from app.services.deal_service import (
DealService,
DealStageTransitionError,
@@ -66,7 +67,9 @@ async def list_deals(
statuses_value = [DealStatus(value) for value in status_filter] if status_filter else None
stage_value = DealStage(stage) if stage else None
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid deal filter") from exc
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid deal filter"
) from exc
params = DealQueryParams(
organization_id=context.organization_id,
@@ -96,7 +99,9 @@ async def create_deal(
) -> DealRead:
"""Create a new deal within the current organization."""
data = payload.to_domain(organization_id=context.organization_id, fallback_owner=context.user_id)
data = payload.to_domain(
organization_id=context.organization_id, fallback_owner=context.user_id
)
try:
deal = await service.create_deal(data, context=context)
except DealAccessError as exc:
+1
View File
@@ -1,4 +1,5 @@
"""Organization-related API endpoints."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
+1
View File
@@ -1,4 +1,5 @@
"""Task API endpoints with inline schemas."""
from __future__ import annotations
from datetime import date, datetime, time, timezone
+18 -8
View File
@@ -1,17 +1,18 @@
"""Redis cache utilities and availability tracking."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, Optional
from collections.abc import Awaitable, Callable
from typing import Any
import redis.asyncio as redis
from app.core.config import settings
from redis.asyncio.client import Redis
from redis.exceptions import RedisError
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -44,7 +45,9 @@ class RedisCacheManager:
async with self._lock:
if self._client is not None:
return
self._client = redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=False)
self._client = redis.from_url(
settings.redis_url, encoding="utf-8", decode_responses=False
)
await self._refresh_availability()
async def shutdown(self) -> None:
@@ -59,7 +62,9 @@ class RedisCacheManager:
return
async with self._lock:
if self._client is None:
self._client = redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=False)
self._client = redis.from_url(
settings.redis_url, encoding="utf-8", decode_responses=False
)
await self._refresh_availability()
async def _refresh_availability(self) -> None:
@@ -95,7 +100,7 @@ async def shutdown_cache() -> None:
await cache_manager.shutdown()
def get_cache_client() -> Optional[Redis]:
def get_cache_client() -> Redis | None:
"""Expose the active Redis client for dependency injection."""
return cache_manager.get_client()
@@ -113,12 +118,17 @@ async def read_json(client: Redis, key: str) -> Any | None:
cache_manager.mark_available()
try:
return json.loads(raw.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc: # pragma: no cover - malformed payloads
except (
UnicodeDecodeError,
json.JSONDecodeError,
) as exc: # pragma: no cover - malformed payloads
logger.warning("Discarding malformed cache entry %s: %s", key, exc)
return None
async def write_json(client: Redis, key: str, value: Any, ttl_seconds: int, backoff_ms: int) -> None:
async def write_json(
client: Redis, key: str, value: Any, ttl_seconds: int, backoff_ms: int
) -> None:
"""Serialize data to JSON and store it with TTL using retry/backoff."""
payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8")
+7 -2
View File
@@ -1,4 +1,5 @@
"""Application settings using Pydantic Settings."""
from pydantic import Field, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -15,7 +16,9 @@ class Settings(BaseSettings):
db_port: int = Field(default=5432, description="Database port")
db_name: str = Field(default="test_task_crm", description="Database name")
db_user: str = Field(default="postgres", description="Database user")
db_password: SecretStr = Field(default=SecretStr("postgres"), description="Database user password")
db_password: SecretStr = Field(
default=SecretStr("postgres"), description="Database user password"
)
database_url_override: str | None = Field(
default=None,
alias="DATABASE_URL",
@@ -28,7 +31,9 @@ class Settings(BaseSettings):
refresh_token_expire_days: int = 7
redis_enabled: bool = Field(default=False, description="Toggle Redis-backed cache usage")
redis_url: str = Field(default="redis://localhost:6379/0", description="Redis connection URL")
analytics_cache_ttl_seconds: int = Field(default=120, ge=1, description="TTL for cached analytics responses")
analytics_cache_ttl_seconds: int = Field(
default=120, ge=1, description="TTL for cached analytics responses"
)
analytics_cache_backoff_ms: int = Field(
default=200,
ge=0,
+2 -2
View File
@@ -1,11 +1,11 @@
"""Database utilities for async SQLAlchemy engine and sessions."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.core.config import settings
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
engine = create_async_engine(settings.database_url, echo=settings.sqlalchemy_echo)
AsyncSessionMaker = async_sessionmaker(bind=engine, expire_on_commit=False)
+2 -1
View File
@@ -1,11 +1,12 @@
"""Middleware that logs cache availability transitions."""
from __future__ import annotations
import logging
from starlette.types import ASGIApp, Receive, Scope, Send
from app.core.cache import cache_manager
from app.core.config import settings
from starlette.types import ASGIApp, Receive, Scope, Send
logger = logging.getLogger(__name__)
+4 -3
View File
@@ -1,13 +1,14 @@
"""Security helpers for hashing passwords and issuing JWT tokens."""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone
from typing import Any, Mapping
from typing import Any
import jwt
from passlib.context import CryptContext # type: ignore
from app.core.config import settings
from passlib.context import CryptContext # type: ignore
class PasswordHasher:
+6 -4
View File
@@ -7,6 +7,7 @@ from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
@@ -14,13 +15,12 @@ from app.api.routes import api_router
from app.core.cache import init_cache, shutdown_cache
from app.core.config import settings
from app.core.middleware.cache_monitor import CacheAvailabilityMiddleware
from fastapi.middleware.cors import CORSMiddleware
PROJECT_ROOT = Path(__file__).resolve().parent.parent
FRONTEND_DIST = PROJECT_ROOT / "frontend" / "dist"
FRONTEND_INDEX = FRONTEND_DIST / "index.html"
def create_app() -> FastAPI:
"""Build FastAPI application instance."""
@@ -43,7 +43,7 @@ def create_app() -> FastAPI:
# "http://localhost:8000",
# "http://0.0.0.0:8000",
# "http://127.0.0.1:8000",
"*" # ! TODO: Убрать
"*", # ! TODO: Убрать
],
allow_credentials=True,
allow_methods=["*"], # Разрешить все HTTP-методы
@@ -59,7 +59,9 @@ def create_app() -> FastAPI:
return FileResponse(FRONTEND_INDEX)
@application.get("/{path:path}", include_in_schema=False)
async def serve_frontend_path(path: str) -> FileResponse: # pragma: no cover - simple file response
async def serve_frontend_path(
path: str,
) -> FileResponse: # pragma: no cover - simple file response
if path == "" or path.startswith("api"):
raise HTTPException(status_code=404)
+13 -12
View File
@@ -1,4 +1,5 @@
"""Model exports for Alembic discovery."""
from app.models.activity import Activity, ActivityType
from app.models.base import Base
from app.models.contact import Contact
@@ -9,16 +10,16 @@ from app.models.task import Task
from app.models.user import User
__all__ = [
"Activity",
"ActivityType",
"Base",
"Contact",
"Deal",
"DealStage",
"DealStatus",
"Organization",
"OrganizationMember",
"OrganizationRole",
"Task",
"User",
"Activity",
"ActivityType",
"Base",
"Contact",
"Deal",
"DealStage",
"DealStatus",
"Organization",
"OrganizationMember",
"OrganizationRole",
"Task",
"User",
]
+12 -5
View File
@@ -1,4 +1,5 @@
"""Activity timeline ORM model and schemas."""
from __future__ import annotations
from datetime import datetime
@@ -6,10 +7,12 @@ from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, func, text
from sqlalchemy import DateTime, ForeignKey, Integer, func, text
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import JSON as GenericJSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import JSON as GenericJSON
from sqlalchemy.types import TypeDecorator
from app.models.base import Base, enum_values
@@ -44,10 +47,12 @@ class Activity(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
deal_id: Mapped[int] = mapped_column(ForeignKey("deals.id", ondelete="CASCADE"))
author_id: Mapped[int | None] = mapped_column(
ForeignKey("users.id", ondelete="SET NULL"), nullable=True
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
type: Mapped[ActivityType] = mapped_column(
SqlEnum(ActivityType, name="activity_type", values_callable=enum_values), nullable=False
SqlEnum(ActivityType, name="activity_type", values_callable=enum_values),
nullable=False,
)
payload: Mapped[dict[str, Any]] = mapped_column(
JSONBCompat().with_variant(GenericJSON(), "sqlite"),
@@ -55,7 +60,9 @@ class Activity(Base):
server_default=text("'{}'"),
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
deal = relationship("Deal", back_populates="activities")
+1
View File
@@ -1,4 +1,5 @@
"""Declarative base for SQLAlchemy models."""
from __future__ import annotations
from enum import StrEnum
+4 -1
View File
@@ -1,4 +1,5 @@
"""Contact ORM model and schemas."""
from __future__ import annotations
from datetime import datetime
@@ -22,7 +23,9 @@ class Contact(Base):
email: Mapped[str | None] = mapped_column(String(320), nullable=True)
phone: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
organization = relationship("Organization", back_populates="contacts")
+10 -3
View File
@@ -1,4 +1,5 @@
"""Deal ORM model and schemas."""
from __future__ import annotations
from datetime import datetime
@@ -6,7 +7,8 @@ from decimal import Decimal
from enum import StrEnum
from pydantic import BaseModel, ConfigDict
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, Numeric, String, func
from sqlalchemy import DateTime, ForeignKey, Integer, Numeric, String, func
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, enum_values
@@ -49,10 +51,15 @@ class Deal(Base):
default=DealStage.QUALIFICATION,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
organization = relationship("Organization", back_populates="deals")
+4 -1
View File
@@ -1,4 +1,5 @@
"""Organization ORM model and schemas."""
from __future__ import annotations
from datetime import datetime
@@ -18,7 +19,9 @@ class Organization(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
members = relationship(
+6 -2
View File
@@ -1,11 +1,13 @@
"""Organization member ORM model."""
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, ConfigDict
from sqlalchemy import DateTime, Enum as SqlEnum, ForeignKey, Integer, UniqueConstraint, func
from sqlalchemy import DateTime, ForeignKey, Integer, UniqueConstraint, func
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import Base, enum_values
@@ -39,7 +41,9 @@ class OrganizationMember(Base):
default=OrganizationRole.MEMBER,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
organization = relationship("Organization", back_populates="members")
+4 -1
View File
@@ -1,4 +1,5 @@
"""Task ORM model and schemas."""
from __future__ import annotations
from datetime import datetime
@@ -22,7 +23,9 @@ class Task(Base):
due_date: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
is_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
deal = relationship("Deal", back_populates="tasks")
+1
View File
@@ -1,4 +1,5 @@
"""Token-related Pydantic schemas."""
from __future__ import annotations
from datetime import datetime
+11 -3
View File
@@ -1,4 +1,5 @@
"""User ORM model and Pydantic schemas."""
from __future__ import annotations
from datetime import datetime
@@ -25,13 +26,20 @@ class User(Base):
name: Mapped[str] = mapped_column(String(255), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
memberships = relationship("OrganizationMember", back_populates="user", cascade="all, delete-orphan")
memberships = relationship(
"OrganizationMember", back_populates="user", cascade="all, delete-orphan"
)
owned_contacts = relationship("Contact", back_populates="owner")
owned_deals = relationship("Deal", back_populates="owner")
activities = relationship("Activity", back_populates="author")
+4 -1
View File
@@ -1,4 +1,5 @@
"""Repository helpers for deal activities."""
from __future__ import annotations
from collections.abc import Sequence
@@ -39,7 +40,9 @@ class ActivityRepository:
stmt = (
select(Activity)
.join(Deal, Deal.id == Activity.deal_id)
.where(Activity.deal_id == params.deal_id, Deal.organization_id == params.organization_id)
.where(
Activity.deal_id == params.deal_id, Deal.organization_id == params.organization_id
)
.order_by(Activity.created_at)
)
stmt = self._apply_window(stmt, params)
+2 -1
View File
@@ -1,4 +1,5 @@
"""Analytics-specific data access helpers."""
from __future__ import annotations
from dataclasses import dataclass
@@ -58,7 +59,7 @@ class AnalyticsRepository:
deal_count=int(count or 0),
amount_sum=_to_decimal(amount_sum),
amount_count=int(amount_count or 0),
)
),
)
return rollup
+8 -3
View File
@@ -1,4 +1,5 @@
"""Repository helpers for contacts with role-aware access."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
@@ -44,7 +45,9 @@ class ContactRepository:
role: OrganizationRole,
user_id: int,
) -> Sequence[Contact]:
stmt: Select[tuple[Contact]] = select(Contact).where(Contact.organization_id == params.organization_id)
stmt: Select[tuple[Contact]] = select(Contact).where(
Contact.organization_id == params.organization_id
)
stmt = self._apply_filters(stmt, params, role, user_id)
offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.order_by(Contact.created_at.desc()).offset(offset).limit(params.page_size)
@@ -59,7 +62,9 @@ class ContactRepository:
role: OrganizationRole,
user_id: int,
) -> Contact | None:
stmt = select(Contact).where(Contact.id == contact_id, Contact.organization_id == organization_id)
stmt = select(Contact).where(
Contact.id == contact_id, Contact.organization_id == organization_id
)
result = await self._session.scalars(stmt)
return result.first()
@@ -117,7 +122,7 @@ class ContactRepository:
pattern = f"%{params.search.lower()}%"
stmt = stmt.where(
func.lower(Contact.name).like(pattern)
| func.lower(func.coalesce(Contact.email, "")).like(pattern)
| func.lower(func.coalesce(Contact.email, "")).like(pattern),
)
if params.owner_id is not None:
if role == OrganizationRole.MEMBER:
+117 -115
View File
@@ -1,4 +1,5 @@
"""Deal repository with access-aware CRUD helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
@@ -12,142 +13,143 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.models.deal import Deal, DealCreate, DealStage, DealStatus
from app.models.organization_member import OrganizationRole
ORDERABLE_COLUMNS: dict[str, Any] = {
"created_at": Deal.created_at,
"amount": Deal.amount,
"title": Deal.title,
"created_at": Deal.created_at,
"amount": Deal.amount,
"title": Deal.title,
}
class DealAccessError(Exception):
"""Raised when a user attempts an operation without sufficient permissions."""
"""Raised when a user attempts an operation without sufficient permissions."""
@dataclass(slots=True)
class DealQueryParams:
"""Filters supported by list queries."""
"""Filters supported by list queries."""
organization_id: int
page: int = 1
page_size: int = 20
statuses: Sequence[DealStatus] | None = None
stage: DealStage | None = None
owner_id: int | None = None
min_amount: Decimal | None = None
max_amount: Decimal | None = None
order_by: str | None = None
order_desc: bool = True
organization_id: int
page: int = 1
page_size: int = 20
statuses: Sequence[DealStatus] | None = None
stage: DealStage | None = None
owner_id: int | None = None
min_amount: Decimal | None = None
max_amount: Decimal | None = None
order_by: str | None = None
order_desc: bool = True
class DealRepository:
"""Provides CRUD helpers for deals with role-aware filtering."""
"""Provides CRUD helpers for deals with role-aware filtering."""
def __init__(self, session: AsyncSession) -> None:
self._session = session
def __init__(self, session: AsyncSession) -> None:
self._session = session
@property
def session(self) -> AsyncSession:
return self._session
@property
def session(self) -> AsyncSession:
return self._session
async def list(
self,
*,
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Sequence[Deal]:
stmt = select(Deal).where(Deal.organization_id == params.organization_id)
stmt = self._apply_filters(stmt, params, role, user_id)
stmt = self._apply_ordering(stmt, params)
async def list(
self,
*,
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Sequence[Deal]:
stmt = select(Deal).where(Deal.organization_id == params.organization_id)
stmt = self._apply_filters(stmt, params, role, user_id)
stmt = self._apply_ordering(stmt, params)
offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.offset(offset).limit(params.page_size)
result = await self._session.scalars(stmt)
return result.all()
offset = (max(params.page, 1) - 1) * params.page_size
stmt = stmt.offset(offset).limit(params.page_size)
result = await self._session.scalars(stmt)
return result.all()
async def get(
self,
deal_id: int,
*,
organization_id: int,
role: OrganizationRole,
user_id: int,
require_owner: bool = False,
) -> Deal | None:
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
result = await self._session.scalars(stmt)
return result.first()
async def get(
self,
deal_id: int,
*,
organization_id: int,
role: OrganizationRole,
user_id: int,
require_owner: bool = False,
) -> Deal | None:
stmt = select(Deal).where(Deal.id == deal_id, Deal.organization_id == organization_id)
stmt = self._apply_role_clause(stmt, role, user_id, require_owner=require_owner)
result = await self._session.scalars(stmt)
return result.first()
async def create(
self,
data: DealCreate,
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and data.owner_id != user_id:
raise DealAccessError("Members can only create deals they own")
deal = Deal(**data.model_dump())
self._session.add(deal)
await self._session.flush()
return deal
async def create(
self,
data: DealCreate,
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and data.owner_id != user_id:
raise DealAccessError("Members can only create deals they own")
deal = Deal(**data.model_dump())
self._session.add(deal)
await self._session.flush()
return deal
async def update(
self,
deal: Deal,
updates: Mapping[str, Any],
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise DealAccessError("Members can only modify their own deals")
for field, value in updates.items():
if hasattr(deal, field):
setattr(deal, field, value)
await self._session.flush()
await self._session.refresh(deal)
return deal
async def update(
self,
deal: Deal,
updates: Mapping[str, Any],
*,
role: OrganizationRole,
user_id: int,
) -> Deal:
if role == OrganizationRole.MEMBER and deal.owner_id != user_id:
raise DealAccessError("Members can only modify their own deals")
for field, value in updates.items():
if hasattr(deal, field):
setattr(deal, field, value)
await self._session.flush()
await self._session.refresh(deal)
return deal
def _apply_filters(
self,
stmt: Select[tuple[Deal]],
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Select[tuple[Deal]]:
if params.statuses:
stmt = stmt.where(Deal.status.in_(params.statuses))
if params.stage:
stmt = stmt.where(Deal.stage == params.stage)
if params.owner_id is not None:
if role == OrganizationRole.MEMBER and params.owner_id != user_id:
raise DealAccessError("Members cannot filter by other owners")
stmt = stmt.where(Deal.owner_id == params.owner_id)
if params.min_amount is not None:
stmt = stmt.where(Deal.amount >= params.min_amount)
if params.max_amount is not None:
stmt = stmt.where(Deal.amount <= params.max_amount)
def _apply_filters(
self,
stmt: Select[tuple[Deal]],
params: DealQueryParams,
role: OrganizationRole,
user_id: int,
) -> Select[tuple[Deal]]:
if params.statuses:
stmt = stmt.where(Deal.status.in_(params.statuses))
if params.stage:
stmt = stmt.where(Deal.stage == params.stage)
if params.owner_id is not None:
if role == OrganizationRole.MEMBER and params.owner_id != user_id:
raise DealAccessError("Members cannot filter by other owners")
stmt = stmt.where(Deal.owner_id == params.owner_id)
if params.min_amount is not None:
stmt = stmt.where(Deal.amount >= params.min_amount)
if params.max_amount is not None:
stmt = stmt.where(Deal.amount <= params.max_amount)
return self._apply_role_clause(stmt, role, user_id)
return self._apply_role_clause(stmt, role, user_id)
def _apply_role_clause(
self,
stmt: Select[tuple[Deal]],
role: OrganizationRole,
user_id: int,
*,
require_owner: bool = False,
) -> Select[tuple[Deal]]:
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
return stmt
if require_owner:
return stmt.where(Deal.owner_id == user_id)
return stmt
def _apply_role_clause(
self,
stmt: Select[tuple[Deal]],
role: OrganizationRole,
user_id: int,
*,
require_owner: bool = False,
) -> Select[tuple[Deal]]:
if role in {OrganizationRole.OWNER, OrganizationRole.ADMIN, OrganizationRole.MANAGER}:
return stmt
if require_owner:
return stmt.where(Deal.owner_id == user_id)
return stmt
def _apply_ordering(self, stmt: Select[tuple[Deal]], params: DealQueryParams) -> Select[tuple[Deal]]:
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
order_func = desc if params.order_desc else asc
return stmt.order_by(order_func(column))
def _apply_ordering(
self, stmt: Select[tuple[Deal]], params: DealQueryParams
) -> Select[tuple[Deal]]:
column = ORDERABLE_COLUMNS.get(params.order_by or "created_at", Deal.created_at)
order_func = desc if params.order_desc else asc
return stmt.order_by(order_func(column))
+2 -1
View File
@@ -1,11 +1,12 @@
"""Organization repository for database operations."""
from __future__ import annotations
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.organization import Organization, OrganizationCreate
from app.models.organization_member import OrganizationMember
+4 -1
View File
@@ -1,4 +1,5 @@
"""Task repository providing role-aware CRUD helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
@@ -105,7 +106,9 @@ class TaskRepository:
await self._session.flush()
return task
def _apply_filters(self, stmt: Select[tuple[Task]], params: TaskQueryParams) -> Select[tuple[Task]]:
def _apply_filters(
self, stmt: Select[tuple[Task]], params: TaskQueryParams
) -> Select[tuple[Task]]:
if params.deal_id is not None:
stmt = stmt.where(Task.deal_id == params.deal_id)
if params.only_open:
+1
View File
@@ -1,4 +1,5 @@
"""User repository handling database operations."""
from __future__ import annotations
from collections.abc import Sequence
+2 -1
View File
@@ -1,4 +1,5 @@
"""Business logic services."""
from .activity_service import ( # noqa: F401
ActivityForbiddenError,
ActivityListFilters,
@@ -22,4 +23,4 @@ from .task_service import ( # noqa: F401
TaskService,
TaskServiceError,
TaskUpdateData,
)
)
+1
View File
@@ -1,4 +1,5 @@
"""Business logic for timeline activities."""
from __future__ import annotations
from collections.abc import Sequence
+23 -16
View File
@@ -1,11 +1,13 @@
"""Analytics-related business logic."""
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from decimal import Decimal, InvalidOperation
from typing import Any, Iterable
from typing import Any
from redis.asyncio.client import Redis
from redis.exceptions import RedisError
@@ -105,9 +107,7 @@ class AnalyticsService:
won_amount_count = row.amount_count
won_count = row.deal_count
won_average = (
(won_amount_sum / won_amount_count) if won_amount_count > 0 else Decimal("0")
)
won_average = (won_amount_sum / won_amount_count) if won_amount_count > 0 else Decimal("0")
window_threshold = _threshold_from_days(days)
new_deals = await self._repository.count_new_deals_since(organization_id, window_threshold)
@@ -137,7 +137,7 @@ class AnalyticsService:
breakdowns: list[StageBreakdown] = []
totals = {stage: sum(by_status.values()) for stage, by_status in stage_map.items()}
for index, stage in enumerate(_STAGE_ORDER):
by_status = stage_map.get(stage, {status: 0 for status in DealStatus})
by_status = stage_map.get(stage, dict.fromkeys(DealStatus, 0))
total = totals.get(stage, 0)
conversion = None
if index < len(_STAGE_ORDER) - 1:
@@ -151,7 +151,7 @@ class AnalyticsService:
total=total,
by_status=by_status,
conversion_to_next=conversion,
)
),
)
await self._store_funnel_cache(organization_id, breakdowns)
return breakdowns
@@ -168,7 +168,9 @@ class AnalyticsService:
return None
return _deserialize_summary(payload)
async def _store_summary_cache(self, organization_id: int, days: int, summary: DealSummary) -> None:
async def _store_summary_cache(
self, organization_id: int, days: int, summary: DealSummary
) -> None:
if not self._is_cache_enabled() or self._cache is None:
return
key = _summary_cache_key(organization_id, days)
@@ -184,7 +186,9 @@ class AnalyticsService:
return None
return _deserialize_funnel(payload)
async def _store_funnel_cache(self, organization_id: int, breakdowns: list[StageBreakdown]) -> None:
async def _store_funnel_cache(
self, organization_id: int, breakdowns: list[StageBreakdown]
) -> None:
if not self._is_cache_enabled() or self._cache is None:
return
key = _funnel_cache_key(organization_id)
@@ -198,11 +202,10 @@ def _threshold_from_days(days: int) -> datetime:
def _build_stage_map(rollup: Iterable[StageStatusRollup]) -> dict[DealStage, dict[DealStatus, int]]:
stage_map: dict[DealStage, dict[DealStatus, int]] = {
stage: {status: 0 for status in DealStatus}
for stage in _STAGE_ORDER
stage: dict.fromkeys(DealStatus, 0) for stage in _STAGE_ORDER
}
for item in rollup:
stage_map.setdefault(item.stage, {status: 0 for status in DealStatus})
stage_map.setdefault(item.stage, dict.fromkeys(DealStatus, 0))
stage_map[item.stage][item.status] = item.deal_count
return stage_map
@@ -263,7 +266,7 @@ def _deserialize_summary(payload: Any) -> DealSummary | None:
status=DealStatus(item["status"]),
count=int(item["count"]),
amount_sum=Decimal(item["amount_sum"]),
)
),
)
won = WonStatistics(
count=int(won_payload["count"]),
@@ -289,7 +292,7 @@ def _serialize_funnel(breakdowns: list[StageBreakdown]) -> list[dict[str, Any]]:
"total": item.total,
"by_status": {status.value: count for status, count in item.by_status.items()},
"conversion_to_next": item.conversion_to_next,
}
},
)
return serialized
@@ -307,15 +310,19 @@ def _deserialize_funnel(payload: Any) -> list[StageBreakdown] | None:
stage=DealStage(item["stage"]),
total=int(item["total"]),
by_status=by_status,
conversion_to_next=float(item["conversion_to_next"]) if item["conversion_to_next"] is not None else None,
)
conversion_to_next=float(item["conversion_to_next"])
if item["conversion_to_next"] is not None
else None,
),
)
except (KeyError, TypeError, ValueError):
return None
return breakdowns
async def invalidate_analytics_cache(cache: Redis | None, organization_id: int, backoff_ms: int) -> None:
async def invalidate_analytics_cache(
cache: Redis | None, organization_id: int, backoff_ms: int
) -> None:
"""Remove cached analytics payloads for the organization."""
if cache is None:
+1
View File
@@ -1,4 +1,5 @@
"""Authentication workflows."""
from __future__ import annotations
from datetime import timedelta
+7 -2
View File
@@ -1,4 +1,5 @@
"""Business logic for contact workflows."""
from __future__ import annotations
from collections.abc import Sequence
@@ -78,7 +79,9 @@ class ContactService:
owner_id=filters.owner_id,
)
try:
return await self._repository.list(params=params, role=context.role, user_id=context.user_id)
return await self._repository.list(
params=params, role=context.role, user_id=context.user_id
)
except ContactAccessError as exc:
raise ContactForbiddenError(str(exc)) from exc
@@ -122,7 +125,9 @@ class ContactService:
if not payload:
return contact
try:
return await self._repository.update(contact, payload, role=context.role, user_id=context.user_id)
return await self._repository.update(
contact, payload, role=context.role, user_id=context.user_id
)
except ContactAccessError as exc:
raise ContactForbiddenError(str(exc)) from exc
+132 -120
View File
@@ -1,4 +1,5 @@
"""Business logic for deals."""
from __future__ import annotations
from collections.abc import Iterable
@@ -16,162 +17,173 @@ from app.repositories.deal_repo import DealRepository
from app.services.analytics_service import invalidate_analytics_cache
from app.services.organization_service import OrganizationContext
STAGE_ORDER = {
stage: index
for index, stage in enumerate(
[
DealStage.QUALIFICATION,
DealStage.PROPOSAL,
DealStage.NEGOTIATION,
DealStage.CLOSED,
]
)
stage: index
for index, stage in enumerate(
[
DealStage.QUALIFICATION,
DealStage.PROPOSAL,
DealStage.NEGOTIATION,
DealStage.CLOSED,
],
)
}
class DealServiceError(Exception):
"""Base class for deal service errors."""
"""Base class for deal service errors."""
class DealOrganizationMismatchError(DealServiceError):
"""Raised when attempting to use resources from another organization."""
"""Raised when attempting to use resources from another organization."""
class DealStageTransitionError(DealServiceError):
"""Raised when stage transition violates business rules."""
"""Raised when stage transition violates business rules."""
class DealStatusValidationError(DealServiceError):
"""Raised when invalid status transitions are requested."""
"""Raised when invalid status transitions are requested."""
class ContactHasDealsError(DealServiceError):
"""Raised when attempting to delete a contact with active deals."""
"""Raised when attempting to delete a contact with active deals."""
@dataclass(slots=True)
class DealUpdateData:
"""Structured container for deal update operations."""
"""Structured container for deal update operations."""
status: DealStatus | None = None
stage: DealStage | None = None
amount: Decimal | None = None
currency: str | None = None
status: DealStatus | None = None
stage: DealStage | None = None
amount: Decimal | None = None
currency: str | None = None
class DealService:
"""Encapsulates deal workflows and validations."""
"""Encapsulates deal workflows and validations."""
def __init__(
self,
repository: DealRepository,
cache: Redis | None = None,
*,
cache_backoff_ms: int = 0,
) -> None:
self._repository = repository
self._cache = cache
self._cache_backoff_ms = cache_backoff_ms
def __init__(
self,
repository: DealRepository,
cache: Redis | None = None,
*,
cache_backoff_ms: int = 0,
) -> None:
self._repository = repository
self._cache = cache
self._cache_backoff_ms = cache_backoff_ms
async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal:
self._ensure_same_organization(data.organization_id, context)
await self._ensure_contact_in_organization(data.contact_id, context.organization_id)
deal = await self._repository.create(data=data, role=context.role, user_id=context.user_id)
await invalidate_analytics_cache(self._cache, context.organization_id, self._cache_backoff_ms)
return deal
async def create_deal(self, data: DealCreate, *, context: OrganizationContext) -> Deal:
self._ensure_same_organization(data.organization_id, context)
await self._ensure_contact_in_organization(data.contact_id, context.organization_id)
deal = await self._repository.create(data=data, role=context.role, user_id=context.user_id)
await invalidate_analytics_cache(
self._cache, context.organization_id, self._cache_backoff_ms
)
return deal
async def update_deal(
self,
deal: Deal,
updates: DealUpdateData,
*,
context: OrganizationContext,
) -> Deal:
self._ensure_same_organization(deal.organization_id, context)
changes: dict[str, object] = {}
stage_activity: tuple[ActivityType, dict[str, str]] | None = None
status_activity: tuple[ActivityType, dict[str, str]] | None = None
async def update_deal(
self,
deal: Deal,
updates: DealUpdateData,
*,
context: OrganizationContext,
) -> Deal:
self._ensure_same_organization(deal.organization_id, context)
changes: dict[str, object] = {}
stage_activity: tuple[ActivityType, dict[str, str]] | None = None
status_activity: tuple[ActivityType, dict[str, str]] | None = None
if updates.amount is not None:
changes["amount"] = updates.amount
if updates.currency is not None:
changes["currency"] = updates.currency
if updates.amount is not None:
changes["amount"] = updates.amount
if updates.currency is not None:
changes["currency"] = updates.currency
if updates.stage is not None and updates.stage != deal.stage:
self._validate_stage_transition(deal.stage, updates.stage, context.role)
changes["stage"] = updates.stage
stage_activity = (
ActivityType.STAGE_CHANGED,
{"old_stage": deal.stage, "new_stage": updates.stage},
)
if updates.stage is not None and updates.stage != deal.stage:
self._validate_stage_transition(deal.stage, updates.stage, context.role)
changes["stage"] = updates.stage
stage_activity = (
ActivityType.STAGE_CHANGED,
{"old_stage": deal.stage, "new_stage": updates.stage},
)
if updates.status is not None and updates.status != deal.status:
self._validate_status_transition(deal, updates)
changes["status"] = updates.status
status_activity = (
ActivityType.STATUS_CHANGED,
{"old_status": deal.status, "new_status": updates.status},
)
if updates.status is not None and updates.status != deal.status:
self._validate_status_transition(deal, updates)
changes["status"] = updates.status
status_activity = (
ActivityType.STATUS_CHANGED,
{"old_status": deal.status, "new_status": updates.status},
)
if not changes:
return deal
if not changes:
return deal
updated = await self._repository.update(deal, changes, role=context.role, user_id=context.user_id)
await self._log_activities(
deal_id=deal.id,
author_id=context.user_id,
activities=[activity for activity in [stage_activity, status_activity] if activity],
)
await invalidate_analytics_cache(self._cache, context.organization_id, self._cache_backoff_ms)
return updated
updated = await self._repository.update(
deal, changes, role=context.role, user_id=context.user_id
)
await self._log_activities(
deal_id=deal.id,
author_id=context.user_id,
activities=[activity for activity in [stage_activity, status_activity] if activity],
)
await invalidate_analytics_cache(
self._cache, context.organization_id, self._cache_backoff_ms
)
return updated
async def ensure_contact_can_be_deleted(self, contact_id: int) -> None:
stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id)
count = await self._repository.session.scalar(stmt)
if count and count > 0:
raise ContactHasDealsError("Contact has related deals and cannot be deleted")
async def ensure_contact_can_be_deleted(self, contact_id: int) -> None:
stmt = select(func.count()).select_from(Deal).where(Deal.contact_id == contact_id)
count = await self._repository.session.scalar(stmt)
if count and count > 0:
raise ContactHasDealsError("Contact has related deals and cannot be deleted")
async def _log_activities(
self,
*,
deal_id: int,
author_id: int,
activities: Iterable[tuple[ActivityType, dict[str, str]]],
) -> None:
entries = list(activities)
if not entries:
return
for activity_type, payload in entries:
activity = Activity(deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload)
self._repository.session.add(activity)
await self._repository.session.flush()
async def _log_activities(
self,
*,
deal_id: int,
author_id: int,
activities: Iterable[tuple[ActivityType, dict[str, str]]],
) -> None:
entries = list(activities)
if not entries:
return
for activity_type, payload in entries:
activity = Activity(
deal_id=deal_id, author_id=author_id, type=activity_type, payload=payload
)
self._repository.session.add(activity)
await self._repository.session.flush()
def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None:
if organization_id != context.organization_id:
raise DealOrganizationMismatchError("Operation targets a different organization")
def _ensure_same_organization(self, organization_id: int, context: OrganizationContext) -> None:
if organization_id != context.organization_id:
raise DealOrganizationMismatchError("Operation targets a different organization")
async def _ensure_contact_in_organization(self, contact_id: int, organization_id: int) -> Contact:
contact = await self._repository.session.get(Contact, contact_id)
if contact is None or contact.organization_id != organization_id:
raise DealOrganizationMismatchError("Contact belongs to another organization")
return contact
async def _ensure_contact_in_organization(
self, contact_id: int, organization_id: int
) -> Contact:
contact = await self._repository.session.get(Contact, contact_id)
if contact is None or contact.organization_id != organization_id:
raise DealOrganizationMismatchError("Contact belongs to another organization")
return contact
def _validate_stage_transition(
self,
current_stage: DealStage,
new_stage: DealStage,
role: OrganizationRole,
) -> None:
if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in {
OrganizationRole.OWNER,
OrganizationRole.ADMIN,
}:
raise DealStageTransitionError("Stage rollback requires owner or admin role")
def _validate_stage_transition(
self,
current_stage: DealStage,
new_stage: DealStage,
role: OrganizationRole,
) -> None:
if STAGE_ORDER[new_stage] < STAGE_ORDER[current_stage] and role not in {
OrganizationRole.OWNER,
OrganizationRole.ADMIN,
}:
raise DealStageTransitionError("Stage rollback requires owner or admin role")
def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None:
if updates.status != DealStatus.WON:
return
effective_amount = updates.amount if updates.amount is not None else deal.amount
if effective_amount is None or Decimal(effective_amount) <= Decimal("0"):
raise DealStatusValidationError("Amount must be greater than zero to mark a deal as won")
def _validate_status_transition(self, deal: Deal, updates: DealUpdateData) -> None:
if updates.status != DealStatus.WON:
return
effective_amount = updates.amount if updates.amount is not None else deal.amount
if effective_amount is None or Decimal(effective_amount) <= Decimal("0"):
raise DealStatusValidationError(
"Amount must be greater than zero to mark a deal as won"
)
+8 -3
View File
@@ -1,4 +1,5 @@
"""Organization-related business rules."""
from __future__ import annotations
from dataclasses import dataclass
@@ -54,7 +55,9 @@ class OrganizationService:
def __init__(self, repository: OrganizationRepository) -> None:
self._repository = repository
async def get_context(self, *, user_id: int, organization_id: int | None) -> OrganizationContext:
async def get_context(
self, *, user_id: int, organization_id: int | None
) -> OrganizationContext:
"""Resolve request context ensuring the user belongs to the given organization."""
if organization_id is None:
@@ -66,7 +69,9 @@ class OrganizationService:
return OrganizationContext(organization=membership.organization, membership=membership)
def ensure_entity_in_context(self, *, entity_organization_id: int, context: OrganizationContext) -> None:
def ensure_entity_in_context(
self, *, entity_organization_id: int, context: OrganizationContext
) -> None:
"""Make sure a resource belongs to the current organization."""
if entity_organization_id != context.organization_id:
@@ -113,4 +118,4 @@ class OrganizationService:
self._repository.session.add(membership)
await self._repository.session.commit()
await self._repository.session.refresh(membership)
return membership
return membership
+6 -1
View File
@@ -1,4 +1,5 @@
"""Business logic for tasks linked to deals."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
@@ -9,10 +10,14 @@ from typing import Any
from app.models.activity import ActivityCreate, ActivityType
from app.models.organization_member import OrganizationRole
from app.models.task import Task, TaskCreate
from app.repositories.activity_repo import ActivityRepository, ActivityOrganizationMismatchError
from app.repositories.activity_repo import ActivityOrganizationMismatchError, ActivityRepository
from app.repositories.task_repo import (
TaskAccessError as RepoTaskAccessError,
)
from app.repositories.task_repo import (
TaskOrganizationMismatchError as RepoTaskOrganizationMismatchError,
)
from app.repositories.task_repo import (
TaskQueryParams,
TaskRepository,
)