refactor(backend): migrate type checking from mypy to pyright
Replace mypy>=1.8.0 with pyright>=1.1.390. Remove all [tool.mypy] and [tool.pydantic-mypy] sections from pyproject.toml and add pyrightconfig.json (standard mode, SQLAlchemy false-positive rules suppressed globally). Fixes surfaced by pyright: - Remove unreachable except AuthError clauses in login/login_oauth (same class as AuthenticationError) - Fix Pydantic v2 list Field: min_items/max_items → min_length/max_length - Split OAuthProviderConfig TypedDict into required + optional(email_url) inheritance - Move JWTError/ExpiredSignatureError from lazy try-block imports to module level - Add timezone-aware guard to UserSession.is_expired to match sibling models - Fix is_active: bool → bool | None in three organization repo signatures - Initialize search_filter = None before conditional block (possibly unbound fix) - Add bool() casts to model is_expired and repo is_active/is_superuser returns - Restructure except (JWTError, Exception) into separate except clauses
This commit is contained in:
@@ -65,7 +65,7 @@ class BulkUserAction(BaseModel):
|
||||
|
||||
action: BulkAction = Field(..., description="Action to perform on selected users")
|
||||
user_ids: list[UUID] = Field(
|
||||
..., min_items=1, max_items=100, description="List of user IDs (max 100)"
|
||||
..., min_length=1, max_length=100, description="List of user IDs (max 100)"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -183,9 +183,6 @@ async def login(
|
||||
# Handle specific authentication errors like inactive accounts
|
||||
logger.warning(f"Authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error during login: {e!s}", exc_info=True)
|
||||
@@ -232,9 +229,6 @@ async def login_oauth(
|
||||
except AuthenticationError as e:
|
||||
logger.warning(f"OAuth authentication failed: {e!s}")
|
||||
raise AuthError(message=str(e), error_code=ErrorCode.INVALID_CREDENTIALS)
|
||||
except AuthError:
|
||||
# Re-raise custom auth exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during OAuth login: {e!s}", exc_info=True)
|
||||
raise DatabaseError(
|
||||
|
||||
@@ -655,7 +655,7 @@ async def introspect(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token introspection error: {e}")
|
||||
return OAuthTokenIntrospectionResponse(active=False)
|
||||
return OAuthTokenIntrospectionResponse(active=False) # pyright: ignore[reportCallIssue]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -222,7 +222,7 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
||||
)
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail))]
|
||||
errors=[ErrorDetail(code=error_code, message=str(exc.detail), field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -254,7 +254,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONR
|
||||
message = f"{type(exc).__name__}: {exc!s}"
|
||||
|
||||
error_response = ErrorResponse(
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message)]
|
||||
errors=[ErrorDetail(code=ErrorCode.INTERNAL_ERROR, message=message, field=None)]
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
|
||||
@@ -92,7 +92,7 @@ class OAuthAuthorizationCode(Base, UUIDMixin, TimestampMixin):
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return now > expires_at
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
@@ -99,7 +99,7 @@ class OAuthProviderRefreshToken(Base, UUIDMixin, TimestampMixin):
|
||||
# Handle both timezone-aware and naive datetimes from DB
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return now > expires_at
|
||||
return bool(now > expires_at)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
|
||||
@@ -76,7 +76,11 @@ class UserSession(Base, UUIDMixin, TimestampMixin):
|
||||
"""Check if session has expired."""
|
||||
from datetime import datetime
|
||||
|
||||
return self.expires_at < datetime.now(UTC)
|
||||
now = datetime.now(UTC)
|
||||
expires_at = self.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
return bool(expires_at < now)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert session to dictionary for serialization."""
|
||||
|
||||
@@ -174,6 +174,7 @@ class OrganizationRepository(
|
||||
if is_active is not None:
|
||||
query = query.where(Organization.is_active == is_active)
|
||||
|
||||
search_filter = None
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Organization.name.ilike(f"%{search}%"),
|
||||
@@ -185,7 +186,7 @@ class OrganizationRepository(
|
||||
count_query = select(func.count(Organization.id))
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Organization.is_active == is_active)
|
||||
if search:
|
||||
if search_filter is not None:
|
||||
count_query = count_query.where(search_filter)
|
||||
|
||||
count_result = await db.execute(count_query)
|
||||
@@ -333,7 +334,7 @@ class OrganizationRepository(
|
||||
organization_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool = True,
|
||||
is_active: bool | None = True,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get members of an organization with user details."""
|
||||
try:
|
||||
@@ -387,7 +388,7 @@ class OrganizationRepository(
|
||||
raise
|
||||
|
||||
async def get_user_organizations(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[Organization]:
|
||||
"""Get all organizations a user belongs to."""
|
||||
try:
|
||||
@@ -410,7 +411,7 @@ class OrganizationRepository(
|
||||
raise
|
||||
|
||||
async def get_user_organizations_with_details(
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool = True
|
||||
self, db: AsyncSession, *, user_id: UUID, is_active: bool | None = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get user's organizations with role and member count in SINGLE QUERY."""
|
||||
try:
|
||||
@@ -476,7 +477,7 @@ class OrganizationRepository(
|
||||
)
|
||||
user_org = result.scalar_one_or_none()
|
||||
|
||||
return user_org.role if user_org else None
|
||||
return user_org.role if user_org else None # pyright: ignore[reportReturnType]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user role in org: {e!s}")
|
||||
raise
|
||||
|
||||
@@ -256,11 +256,11 @@ class UserRepository(BaseRepository[User, UserCreate, UserUpdate]):
|
||||
|
||||
def is_active(self, user: User) -> bool:
|
||||
"""Check if user is active."""
|
||||
return user.is_active
|
||||
return bool(user.is_active)
|
||||
|
||||
def is_superuser(self, user: User) -> bool:
|
||||
"""Check if user is a superuser."""
|
||||
return user.is_superuser
|
||||
return bool(user.is_superuser)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
|
||||
@@ -48,7 +48,7 @@ class OrganizationCreate(OrganizationBase):
|
||||
"""Schema for creating a new organization."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255) # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
|
||||
|
||||
class OrganizationUpdate(BaseModel):
|
||||
|
||||
@@ -25,7 +25,8 @@ from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from jose import jwt
|
||||
from jose import JWTError, jwt
|
||||
from jose.exceptions import ExpiredSignatureError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -677,8 +678,6 @@ async def revoke_token(
|
||||
# Try as access token (JWT)
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
from jose.exceptions import JWTError
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
@@ -700,7 +699,9 @@ async def revoke_token(
|
||||
f"Revoked refresh token via access token JTI {jti[:8]}..."
|
||||
)
|
||||
return True
|
||||
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT not an error
|
||||
except JWTError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT not an error
|
||||
pass
|
||||
|
||||
return False
|
||||
@@ -791,8 +792,6 @@ async def introspect_token(
|
||||
# Try as access token (JWT) first
|
||||
if token_type_hint != "refresh_token":
|
||||
try:
|
||||
from jose.exceptions import ExpiredSignatureError, JWTError
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.SECRET_KEY,
|
||||
@@ -823,7 +822,9 @@ async def introspect_token(
|
||||
}
|
||||
except ExpiredSignatureError:
|
||||
return {"active": False}
|
||||
except (JWTError, Exception): # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
except JWTError:
|
||||
pass
|
||||
except Exception: # noqa: S110 - Intentional: invalid JWT falls through to refresh token check
|
||||
pass
|
||||
|
||||
# Try as refresh token
|
||||
|
||||
@@ -39,19 +39,22 @@ from app.schemas.oauth import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthProviderConfig(TypedDict, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
class _OAuthProviderConfigRequired(TypedDict):
|
||||
name: str
|
||||
icon: str
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
userinfo_url: str
|
||||
email_url: str # Optional, GitHub-only
|
||||
scopes: list[str]
|
||||
supports_pkce: bool
|
||||
|
||||
|
||||
class OAuthProviderConfig(_OAuthProviderConfigRequired, total=False):
|
||||
"""Type definition for OAuth provider configuration."""
|
||||
|
||||
email_url: str # Optional, GitHub-only
|
||||
|
||||
|
||||
# Provider configurations
|
||||
OAUTH_PROVIDERS: dict[str, OAuthProviderConfig] = {
|
||||
"google": {
|
||||
@@ -485,7 +488,7 @@ class OAuthService:
|
||||
# GitHub requires separate request for email
|
||||
if provider == "github" and not user_info.get("email"):
|
||||
email_resp = await client.get(
|
||||
config["email_url"],
|
||||
config["email_url"], # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
headers=headers,
|
||||
)
|
||||
email_resp.raise_for_status()
|
||||
|
||||
@@ -65,10 +65,10 @@ async def setup_async_test_db():
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
AsyncTestingSessionLocal = sessionmaker(
|
||||
AsyncTestingSessionLocal = sessionmaker( # pyright: ignore[reportCallIssue]
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=test_engine,
|
||||
bind=test_engine, # pyright: ignore[reportArgumentType]
|
||||
expire_on_commit=False,
|
||||
class_=AsyncSession,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user