refactor: enhance type hinting and casting for improved type safety across multiple files

This commit is contained in:
Artem Kashaev
2025-12-01 16:44:14 +05:00
parent f234e60e65
commit 688ade0452
14 changed files with 62 additions and 42 deletions
+20 -14
View File
@@ -6,7 +6,7 @@ import asyncio
import json
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from typing import Any, cast
import redis.asyncio as redis
from app.core.config import settings
@@ -45,10 +45,13 @@ class RedisCacheManager:
async with self._lock:
if self._client is not None:
return
self._client = redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=False,
self._client = cast(
Redis,
redis.from_url( # type: ignore[no-untyped-call]
settings.redis_url,
encoding="utf-8",
decode_responses=False,
),
)
await self._refresh_availability()
@@ -64,10 +67,13 @@ class RedisCacheManager:
return
async with self._lock:
if self._client is None:
self._client = redis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=False,
self._client = cast(
Redis,
redis.from_url( # type: ignore[no-untyped-call]
settings.redis_url,
encoding="utf-8",
decode_responses=False,
),
)
await self._refresh_availability()
@@ -76,7 +82,7 @@ class RedisCacheManager:
self._available = False
return
try:
await self._client.ping()
await cast(Awaitable[Any], self._client.ping())
except RedisError as exc: # pragma: no cover - logging only
self._available = False
logger.warning("Redis ping failed: %s", exc)
@@ -140,8 +146,8 @@ async def write_json(
"""Serialize data to JSON and store it with TTL using retry/backoff."""
payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8")
async def _operation() -> Any:
return await client.set(name=key, value=payload, ex=ttl_seconds)
async def _operation() -> None:
await client.set(name=key, value=payload, ex=ttl_seconds)
await _run_with_retry(_operation, backoff_ms)
@@ -151,8 +157,8 @@ async def delete_keys(client: Redis, keys: list[str], backoff_ms: int) -> None:
if not keys:
return
async def _operation() -> Any:
return await client.delete(*keys)
async def _operation() -> None:
await client.delete(*keys)
await _run_with_retry(_operation, backoff_ms)
+5 -5
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import Any, cast
import jwt
from app.core.config import settings
@@ -18,10 +18,10 @@ class PasswordHasher:
self._context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
def hash(self, password: str) -> str:
return self._context.hash(password)
return cast(str, self._context.hash(password))
def verify(self, password: str, hashed_password: str) -> bool:
return self._context.verify(password, hashed_password)
return bool(self._context.verify(password, hashed_password))
class JWTService:
@@ -45,10 +45,10 @@ class JWTService:
}
if claims:
payload.update(claims)
return jwt.encode(payload, self._secret_key, algorithm=self._algorithm)
return cast(str, jwt.encode(payload, self._secret_key, algorithm=self._algorithm))
def decode(self, token: str) -> dict[str, Any]:
return jwt.decode(token, self._secret_key, algorithms=[self._algorithm])
return cast(dict[str, Any], jwt.decode(token, self._secret_key, algorithms=[self._algorithm]))
password_hasher = PasswordHasher()
+3 -1
View File
@@ -10,9 +10,11 @@ from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import DateTime, ForeignKey, Integer, func, text
from sqlalchemy import Enum as SqlEnum
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine import Dialect
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import JSON as SA_JSON
from sqlalchemy.types import TypeDecorator
from sqlalchemy.sql.type_api import TypeEngine
from app.models.base import Base, enum_values
@@ -31,7 +33,7 @@ class JSONBCompat(TypeDecorator):
impl = JSONB
cache_ok = True
def load_dialect_impl(self, dialect): # type: ignore[override]
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "sqlite":
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON # local import
+1 -1
View File
@@ -14,7 +14,7 @@ class Base(DeclarativeBase):
"""Base class that configures naming conventions."""
@declared_attr.directive
def __tablename__(cls) -> str: # type: ignore[misc] # noqa: N805 - SQLAlchemy expects cls
def __tablename__(cls) -> str: # noqa: N805 - SQLAlchemy expects cls
return cls.__name__.lower()
+6 -3
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from typing import Any, cast
from sqlalchemy import Select, select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -123,6 +123,9 @@ class TaskRepository:
async def _resolve_task_owner(self, task: Task) -> int | None:
if task.deal is not None:
return task.deal.owner_id
return int(task.deal.owner_id)
stmt = select(Deal.owner_id).where(Deal.id == task.deal_id)
return await self._session.scalar(stmt)
owner_id_raw: Any = await self._session.scalar(stmt)
if owner_id_raw is None:
return None
return cast(int, owner_id_raw)
+4 -3
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import cast
from sqlalchemy import select
@@ -151,11 +152,11 @@ class ContactService:
def _build_update_mapping(self, updates: ContactUpdateData) -> dict[str, str | None]:
payload: dict[str, str | None] = {}
if updates.name is not UNSET:
payload["name"] = updates.name
payload["name"] = cast(str | None, updates.name)
if updates.email is not UNSET:
payload["email"] = updates.email
payload["email"] = cast(str | None, updates.email)
if updates.phone is not UNSET:
payload["phone"] = updates.phone
payload["phone"] = cast(str | None, updates.phone)
return payload
async def _ensure_no_related_deals(self, contact_id: int) -> None: