Add security headers middleware and tests; improve user model schema
- Added security headers middleware to enforce best practices (e.g., XSS and clickjacking prevention, CSP, HSTS in production). - Updated `User` model schema: refined field constraints and switched `preferences` to `JSONB` for PostgreSQL compatibility. - Introduced tests to validate security headers across endpoints and error responses. - Ensured headers like `X-Frame-Options`, `X-Content-Type-Options`, and `Permissions-Policy` are correctly configured.
This commit is contained in:
@@ -43,6 +43,37 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
# Add security headers middleware
|
||||
@app.middleware("http")
|
||||
async def add_security_headers(request: Request, call_next):
|
||||
"""Add security headers to all responses"""
|
||||
response = await call_next(request)
|
||||
|
||||
# Prevent clickjacking
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# Prevent MIME type sniffing
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# Enable XSS protection
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# Enforce HTTPS in production
|
||||
if settings.ENVIRONMENT == "production":
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
# Content Security Policy
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'; frame-ancestors 'none'"
|
||||
|
||||
# Permissions Policy (formerly Feature Policy)
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Referrer Policy
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
return """
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from sqlalchemy import Column, String, JSON, Boolean
|
||||
from sqlalchemy import Column, String, Boolean
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from .base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
@@ -6,14 +7,14 @@ from .base import Base, TimestampMixin, UUIDMixin
|
||||
class User(Base, UUIDMixin, TimestampMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email = Column(String, unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String, nullable=False)
|
||||
first_name = Column(String, nullable=False, default="user")
|
||||
last_name = Column(String, nullable=True)
|
||||
phone_number = Column(String)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False)
|
||||
preferences = Column(JSON)
|
||||
email = Column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash = Column(String(255), nullable=False)
|
||||
first_name = Column(String(100), nullable=False, default="user")
|
||||
last_name = Column(String(100), nullable=True)
|
||||
phone_number = Column(String(20))
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
is_superuser = Column(Boolean, default=False, nullable=False, index=True)
|
||||
preferences = Column(JSONB)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User {self.email}>"
|
||||
94
backend/tests/api/test_security_headers.py
Normal file
94
backend/tests/api/test_security_headers.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# tests/api/test_security_headers.py
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a FastAPI test client for the main app."""
|
||||
# Mock get_db to avoid database connection issues
|
||||
with patch("app.main.get_db") as mock_get_db:
|
||||
def mock_session_generator():
|
||||
from unittest.mock import MagicMock
|
||||
mock_session = MagicMock()
|
||||
mock_session.execute.return_value = None
|
||||
mock_session.close.return_value = None
|
||||
yield mock_session
|
||||
|
||||
mock_get_db.side_effect = lambda: mock_session_generator()
|
||||
yield TestClient(app)
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Tests for security headers middleware"""
|
||||
|
||||
def test_x_frame_options_header(self, client):
|
||||
"""Test that X-Frame-Options header is set to DENY"""
|
||||
response = client.get("/health")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
|
||||
def test_x_content_type_options_header(self, client):
|
||||
"""Test that X-Content-Type-Options header is set to nosniff"""
|
||||
response = client.get("/health")
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_x_xss_protection_header(self, client):
|
||||
"""Test that X-XSS-Protection header is set"""
|
||||
response = client.get("/health")
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
|
||||
def test_content_security_policy_header(self, client):
|
||||
"""Test that Content-Security-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
assert "Content-Security-Policy" in response.headers
|
||||
assert "default-src 'self'" in response.headers["Content-Security-Policy"]
|
||||
assert "frame-ancestors 'none'" in response.headers["Content-Security-Policy"]
|
||||
|
||||
def test_permissions_policy_header(self, client):
|
||||
"""Test that Permissions-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
assert "Permissions-Policy" in response.headers
|
||||
assert "geolocation=()" in response.headers["Permissions-Policy"]
|
||||
assert "microphone=()" in response.headers["Permissions-Policy"]
|
||||
assert "camera=()" in response.headers["Permissions-Policy"]
|
||||
|
||||
def test_referrer_policy_header(self, client):
|
||||
"""Test that Referrer-Policy header is set"""
|
||||
response = client.get("/health")
|
||||
assert "Referrer-Policy" in response.headers
|
||||
assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
def test_strict_transport_security_not_in_development(self, client):
|
||||
"""Test that Strict-Transport-Security header is not set in development"""
|
||||
from app.core.config import settings
|
||||
|
||||
# In development, HSTS should not be present
|
||||
if settings.ENVIRONMENT == "development":
|
||||
response = client.get("/health")
|
||||
assert "Strict-Transport-Security" not in response.headers
|
||||
|
||||
def test_security_headers_on_all_endpoints(self, client):
|
||||
"""Test that security headers are present on all endpoints"""
|
||||
# Test health endpoint
|
||||
response = client.get("/health")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
|
||||
# Test root endpoint
|
||||
response = client.get("/")
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
|
||||
def test_security_headers_on_404(self, client):
|
||||
"""Test that security headers are present even on 404 responses"""
|
||||
response = client.get("/nonexistent-endpoint")
|
||||
assert response.status_code == 404
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
Reference in New Issue
Block a user