diff --git a/backend/app/api/routes/admin.py b/backend/app/api/routes/admin.py index 5a2e8a3..18aa014 100755 --- a/backend/app/api/routes/admin.py +++ b/backend/app/api/routes/admin.py @@ -34,7 +34,7 @@ from app.schemas.common import ( SortParams, create_pagination_meta ) -from app.core.exceptions import NotFoundError, ErrorCode +from app.core.exceptions import NotFoundError, DuplicateError, AuthorizationError, ErrorCode logger = logging.getLogger(__name__) @@ -231,8 +231,9 @@ async def admin_delete_user( # Prevent deleting yourself if user.id == admin.id: - raise NotFoundError( - detail="Cannot delete your own account", + # Use AuthorizationError for permission/operation restrictions + raise AuthorizationError( + message="Cannot delete your own account", error_code=ErrorCode.OPERATION_FORBIDDEN ) @@ -310,8 +311,9 @@ async def admin_deactivate_user( # Prevent deactivating yourself if user.id == admin.id: - raise NotFoundError( - detail="Cannot deactivate your own account", + # Use AuthorizationError for permission/operation restrictions + raise AuthorizationError( + message="Cannot deactivate your own account", error_code=ErrorCode.OPERATION_FORBIDDEN ) @@ -416,19 +418,21 @@ async def admin_list_organizations( ) -> Any: """List all organizations with filtering and search.""" try: - orgs, total = await organization_crud.get_multi_with_filters( + # Use optimized method that gets member counts in single query (no N+1) + orgs_with_data, total = await organization_crud.get_multi_with_member_counts( db, skip=pagination.offset, limit=pagination.limit, is_active=is_active, - search=search, - sort_by="created_at", - sort_order="desc" + search=search ) - # Add member count to each organization + # Build response objects from optimized query results orgs_with_count = [] - for org in orgs: + for item in orgs_with_data: + org = item['organization'] + member_count = item['member_count'] + org_dict = { "id": org.id, "name": org.name, @@ -438,7 +442,7 @@ async def admin_list_organizations( "settings": org.settings, "created_at": org.created_at, "updated_at": org.updated_at, - "member_count": await organization_crud.get_member_count(db, organization_id=org.id) + "member_count": member_count } orgs_with_count.append(OrganizationResponse(**org_dict)) @@ -718,7 +722,12 @@ async def admin_add_organization_member( except ValueError as e: logger.warning(f"Failed to add user to organization: {str(e)}") - raise NotFoundError(detail=str(e), error_code=ErrorCode.ALREADY_EXISTS) + # Use DuplicateError for "already exists" scenarios + raise DuplicateError( + message=str(e), + error_code=ErrorCode.USER_ALREADY_EXISTS, + field="user_id" + ) except NotFoundError: raise except Exception as e: diff --git a/backend/app/core/auth.py b/backend/app/core/auth.py index 21ddaf1..daabf98 100644 --- a/backend/app/core/auth.py +++ b/backend/app/core/auth.py @@ -141,12 +141,31 @@ def decode_token(token: str, verify_type: Optional[str] = None) -> TokenPayload: TokenMissingClaimError: If a required claim is missing """ try: + # Decode token with strict algorithm validation payload = jwt.decode( token, settings.SECRET_KEY, - algorithms=[settings.ALGORITHM] + algorithms=[settings.ALGORITHM], + options={ + "verify_signature": True, + "verify_exp": True, + "verify_iat": True, + "require": ["exp", "sub", "iat"] + } ) + # SECURITY: Explicitly verify the algorithm to prevent algorithm confusion attacks + # Decode header to check algorithm (without verification, just to inspect) + header = jwt.get_unverified_header(token) + token_algorithm = header.get("alg", "").upper() + + # Reject weak or unexpected algorithms + if token_algorithm == "NONE": + raise TokenInvalidError("Algorithm 'none' is not allowed") + + if token_algorithm != settings.ALGORITHM.upper(): + raise TokenInvalidError(f"Invalid algorithm: {token_algorithm}") + # Check required claims before Pydantic validation if not payload.get("sub"): raise TokenMissingClaimError("Token missing 'sub' claim") diff --git a/backend/app/crud/organization_async.py b/backend/app/crud/organization_async.py index c92f3be..5ab93f5 100755 --- a/backend/app/crud/organization_async.py +++ b/backend/app/crud/organization_async.py @@ -130,6 +130,83 @@ class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, Orga logger.error(f"Error getting member count for organization {organization_id}: {str(e)}") raise + async def get_multi_with_member_counts( + self, + db: AsyncSession, + *, + skip: int = 0, + limit: int = 100, + is_active: Optional[bool] = None, + search: Optional[str] = None + ) -> tuple[List[Dict[str, Any]], int]: + """ + Get organizations with member counts in a SINGLE QUERY using JOIN and GROUP BY. + This eliminates the N+1 query problem. + + Returns: + Tuple of (list of dicts with org and member_count, total count) + """ + try: + # Build base query with LEFT JOIN and GROUP BY + query = ( + select( + Organization, + func.count( + func.distinct( + and_( + UserOrganization.is_active == True, + UserOrganization.user_id + ).self_group() + ) + ).label('member_count') + ) + .outerjoin(UserOrganization, Organization.id == UserOrganization.organization_id) + .group_by(Organization.id) + ) + + # Apply filters + if is_active is not None: + query = query.where(Organization.is_active == is_active) + + if search: + search_filter = or_( + Organization.name.ilike(f"%{search}%"), + Organization.slug.ilike(f"%{search}%"), + Organization.description.ilike(f"%{search}%") + ) + query = query.where(search_filter) + + # Get total count + count_query = select(func.count(Organization.id)) + if is_active is not None: + count_query = count_query.where(Organization.is_active == is_active) + if search: + count_query = count_query.where(search_filter) + + count_result = await db.execute(count_query) + total = count_result.scalar_one() + + # Apply pagination and ordering + query = query.order_by(Organization.created_at.desc()).offset(skip).limit(limit) + + result = await db.execute(query) + rows = result.all() + + # Convert to list of dicts + orgs_with_counts = [ + { + 'organization': org, + 'member_count': member_count + } + for org, member_count in rows + ] + + return orgs_with_counts, total + + except Exception as e: + logger.error(f"Error getting organizations with member counts: {str(e)}", exc_info=True) + raise + async def add_user( self, db: AsyncSession, @@ -332,6 +409,63 @@ class CRUDOrganizationAsync(CRUDBaseAsync[Organization, OrganizationCreate, Orga logger.error(f"Error getting user organizations: {str(e)}") raise + async def get_user_organizations_with_details( + self, + db: AsyncSession, + *, + user_id: UUID, + is_active: bool = True + ) -> List[Dict[str, Any]]: + """ + Get user's organizations with role and member count in SINGLE QUERY. + Eliminates N+1 problem by using subquery for member counts. + + Returns: + List of dicts with organization, role, and member_count + """ + try: + # Subquery to get member counts for each organization + member_count_subq = ( + select( + UserOrganization.organization_id, + func.count(UserOrganization.user_id).label('member_count') + ) + .where(UserOrganization.is_active == True) + .group_by(UserOrganization.organization_id) + .subquery() + ) + + # Main query with JOIN to get org, role, and member count + query = ( + select( + Organization, + UserOrganization.role, + func.coalesce(member_count_subq.c.member_count, 0).label('member_count') + ) + .join(UserOrganization, Organization.id == UserOrganization.organization_id) + .outerjoin(member_count_subq, Organization.id == member_count_subq.c.organization_id) + .where(UserOrganization.user_id == user_id) + ) + + if is_active is not None: + query = query.where(UserOrganization.is_active == is_active) + + result = await db.execute(query) + rows = result.all() + + return [ + { + 'organization': org, + 'role': role, + 'member_count': member_count + } + for org, role, member_count in rows + ] + + except Exception as e: + logger.error(f"Error getting user organizations with details: {str(e)}", exc_info=True) + raise + async def get_user_role_in_org( self, db: AsyncSession, diff --git a/backend/app/schemas/errors.py b/backend/app/schemas/errors.py index 5d84c72..90b9f33 100644 --- a/backend/app/schemas/errors.py +++ b/backend/app/schemas/errors.py @@ -16,6 +16,7 @@ class ErrorCode(str, Enum): INSUFFICIENT_PERMISSIONS = "AUTH_004" USER_INACTIVE = "AUTH_005" AUTHENTICATION_REQUIRED = "AUTH_006" + OPERATION_FORBIDDEN = "AUTH_007" # Operation not allowed for this user/role # User errors (USER_xxx) USER_NOT_FOUND = "USER_001" @@ -43,6 +44,7 @@ class ErrorCode(str, Enum): NOT_FOUND = "SYS_002" METHOD_NOT_ALLOWED = "SYS_003" RATE_LIMIT_EXCEEDED = "SYS_004" + ALREADY_EXISTS = "SYS_005" # Generic resource already exists error class ErrorDetail(BaseModel): diff --git a/backend/app/schemas/users.py b/backend/app/schemas/users.py index 3b05117..7fd7ab6 100644 --- a/backend/app/schemas/users.py +++ b/backend/app/schemas/users.py @@ -6,6 +6,8 @@ from uuid import UUID from pydantic import BaseModel, EmailStr, field_validator, ConfigDict, Field +from app.schemas.validators import validate_password_strength, validate_phone_number + class UserBase(BaseModel): email: EmailStr @@ -15,13 +17,8 @@ class UserBase(BaseModel): @field_validator('phone_number') @classmethod - def validate_phone_number(cls, v: Optional[str]) -> Optional[str]: - if v is None: - return v - # Simple regex for phone validation - if not re.match(r'^\+?[0-9\s\-\(\)]{8,20}$', v): - raise ValueError('Invalid phone number format') - return v + def validate_phone(cls, v: Optional[str]) -> Optional[str]: + return validate_phone_number(v) class UserCreate(UserBase): @@ -31,14 +28,8 @@ class UserCreate(UserBase): @field_validator('password') @classmethod def password_strength(cls, v: str) -> str: - """Basic password strength validation""" - if len(v) < 8: - raise ValueError('Password must be at least 8 characters') - if not any(char.isdigit() for char in v): - raise ValueError('Password must contain at least one digit') - if not any(char.isupper() for char in v): - raise ValueError('Password must contain at least one uppercase letter') - return v + """Enterprise-grade password strength validation""" + return validate_password_strength(v) class UserUpdate(BaseModel): @@ -46,39 +37,12 @@ class UserUpdate(BaseModel): last_name: Optional[str] = None phone_number: Optional[str] = None preferences: Optional[Dict[str, Any]] = None - is_active: Optional[bool] = True + is_active: Optional[bool] = None # Changed default from True to None to avoid unintended updates + @field_validator('phone_number') - def validate_phone_number(cls, v: Optional[str]) -> Optional[str]: - if v is None: - return v - - # Return early for empty strings or whitespace-only strings - if not v or v.strip() == "": - raise ValueError('Phone number cannot be empty') - - # Remove all spaces and formatting characters - cleaned = re.sub(r'[\s\-\(\)]', '', v) - - # Basic pattern: - # Must start with + or 0 - # After + must have at least 8 digits - # After 0 must have at least 8 digits - # Maximum total length of 15 digits (international standard) - # Only allowed characters are + at start and digits - pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$' - - if not re.match(pattern, cleaned): - raise ValueError('Phone number must start with + or 0 followed by 8-14 digits') - - # Additional validation to catch specific invalid cases - if cleaned.count('+') > 1: - raise ValueError('Phone number can only contain one + symbol at the start') - - # Check for any non-digit characters (except the leading +) - if not all(c.isdigit() for c in cleaned[1:]): - raise ValueError('Phone number can only contain digits after the prefix') - - return cleaned + @classmethod + def validate_phone(cls, v: Optional[str]) -> Optional[str]: + return validate_phone_number(v) class UserInDB(UserBase): @@ -131,14 +95,8 @@ class PasswordChange(BaseModel): @field_validator('new_password') @classmethod def password_strength(cls, v: str) -> str: - """Basic password strength validation""" - if len(v) < 8: - raise ValueError('Password must be at least 8 characters') - if not any(char.isdigit() for char in v): - raise ValueError('Password must contain at least one digit') - if not any(char.isupper() for char in v): - raise ValueError('Password must contain at least one uppercase letter') - return v + """Enterprise-grade password strength validation""" + return validate_password_strength(v) class PasswordReset(BaseModel): @@ -149,14 +107,8 @@ class PasswordReset(BaseModel): @field_validator('new_password') @classmethod def password_strength(cls, v: str) -> str: - """Basic password strength validation""" - if len(v) < 8: - raise ValueError('Password must be at least 8 characters') - if not any(char.isdigit() for char in v): - raise ValueError('Password must contain at least one digit') - if not any(char.isupper() for char in v): - raise ValueError('Password must contain at least one uppercase letter') - return v + """Enterprise-grade password strength validation""" + return validate_password_strength(v) class LoginRequest(BaseModel): @@ -189,14 +141,8 @@ class PasswordResetConfirm(BaseModel): @field_validator('new_password') @classmethod def password_strength(cls, v: str) -> str: - """Basic password strength validation""" - if len(v) < 8: - raise ValueError('Password must be at least 8 characters') - if not any(char.isdigit() for char in v): - raise ValueError('Password must contain at least one digit') - if not any(char.isupper() for char in v): - raise ValueError('Password must contain at least one uppercase letter') - return v + """Enterprise-grade password strength validation""" + return validate_password_strength(v) model_config = { "json_schema_extra": { diff --git a/backend/app/schemas/validators.py b/backend/app/schemas/validators.py new file mode 100644 index 0000000..deac048 --- /dev/null +++ b/backend/app/schemas/validators.py @@ -0,0 +1,183 @@ +""" +Shared validators for Pydantic schemas. + +This module provides reusable validation functions to ensure consistency +across all schemas and avoid code duplication. +""" +import re +from typing import Set + +# Common weak passwords that should be rejected +COMMON_PASSWORDS: Set[str] = { + 'password', 'password1', 'password123', 'password1234', + 'admin', 'admin123', 'admin1234', + 'welcome', 'welcome1', 'welcome123', + 'qwerty', 'qwerty123', + '12345678', '123456789', '1234567890', + 'letmein', 'letmein1', 'letmein123', + 'monkey123', 'dragon123', + 'passw0rd', 'p@ssw0rd', 'p@ssword', +} + + +def validate_password_strength(password: str) -> str: + """ + Validate password strength with enterprise-grade requirements. + + Requirements: + - Minimum 12 characters (increased from 8 for better security) + - At least one lowercase letter + - At least one uppercase letter + - At least one digit + - At least one special character + - Not in common password list + + Args: + password: The password to validate + + Returns: + The validated password + + Raises: + ValueError: If password doesn't meet requirements + + Examples: + >>> validate_password_strength("MySecureP@ss123") # Valid + >>> validate_password_strength("password1") # Invalid - too weak + """ + # Check minimum length + if len(password) < 12: + raise ValueError('Password must be at least 12 characters long') + + # Check against common passwords (case-insensitive) + if password.lower() in COMMON_PASSWORDS: + raise ValueError('Password is too common. Please choose a stronger password') + + # Check for required character types + checks = [ + (any(c.islower() for c in password), 'at least one lowercase letter'), + (any(c.isupper() for c in password), 'at least one uppercase letter'), + (any(c.isdigit() for c in password), 'at least one digit'), + (any(c in '!@#$%^&*()_+-=[]{}|;:,.<>?~`' for c in password), 'at least one special character (!@#$%^&*()_+-=[]{}|;:,.<>?~`)') + ] + + failed = [msg for check, msg in checks if not check] + if failed: + raise ValueError(f"Password must contain {', '.join(failed)}") + + return password + + +def validate_phone_number(phone: str | None) -> str | None: + """ + Validate phone number format. + + Accepts international format with + prefix or local format with 0 prefix. + Removes formatting characters (spaces, hyphens, parentheses). + + Args: + phone: Phone number to validate (can be None) + + Returns: + Cleaned phone number or None + + Raises: + ValueError: If phone number format is invalid + + Examples: + >>> validate_phone_number("+1 (555) 123-4567") # Valid + >>> validate_phone_number("0412 345 678") # Valid + >>> validate_phone_number("invalid") # Invalid + """ + if phone is None: + return None + + # Check for empty strings + if not phone or phone.strip() == "": + raise ValueError('Phone number cannot be empty') + + # Remove all spaces and formatting characters + cleaned = re.sub(r'[\s\-\(\)]', '', phone) + + # Basic pattern: + # Must start with + or 0 + # After + must have at least 8 digits + # After 0 must have at least 8 digits + # Maximum total length of 15 digits (international standard) + # Only allowed characters are + at start and digits + pattern = r'^(?:\+[0-9]{8,14}|0[0-9]{8,14})$' + + if not re.match(pattern, cleaned): + raise ValueError('Phone number must start with + or 0 followed by 8-14 digits') + + # Additional validation to catch specific invalid cases + if cleaned.count('+') > 1: + raise ValueError('Phone number can only contain one + symbol at the start') + + # Check for any non-digit characters (except the leading +) + if not all(c.isdigit() for c in cleaned[1:]): + raise ValueError('Phone number can only contain digits after the prefix') + + return cleaned + + +def validate_email_format(email: str) -> str: + """ + Additional email validation beyond Pydantic's EmailStr. + + This can be extended for custom email validation rules. + + Args: + email: Email address to validate + + Returns: + Validated email address + + Raises: + ValueError: If email format is invalid + """ + # Pydantic's EmailStr already does comprehensive validation + # This function is here for custom rules if needed + + # Example: Reject disposable email domains (optional) + # disposable_domains = {'tempmail.com', '10minutemail.com', 'guerrillamail.com'} + # domain = email.split('@')[1].lower() + # if domain in disposable_domains: + # raise ValueError('Disposable email addresses are not allowed') + + return email.lower() # Normalize to lowercase + + +def validate_slug(slug: str) -> str: + """ + Validate URL slug format. + + Slugs must: + - Be 2-50 characters long + - Contain only lowercase letters, numbers, and hyphens + - Not start or end with a hyphen + - Not contain consecutive hyphens + + Args: + slug: URL slug to validate + + Returns: + Validated slug + + Raises: + ValueError: If slug format is invalid + """ + if not slug or len(slug) < 2: + raise ValueError('Slug must be at least 2 characters long') + + if len(slug) > 50: + raise ValueError('Slug must be at most 50 characters long') + + # Check format + if not re.match(r'^[a-z0-9]+(?:-[a-z0-9]+)*$', slug): + raise ValueError( + 'Slug can only contain lowercase letters, numbers, and hyphens. ' + 'It cannot start or end with a hyphen, and cannot contain consecutive hyphens' + ) + + return slug