refactor: enhance type hinting and casting for improved type safety across multiple files

This commit is contained in:
Artem Kashaev
2025-12-01 16:44:14 +05:00
parent f234e60e65
commit 688ade0452
14 changed files with 62 additions and 42 deletions
View File
View File
View File
+6 -4
View File
@@ -3,10 +3,12 @@
from __future__ import annotations
from enum import StrEnum
from typing import cast
from app.models.activity import Activity, ActivityType
from app.models.deal import Deal, DealStage, DealStatus
from app.models.organization_member import OrganizationMember, OrganizationRole
from sqlalchemy import Enum as SqlEnum
def _values(enum_cls: type[StrEnum]) -> list[str]:
@@ -14,20 +16,20 @@ def _values(enum_cls: type[StrEnum]) -> list[str]:
def test_organization_role_column_uses_value_strings() -> None:
role_type = OrganizationMember.__table__.c.role.type # noqa: SLF001 - runtime inspection
role_type = cast(SqlEnum, OrganizationMember.__table__.c.role.type) # noqa: SLF001
assert role_type.enums == _values(OrganizationRole)
def test_deal_status_column_uses_value_strings() -> None:
status_type = Deal.__table__.c.status.type # noqa: SLF001 - runtime inspection
status_type = cast(SqlEnum, Deal.__table__.c.status.type) # noqa: SLF001
assert status_type.enums == _values(DealStatus)
def test_deal_stage_column_uses_value_strings() -> None:
stage_type = Deal.__table__.c.stage.type # noqa: SLF001 - runtime inspection
stage_type = cast(SqlEnum, Deal.__table__.c.stage.type) # noqa: SLF001
assert stage_type.enums == _values(DealStage)
def test_activity_type_column_uses_value_strings() -> None:
activity_type = Activity.__table__.c.type.type # noqa: SLF001 - runtime inspection
activity_type = cast(SqlEnum, Activity.__table__.c.type.type) # noqa: SLF001
assert activity_type.enums == _values(ActivityType)
+13 -7
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from typing import cast
import pytest
import pytest_asyncio
@@ -14,8 +15,13 @@ from app.models.deal import Deal, DealStage, DealStatus
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.models.user import User
from app.repositories.analytics_repo import AnalyticsRepository
from app.repositories.analytics_repo import (
AnalyticsRepository,
StageStatusRollup,
StatusRollup,
)
from app.services.analytics_service import AnalyticsService, invalidate_analytics_cache
from redis.asyncio.client import Redis
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from tests.utils.fake_redis import InMemoryRedis
@@ -170,20 +176,20 @@ async def test_funnel_breakdown_contains_stage_conversions(session: AsyncSession
class _ExplodingRepository(AnalyticsRepository):
async def fetch_status_rollup(self, organization_id: int): # type: ignore[override]
async def fetch_status_rollup(self, organization_id: int) -> list[StatusRollup]:
raise AssertionError("cache not used for status rollup")
async def count_new_deals_since(self, organization_id: int, threshold): # type: ignore[override]
async def count_new_deals_since(self, organization_id: int, threshold: datetime) -> int:
raise AssertionError("cache not used for new deal count")
async def fetch_stage_status_rollup(self, organization_id: int): # type: ignore[override]
async def fetch_stage_status_rollup(self, organization_id: int) -> list[StageStatusRollup]:
raise AssertionError("cache not used for funnel rollup")
@pytest.mark.asyncio
async def test_summary_reads_from_cache_when_available(session: AsyncSession) -> None:
org_id, _, _ = await _seed_data(session)
cache = InMemoryRedis()
cache = cast(Redis, InMemoryRedis())
service = AnalyticsService(
repository=AnalyticsRepository(session),
cache=cache,
@@ -201,7 +207,7 @@ async def test_summary_reads_from_cache_when_available(session: AsyncSession) ->
@pytest.mark.asyncio
async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> None:
org_id, _, contact_id = await _seed_data(session)
cache = InMemoryRedis()
cache = cast(Redis, InMemoryRedis())
service = AnalyticsService(
repository=AnalyticsRepository(session),
cache=cache,
@@ -235,7 +241,7 @@ async def test_invalidation_refreshes_cached_summary(session: AsyncSession) -> N
@pytest.mark.asyncio
async def test_funnel_reads_from_cache_when_available(session: AsyncSession) -> None:
org_id, _, _ = await _seed_data(session)
cache = InMemoryRedis()
cache = cast(Redis, InMemoryRedis())
service = AnalyticsService(
repository=AnalyticsRepository(session),
cache=cache,
+1 -1
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
from typing import cast
from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found]
import pytest
from app.core.security import JWTService, PasswordHasher
from app.models.user import User
from app.repositories.user_repo import UserRepository
+2 -2
View File
@@ -6,8 +6,8 @@ import uuid
from collections.abc import AsyncGenerator
from decimal import Decimal
import pytest # type: ignore[import-not-found]
import pytest_asyncio # type: ignore[import-not-found]
import pytest
import pytest_asyncio
from app.models.activity import Activity, ActivityType
from app.models.base import Base
from app.models.contact import Contact
+1 -1
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
from typing import cast
from unittest.mock import MagicMock
import pytest # type: ignore[import-not-found]
import pytest
from app.models.organization import Organization
from app.models.organization_member import OrganizationMember, OrganizationRole
from app.repositories.org_repo import OrganizationRepository