Refactor(backend): improve formatting in services, repositories & tests

- Consistently format multi-line function headers, exception handling, and repository method calls for readability.
- Reorganize misplaced imports across modules (e.g., services & tests) into proper sorted order.
- Adjust indentation, line breaks, and spacing inconsistencies in tests and migration files.
- Cleanup unnecessary trailing newlines and reorganize `__all__` declarations for consistency.
This commit is contained in:
2026-02-28 18:37:56 +01:00
parent 98b455fdc3
commit 4c6bf55bcc
38 changed files with 567 additions and 337 deletions

View File

@@ -9,11 +9,11 @@ from .user_service import UserService, user_service
__all__ = [
"AuthService",
"OAuthService",
"UserService",
"OrganizationService",
"SessionService",
"UserService",
"oauth_provider_service",
"user_service",
"organization_service",
"session_service",
"user_service",
]

View File

@@ -30,13 +30,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.oauth_client import OAuthClient
from app.schemas.oauth import OAuthClientCreate
from app.models.user import User
from app.repositories.oauth_authorization_code import oauth_authorization_code_repo
from app.repositories.oauth_client import oauth_client_repo
from app.repositories.oauth_consent import oauth_consent_repo
from app.repositories.oauth_provider_token import oauth_provider_token_repo
from app.repositories.user import user_repo
from app.schemas.oauth import OAuthClientCreate
logger = logging.getLogger(__name__)
@@ -691,9 +691,7 @@ async def revoke_token(
jti = payload.get("jti")
if jti:
# Find and revoke the associated refresh token
refresh_record = await oauth_provider_token_repo.get_by_jti(
db, jti=jti
)
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
if refresh_record:
if client_id and refresh_record.client_id != client_id:
raise InvalidClientError("Token was not issued to this client")
@@ -807,9 +805,7 @@ async def introspect_token(
# Check if associated refresh token is revoked
jti = payload.get("jti")
if jti:
refresh_record = await oauth_provider_token_repo.get_by_jti(
db, jti=jti
)
refresh_record = await oauth_provider_token_repo.get_by_jti(db, jti=jti)
if refresh_record and refresh_record.revoked:
return {"active": False}
@@ -862,7 +858,9 @@ async def get_consent(
client_id: str,
):
"""Get existing consent record for user-client pair."""
return await oauth_consent_repo.get_consent(db, user_id=user_id, client_id=client_id)
return await oauth_consent_repo.get_consent(
db, user_id=user_id, client_id=client_id
)
async def check_consent(

View File

@@ -24,9 +24,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import create_access_token, create_refresh_token
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.models.user import User
from app.repositories.oauth_account import oauth_account_repo as oauth_account
from app.repositories.oauth_state import oauth_state_repo as oauth_state
from app.models.user import User
from app.repositories.user import user_repo
from app.schemas.oauth import (
OAuthAccountCreate,
@@ -344,7 +344,9 @@ class OAuthService:
await oauth_account.update_tokens(
db,
account=existing_oauth,
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600)),
)
@@ -373,7 +375,9 @@ class OAuthService:
provider=provider,
provider_user_id=provider_user_id,
provider_email=provider_email,
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,
@@ -639,7 +643,9 @@ class OAuthService:
provider=provider,
provider_user_id=provider_user_id,
provider_email=email,
access_token=token.get("access_token"), refresh_token=token.get("refresh_token"), token_expires_at=datetime.now(UTC)
access_token=token.get("access_token"),
refresh_token=token.get("refresh_token"),
token_expires_at=datetime.now(UTC)
+ timedelta(seconds=token.get("expires_in", 3600))
if token.get("expires_in")
else None,

View File

@@ -51,9 +51,7 @@ class OrganizationService:
"""Permanently delete an organization by ID."""
await self._repo.remove(db, id=org_id)
async def get_member_count(
self, db: AsyncSession, *, organization_id: UUID
) -> int:
async def get_member_count(self, db: AsyncSession, *, organization_id: UUID) -> int:
"""Get number of active members in an organization."""
return await self._repo.get_member_count(db, organization_id=organization_id)

View File

@@ -25,7 +25,9 @@ class SessionService:
"""Create a new session record."""
return await self._repo.create_session(db, obj_in=obj_in)
async def get_session(self, db: AsyncSession, session_id: str) -> UserSession | None:
async def get_session(
self, db: AsyncSession, session_id: str
) -> UserSession | None:
"""Get session by ID."""
return await self._repo.get(db, id=session_id)
@@ -72,9 +74,7 @@ class SessionService:
db, session=session, new_jti=new_jti, new_expires_at=new_expires_at
)
async def cleanup_expired_for_user(
self, db: AsyncSession, *, user_id: str
) -> int:
async def cleanup_expired_for_user(self, db: AsyncSession, *, user_id: str) -> int:
"""Remove expired sessions for a user. Returns count removed."""
return await self._repo.cleanup_expired_for_user(db, user_id=user_id)

View File

@@ -96,7 +96,9 @@ class UserService:
await db.execute(select(func.count()).select_from(User))
).scalar() or 0
active_count = (
await db.execute(select(func.count()).select_from(User).where(User.is_active))
await db.execute(
select(func.count()).select_from(User).where(User.is_active)
)
).scalar() or 0
inactive_count = (
await db.execute(
@@ -104,9 +106,7 @@ class UserService:
)
).scalar() or 0
all_users = list(
(
await db.execute(select(User).order_by(User.created_at))
).scalars().all()
(await db.execute(select(User).order_by(User.created_at))).scalars().all()
)
return {
"total_users": total_users,