refactor: enhance type hinting and casting for improved type safety across multiple files
This commit is contained in:
+20
-14
@@ -6,7 +6,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import redis.asyncio as redis
|
||||
from app.core.config import settings
|
||||
@@ -45,10 +45,13 @@ class RedisCacheManager:
|
||||
async with self._lock:
|
||||
if self._client is not None:
|
||||
return
|
||||
self._client = redis.from_url(
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
self._client = cast(
|
||||
Redis,
|
||||
redis.from_url( # type: ignore[no-untyped-call]
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
),
|
||||
)
|
||||
await self._refresh_availability()
|
||||
|
||||
@@ -64,10 +67,13 @@ class RedisCacheManager:
|
||||
return
|
||||
async with self._lock:
|
||||
if self._client is None:
|
||||
self._client = redis.from_url(
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
self._client = cast(
|
||||
Redis,
|
||||
redis.from_url( # type: ignore[no-untyped-call]
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False,
|
||||
),
|
||||
)
|
||||
await self._refresh_availability()
|
||||
|
||||
@@ -76,7 +82,7 @@ class RedisCacheManager:
|
||||
self._available = False
|
||||
return
|
||||
try:
|
||||
await self._client.ping()
|
||||
await cast(Awaitable[Any], self._client.ping())
|
||||
except RedisError as exc: # pragma: no cover - logging only
|
||||
self._available = False
|
||||
logger.warning("Redis ping failed: %s", exc)
|
||||
@@ -140,8 +146,8 @@ async def write_json(
|
||||
"""Serialize data to JSON and store it with TTL using retry/backoff."""
|
||||
payload = json.dumps(value, separators=(",", ":"), ensure_ascii=True).encode("utf-8")
|
||||
|
||||
async def _operation() -> Any:
|
||||
return await client.set(name=key, value=payload, ex=ttl_seconds)
|
||||
async def _operation() -> None:
|
||||
await client.set(name=key, value=payload, ex=ttl_seconds)
|
||||
|
||||
await _run_with_retry(_operation, backoff_ms)
|
||||
|
||||
@@ -151,8 +157,8 @@ async def delete_keys(client: Redis, keys: list[str], backoff_ms: int) -> None:
|
||||
if not keys:
|
||||
return
|
||||
|
||||
async def _operation() -> Any:
|
||||
return await client.delete(*keys)
|
||||
async def _operation() -> None:
|
||||
await client.delete(*keys)
|
||||
|
||||
await _run_with_retry(_operation, backoff_ms)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
from app.core.config import settings
|
||||
@@ -18,10 +18,10 @@ class PasswordHasher:
|
||||
self._context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
|
||||
def hash(self, password: str) -> str:
|
||||
return self._context.hash(password)
|
||||
return cast(str, self._context.hash(password))
|
||||
|
||||
def verify(self, password: str, hashed_password: str) -> bool:
|
||||
return self._context.verify(password, hashed_password)
|
||||
return bool(self._context.verify(password, hashed_password))
|
||||
|
||||
|
||||
class JWTService:
|
||||
@@ -45,10 +45,10 @@ class JWTService:
|
||||
}
|
||||
if claims:
|
||||
payload.update(claims)
|
||||
return jwt.encode(payload, self._secret_key, algorithm=self._algorithm)
|
||||
return cast(str, jwt.encode(payload, self._secret_key, algorithm=self._algorithm))
|
||||
|
||||
def decode(self, token: str) -> dict[str, Any]:
|
||||
return jwt.decode(token, self._secret_key, algorithms=[self._algorithm])
|
||||
return cast(dict[str, Any], jwt.decode(token, self._secret_key, algorithms=[self._algorithm]))
|
||||
|
||||
|
||||
password_hasher = PasswordHasher()
|
||||
|
||||
@@ -10,9 +10,11 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, func, text
|
||||
from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.engine import Dialect
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.types import JSON as SA_JSON
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from app.models.base import Base, enum_values
|
||||
|
||||
@@ -31,7 +33,7 @@ class JSONBCompat(TypeDecorator):
|
||||
impl = JSONB
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # type: ignore[override]
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON # local import
|
||||
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ class Base(DeclarativeBase):
|
||||
"""Base class that configures naming conventions."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(cls) -> str: # type: ignore[misc] # noqa: N805 - SQLAlchemy expects cls
|
||||
def __tablename__(cls) -> str: # noqa: N805 - SQLAlchemy expects cls
|
||||
return cls.__name__.lower()
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import Select, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -123,6 +123,9 @@ class TaskRepository:
|
||||
|
||||
async def _resolve_task_owner(self, task: Task) -> int | None:
|
||||
if task.deal is not None:
|
||||
return task.deal.owner_id
|
||||
return int(task.deal.owner_id)
|
||||
stmt = select(Deal.owner_id).where(Deal.id == task.deal_id)
|
||||
return await self._session.scalar(stmt)
|
||||
owner_id_raw: Any = await self._session.scalar(stmt)
|
||||
if owner_id_raw is None:
|
||||
return None
|
||||
return cast(int, owner_id_raw)
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -151,11 +152,11 @@ class ContactService:
|
||||
def _build_update_mapping(self, updates: ContactUpdateData) -> dict[str, str | None]:
|
||||
payload: dict[str, str | None] = {}
|
||||
if updates.name is not UNSET:
|
||||
payload["name"] = updates.name
|
||||
payload["name"] = cast(str | None, updates.name)
|
||||
if updates.email is not UNSET:
|
||||
payload["email"] = updates.email
|
||||
payload["email"] = cast(str | None, updates.email)
|
||||
if updates.phone is not UNSET:
|
||||
payload["phone"] = updates.phone
|
||||
payload["phone"] = cast(str | None, updates.phone)
|
||||
return payload
|
||||
|
||||
async def _ensure_no_related_deals(self, contact_id: int) -> None:
|
||||
|
||||
Reference in New Issue
Block a user