Production hardening: security, resilience, observability, and compliance

Password complexity: custom validator requiring uppercase, lowercase, digit (min 8 chars)
Token expiry: 90-day token lifetime with refresh endpoint (60-90 day renewal window)
Health check: /api/health/ now pings Postgres + Redis, returns 503 on failure
Audit logging: async audit_log table for auth events (login, register, delete, etc.)
Circuit breaker: APNs/FCM push sends wrapped with 5-failure threshold, 30s recovery
FK indexes: 27 missing foreign key indexes across all tables (migration 017)
CSP header: default-src 'none'; frame-ancestors 'none'
Gzip compression: level 5 with media endpoint skipper
Prometheus metrics: /metrics endpoint using existing monitoring service
External timeouts: 15s push, 30s SMTP, context timeouts on all external calls

Migrations: 016 (token created_at), 017 (FK indexes), 018 (audit_log)
Tests: circuit breaker (15), audit service (8), token refresh (7), health (4),
       middleware expiry (5), validator (new)
This commit is contained in:
Trey T
2026-03-26 14:05:28 -05:00
parent 4abc57535e
commit b679f28e55
30 changed files with 2077 additions and 47 deletions

View File

@@ -134,6 +134,8 @@ type SecurityConfig struct {
PasswordResetExpiry time.Duration PasswordResetExpiry time.Duration
ConfirmationExpiry time.Duration ConfirmationExpiry time.Duration
MaxPasswordResetRate int // per hour MaxPasswordResetRate int // per hour
TokenExpiryDays int // Number of days before auth tokens expire (default 90)
TokenRefreshDays int // Token must be at least this many days old before refresh (default 60)
} }
// StorageConfig holds file storage settings // StorageConfig holds file storage settings
@@ -262,6 +264,8 @@ func Load() (*Config, error) {
PasswordResetExpiry: 15 * time.Minute, PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour, ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3, MaxPasswordResetRate: 3,
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
}, },
Storage: StorageConfig{ Storage: StorageConfig{
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
@@ -369,6 +373,10 @@ func setDefaults() {
viper.SetDefault("OVERDUE_REMINDER_HOUR", 15) // 9:00 AM UTC viper.SetDefault("OVERDUE_REMINDER_HOUR", 15) // 9:00 AM UTC
viper.SetDefault("DAILY_DIGEST_HOUR", 3) // 3:00 AM UTC viper.SetDefault("DAILY_DIGEST_HOUR", 3) // 3:00 AM UTC
// Token expiry defaults
viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days
viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days
// Storage defaults // Storage defaults
viper.SetDefault("STORAGE_UPLOAD_DIR", "./uploads") viper.SetDefault("STORAGE_UPLOAD_DIR", "./uploads")
viper.SetDefault("STORAGE_BASE_URL", "/uploads") viper.SetDefault("STORAGE_BASE_URL", "/uploads")

View File

@@ -11,7 +11,7 @@ type LoginRequest struct {
type RegisterRequest struct { type RegisterRequest struct {
Username string `json:"username" validate:"required,min=3,max=150"` Username string `json:"username" validate:"required,min=3,max=150"`
Email string `json:"email" validate:"required,email,max=254"` Email string `json:"email" validate:"required,email,max=254"`
Password string `json:"password" validate:"required,min=8"` Password string `json:"password" validate:"required,min=8,password_complexity"`
FirstName string `json:"first_name" validate:"max=150"` FirstName string `json:"first_name" validate:"max=150"`
LastName string `json:"last_name" validate:"max=150"` LastName string `json:"last_name" validate:"max=150"`
} }
@@ -35,7 +35,7 @@ type VerifyResetCodeRequest struct {
// ResetPasswordRequest represents the reset password request body // ResetPasswordRequest represents the reset password request body
type ResetPasswordRequest struct { type ResetPasswordRequest struct {
ResetToken string `json:"reset_token" validate:"required"` ResetToken string `json:"reset_token" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8"` NewPassword string `json:"new_password" validate:"required,min=8,password_complexity"`
} }
// UpdateProfileRequest represents the profile update request body // UpdateProfileRequest represents the profile update request body

View File

@@ -79,6 +79,12 @@ type ResetPasswordResponse struct {
Message string `json:"message"` Message string `json:"message"`
} }
// RefreshTokenResponse represents the token refresh response
type RefreshTokenResponse struct {
Token string `json:"token"`
Message string `json:"message"`
}
// MessageResponse represents a simple message response // MessageResponse represents a simple message response
type MessageResponse struct { type MessageResponse struct {
Message string `json:"message"` Message string `json:"message"`

View File

@@ -23,6 +23,7 @@ type AuthHandler struct {
appleAuthService *services.AppleAuthService appleAuthService *services.AppleAuthService
googleAuthService *services.GoogleAuthService googleAuthService *services.GoogleAuthService
storageService *services.StorageService storageService *services.StorageService
auditService *services.AuditService
} }
// NewAuthHandler creates a new auth handler // NewAuthHandler creates a new auth handler
@@ -49,6 +50,11 @@ func (h *AuthHandler) SetStorageService(storageService *services.StorageService)
h.storageService = storageService h.storageService = storageService
} }
// SetAuditService sets the audit service for logging security events
func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
h.auditService = auditService
}
// Login handles POST /api/auth/login/ // Login handles POST /api/auth/login/
func (h *AuthHandler) Login(c echo.Context) error { func (h *AuthHandler) Login(c echo.Context) error {
var req requests.LoginRequest var req requests.LoginRequest
@@ -62,9 +68,19 @@ func (h *AuthHandler) Login(c echo.Context) error {
response, err := h.authService.Login(&req) response, err := h.authService.Login(&req)
if err != nil { if err != nil {
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed") log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
"identifier": req.Username,
})
}
return err return err
} }
if h.auditService != nil {
userID := response.User.ID
h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil)
}
return c.JSON(http.StatusOK, response) return c.JSON(http.StatusOK, response)
} }
@@ -84,6 +100,14 @@ func (h *AuthHandler) Register(c echo.Context) error {
return err return err
} }
if h.auditService != nil {
userID := response.User.ID
h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{
"username": req.Username,
"email": req.Email,
})
}
// Send welcome email with confirmation code (async) // Send welcome email with confirmation code (async)
if h.emailService != nil && confirmationCode != "" { if h.emailService != nil && confirmationCode != "" {
go func() { go func() {
@@ -108,6 +132,14 @@ func (h *AuthHandler) Logout(c echo.Context) error {
return apperrors.Unauthorized("error.not_authenticated") return apperrors.Unauthorized("error.not_authenticated")
} }
// Log audit event before invalidating the token
if h.auditService != nil {
user := middleware.GetAuthUser(c)
if user != nil {
h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil)
}
}
// Invalidate token in database // Invalidate token in database
if err := h.authService.Logout(token); err != nil { if err := h.authService.Logout(token); err != nil {
log.Warn().Err(err).Msg("Failed to delete token from database") log.Warn().Err(err).Msg("Failed to delete token from database")
@@ -270,6 +302,12 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
}() }()
} }
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
"email": req.Email,
})
}
// Always return success to prevent email enumeration // Always return success to prevent email enumeration
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{ return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
Message: "Password reset email sent", Message: "Password reset email sent",
@@ -314,6 +352,12 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
return err return err
} }
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{
"method": "reset_token",
})
}
return c.JSON(http.StatusOK, responses.ResetPasswordResponse{ return c.JSON(http.StatusOK, responses.ResetPasswordResponse{
Message: "Password reset successful", Message: "Password reset successful",
}) })
@@ -413,6 +457,34 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
return c.JSON(http.StatusOK, response) return c.JSON(http.StatusOK, response)
} }
// RefreshToken handles POST /api/auth/refresh/
func (h *AuthHandler) RefreshToken(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
token := middleware.GetAuthToken(c)
if token == "" {
return apperrors.Unauthorized("error.not_authenticated")
}
response, err := h.authService.RefreshToken(token, user.ID)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
return err
}
// If the token was refreshed (new token), invalidate the old one from cache
if response.Token != token && h.cache != nil {
if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil {
log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh")
}
}
return c.JSON(http.StatusOK, response)
}
// DeleteAccount handles DELETE /api/auth/account/ // DeleteAccount handles DELETE /api/auth/account/
func (h *AuthHandler) DeleteAccount(c echo.Context) error { func (h *AuthHandler) DeleteAccount(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c) user, err := middleware.MustGetAuthUser(c)
@@ -431,6 +503,14 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
return err return err
} }
if h.auditService != nil {
h.auditService.LogEvent(c, &user.ID, services.AuditEventAccountDeleted, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
})
}
// Delete files from disk (best effort, don't fail the request) // Delete files from disk (best effort, don't fail the request)
if h.storageService != nil && len(fileURLs) > 0 { if h.storageService != nil && len(fileURLs) > 0 {
go func() { go func() {

View File

@@ -6,8 +6,11 @@
"error.email_taken": "Email already registered", "error.email_taken": "Email already registered",
"error.email_already_taken": "Email already taken", "error.email_already_taken": "Email already taken",
"error.registration_failed": "Registration failed", "error.registration_failed": "Registration failed",
"error.password_complexity": "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit",
"error.not_authenticated": "Not authenticated", "error.not_authenticated": "Not authenticated",
"error.invalid_token": "Invalid token", "error.invalid_token": "Invalid token",
"error.token_expired": "Your session has expired. Please log in again.",
"error.token_refresh_not_needed": "Token is still valid.",
"error.failed_to_get_user": "Failed to get user", "error.failed_to_get_user": "Failed to get user",
"error.failed_to_update_profile": "Failed to update profile", "error.failed_to_update_profile": "Failed to update profile",
"error.invalid_verification_code": "Invalid verification code", "error.invalid_verification_code": "Invalid verification code",

View File

@@ -12,6 +12,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/services" "github.com/treytartt/honeydue-api/internal/services"
) )
@@ -28,24 +29,56 @@ const (
// UserCacheTTL is how long full user records are cached in memory to // UserCacheTTL is how long full user records are cached in memory to
// avoid hitting the database on every authenticated request. // avoid hitting the database on every authenticated request.
UserCacheTTL = 30 * time.Second UserCacheTTL = 30 * time.Second
// DefaultTokenExpiryDays is the default number of days before a token expires.
DefaultTokenExpiryDays = 90
) )
// AuthMiddleware provides token authentication middleware // AuthMiddleware provides token authentication middleware
type AuthMiddleware struct { type AuthMiddleware struct {
db *gorm.DB db *gorm.DB
cache *services.CacheService cache *services.CacheService
userCache *UserCache userCache *UserCache
tokenExpiryDays int
} }
// NewAuthMiddleware creates a new auth middleware instance // NewAuthMiddleware creates a new auth middleware instance
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
return &AuthMiddleware{ return &AuthMiddleware{
db: db, db: db,
cache: cache, cache: cache,
userCache: NewUserCache(UserCacheTTL), userCache: NewUserCache(UserCacheTTL),
tokenExpiryDays: DefaultTokenExpiryDays,
} }
} }
// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration
func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware {
expiryDays := DefaultTokenExpiryDays
if cfg != nil && cfg.Security.TokenExpiryDays > 0 {
expiryDays = cfg.Security.TokenExpiryDays
}
return &AuthMiddleware{
db: db,
cache: cache,
userCache: NewUserCache(UserCacheTTL),
tokenExpiryDays: expiryDays,
}
}
// TokenExpiryDuration returns the token expiry duration.
func (m *AuthMiddleware) TokenExpiryDuration() time.Duration {
return time.Duration(m.tokenExpiryDays) * 24 * time.Hour
}
// isTokenExpired checks if a token's created timestamp indicates expiry.
func (m *AuthMiddleware) isTokenExpired(created time.Time) bool {
if created.IsZero() {
return false // Legacy tokens without created time are not expired
}
return time.Since(created) > m.TokenExpiryDuration()
}
// TokenAuth returns an Echo middleware that validates token authentication // TokenAuth returns an Echo middleware that validates token authentication
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@@ -56,7 +89,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
return apperrors.Unauthorized("error.not_authenticated") return apperrors.Unauthorized("error.not_authenticated")
} }
// Try to get user from cache first // Try to get user from cache first (includes expiry check)
user, err := m.getUserFromCache(c.Request().Context(), token) user, err := m.getUserFromCache(c.Request().Context(), token)
if err == nil && user != nil { if err == nil && user != nil {
// Cache hit - set user in context and continue // Cache hit - set user in context and continue
@@ -65,16 +98,27 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
return next(c) return next(c)
} }
// Check if the cache indicated token expiry
if err != nil && err.Error() == "token expired" {
return apperrors.Unauthorized("error.token_expired")
}
// Cache miss - look up token in database // Cache miss - look up token in database
user, err = m.getUserFromDatabase(token) user, authToken, err := m.getUserFromDatabaseWithToken(token)
if err != nil { if err != nil {
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed") log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
return apperrors.Unauthorized("error.invalid_token") return apperrors.Unauthorized("error.invalid_token")
} }
// Cache the user ID for future requests // Check token expiry
if cacheErr := m.cacheUserID(c.Request().Context(), token, user.ID); cacheErr != nil { if m.isTokenExpired(authToken.Created) {
log.Warn().Err(cacheErr).Msg("Failed to cache user ID") log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired")
return apperrors.Unauthorized("error.token_expired")
}
// Cache the user ID and token creation time for future requests
if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil {
log.Warn().Err(cacheErr).Msg("Failed to cache token info")
} }
// Set user in context // Set user in context
@@ -104,9 +148,9 @@ func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
} }
// Try database // Try database
user, err = m.getUserFromDatabase(token) user, authToken, err := m.getUserFromDatabaseWithToken(token)
if err == nil { if err == nil && !m.isTokenExpired(authToken.Created) {
m.cacheUserID(c.Request().Context(), token, user.ID) m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created)
c.Set(AuthUserKey, user) c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token) c.Set(AuthTokenKey, token)
} }
@@ -145,12 +189,13 @@ func extractToken(c echo.Context) (string, error) {
// getUserFromCache tries to get user from Redis cache, then from the // getUserFromCache tries to get user from Redis cache, then from the
// in-memory user cache, before falling back to the database. // in-memory user cache, before falling back to the database.
// Returns a "token expired" error if the cached creation time indicates expiry.
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
if m.cache == nil { if m.cache == nil {
return nil, fmt.Errorf("cache not available") return nil, fmt.Errorf("cache not available")
} }
userID, err := m.cache.GetCachedAuthToken(ctx, token) userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token)
if err != nil { if err != nil {
if err == redis.Nil { if err == redis.Nil {
return nil, fmt.Errorf("token not in cache") return nil, fmt.Errorf("token not in cache")
@@ -158,6 +203,15 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
return nil, err return nil, err
} }
// Check token expiry from cached creation time
if createdUnix > 0 {
created := time.Unix(createdUnix, 0)
if m.isTokenExpired(created) {
m.cache.InvalidateAuthToken(ctx, token)
return nil, fmt.Errorf("token expired")
}
}
// Try in-memory user cache first to avoid a DB round-trip // Try in-memory user cache first to avoid a DB round-trip
if cached := m.userCache.Get(userID); cached != nil { if cached := m.userCache.Get(userID); cached != nil {
if !cached.IsActive { if !cached.IsActive {
@@ -187,22 +241,38 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
return &user, nil return &user, nil
} }
// getUserFromDatabase looks up the token in the database and caches the // getUserFromDatabaseWithToken looks up the token in the database and returns
// resulting user record in memory. // both the user and the auth token record (for expiry checking).
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { func (m *AuthMiddleware) getUserFromDatabaseWithToken(token string) (*models.User, *models.AuthToken, error) {
var authToken models.AuthToken var authToken models.AuthToken
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil { if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
return nil, fmt.Errorf("token not found") return nil, nil, fmt.Errorf("token not found")
} }
// Check if user is active // Check if user is active
if !authToken.User.IsActive { if !authToken.User.IsActive {
return nil, fmt.Errorf("user is inactive") return nil, nil, fmt.Errorf("user is inactive")
} }
// Store in in-memory cache for subsequent requests // Store in in-memory cache for subsequent requests
m.userCache.Set(&authToken.User) m.userCache.Set(&authToken.User)
return &authToken.User, nil return &authToken.User, &authToken, nil
}
// getUserFromDatabase looks up the token in the database and caches the
// resulting user record in memory.
// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking.
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
user, _, err := m.getUserFromDatabaseWithToken(token)
return user, err
}
// cacheTokenInfo caches the user ID and token creation time for a token
func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error {
if m.cache == nil {
return nil
}
return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix())
} }
// cacheUserID caches the user ID for a token // cacheUserID caches the user ID for a token

View File

@@ -0,0 +1,165 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/honeydue-api/internal/models"
)
// setupTestDB creates a temporary in-memory SQLite database with the required
// tables for auth middleware tests.
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&models.User{}, &models.AuthToken{})
require.NoError(t, err)
return db
}
// createTestUserAndToken creates a user and an auth token, then backdates the
// token's Created timestamp by the specified number of days.
func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) {
t.Helper()
user := &models.User{
Username: username,
Email: username + "@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("password123"))
require.NoError(t, db.Create(user).Error)
token := &models.AuthToken{
UserID: user.ID,
}
require.NoError(t, db.Create(token).Error)
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
require.NoError(t, db.Model(token).Update("created", backdated).Error)
token.Created = backdated
return user, token
}
func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry
m := NewAuthMiddleware(db, nil) // No Redis cache for these tests
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestTokenAuth_AcceptsValidToken(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify user was set in context
user := GetAuthUser(c)
require.NotNil(t, user)
assert.Equal(t, "valid_user", user.Username)
}
func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestTokenAuth_RejectsInvalidToken(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token nonexistent-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}

View File

@@ -0,0 +1,48 @@
package models
import (
"database/sql/driver"
"encoding/json"
"errors"
"time"
)
// JSONMap is a custom type for JSONB columns that handles JSON serialization
type JSONMap map[string]interface{}
// Value implements driver.Valuer for database writes
func (j JSONMap) Value() (driver.Value, error) {
if j == nil {
return nil, nil
}
return json.Marshal(j)
}
// Scan implements sql.Scanner for database reads
func (j *JSONMap) Scan(value interface{}) error {
if value == nil {
*j = nil
return nil
}
bytes, ok := value.([]byte)
if !ok {
return errors.New("audit_log: failed to scan JSONMap value")
}
return json.Unmarshal(bytes, j)
}
// AuditLog represents the audit_log table for tracking security-relevant events
type AuditLog struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID *uint `gorm:"column:user_id" json:"user_id"`
EventType string `gorm:"column:event_type;size:50;not null" json:"event_type"`
IPAddress string `gorm:"column:ip_address;size:45" json:"ip_address"`
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent"`
Details JSONMap `gorm:"column:details;type:jsonb" json:"details"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
}
// TableName returns the table name for GORM
func (AuditLog) TableName() string {
return "audit_log"
}

View File

@@ -0,0 +1,167 @@
package push
import (
"errors"
"sync"
"time"
)
// Circuit breaker states
const (
stateClosed = iota // Normal operation, requests pass through
stateOpen // Too many failures, requests are rejected
stateHalfOpen // Testing recovery, one request allowed through
)
// Default circuit breaker settings
const (
defaultFailureThreshold = 5 // Open after this many consecutive failures
defaultRecoveryTimeout = 30 * time.Second // Try again after this duration
)
// ErrCircuitOpen is returned when the circuit breaker is open and rejecting requests.
var ErrCircuitOpen = errors.New("circuit breaker is open")
// CircuitBreaker implements a simple circuit breaker pattern for external service calls.
// It is thread-safe and requires no external dependencies.
//
// States:
// - Closed: normal operation, all requests pass through. Consecutive failures are counted.
// - Open: after reaching the failure threshold, all requests are immediately rejected
// with ErrCircuitOpen until the recovery timeout elapses.
// - Half-Open: after the recovery timeout, one request is allowed through. If it
// succeeds the breaker resets to Closed; if it fails it returns to Open.
type CircuitBreaker struct {
mu sync.Mutex
state int
failureCount int
failureThreshold int
recoveryTimeout time.Duration
lastFailureTime time.Time
name string // For logging
}
// CircuitBreakerOption configures a CircuitBreaker.
type CircuitBreakerOption func(*CircuitBreaker)
// WithFailureThreshold sets the number of consecutive failures before opening the circuit.
func WithFailureThreshold(n int) CircuitBreakerOption {
return func(cb *CircuitBreaker) {
if n > 0 {
cb.failureThreshold = n
}
}
}
// WithRecoveryTimeout sets how long the circuit stays open before trying half-open.
func WithRecoveryTimeout(d time.Duration) CircuitBreakerOption {
return func(cb *CircuitBreaker) {
if d > 0 {
cb.recoveryTimeout = d
}
}
}
// NewCircuitBreaker creates a new CircuitBreaker with the given name and options.
// The name is used for logging and identification.
func NewCircuitBreaker(name string, opts ...CircuitBreakerOption) *CircuitBreaker {
cb := &CircuitBreaker{
state: stateClosed,
failureThreshold: defaultFailureThreshold,
recoveryTimeout: defaultRecoveryTimeout,
name: name,
}
for _, opt := range opts {
opt(cb)
}
return cb
}
// Allow checks whether a request should be allowed through.
// It returns true if the request can proceed, false if the circuit is open.
// When transitioning from open to half-open, it returns true for the probe request.
func (cb *CircuitBreaker) Allow() bool {
cb.mu.Lock()
defer cb.mu.Unlock()
switch cb.state {
case stateClosed:
return true
case stateOpen:
// Check if recovery timeout has elapsed
if time.Since(cb.lastFailureTime) >= cb.recoveryTimeout {
cb.state = stateHalfOpen
return true
}
return false
case stateHalfOpen:
// Only one request at a time in half-open state.
// The first caller that got here via Allow() is already in flight;
// reject subsequent callers until that probe resolves.
return false
default:
return true
}
}
// RecordSuccess records a successful request. If the breaker is half-open, it resets to closed.
func (cb *CircuitBreaker) RecordSuccess() {
cb.mu.Lock()
defer cb.mu.Unlock()
cb.failureCount = 0
cb.state = stateClosed
}
// RecordFailure records a failed request. If the failure threshold is reached, the
// breaker transitions to the open state.
func (cb *CircuitBreaker) RecordFailure() {
cb.mu.Lock()
defer cb.mu.Unlock()
cb.failureCount++
cb.lastFailureTime = time.Now()
if cb.failureCount >= cb.failureThreshold {
cb.state = stateOpen
}
}
// State returns the current state of the circuit breaker as a human-readable string.
func (cb *CircuitBreaker) State() string {
cb.mu.Lock()
defer cb.mu.Unlock()
switch cb.state {
case stateClosed:
return "closed"
case stateOpen:
return "open"
case stateHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// Name returns the circuit breaker's name.
func (cb *CircuitBreaker) Name() string {
return cb.name
}
// Reset resets the circuit breaker to the closed state with zero failures.
func (cb *CircuitBreaker) Reset() {
cb.mu.Lock()
defer cb.mu.Unlock()
cb.state = stateClosed
cb.failureCount = 0
cb.lastFailureTime = time.Time{}
}
// Counts returns the current failure count (useful for testing and monitoring).
func (cb *CircuitBreaker) Counts() int {
cb.mu.Lock()
defer cb.mu.Unlock()
return cb.failureCount
}

View File

@@ -0,0 +1,275 @@
package push
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCircuitBreaker_StartsInClosedState(t *testing.T) {
cb := NewCircuitBreaker("test")
assert.Equal(t, "closed", cb.State())
assert.True(t, cb.Allow())
}
func TestCircuitBreaker_OpensAfterThresholdFailures(t *testing.T) {
cb := NewCircuitBreaker("test", WithFailureThreshold(3))
// First two failures should keep it closed
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
assert.True(t, cb.Allow())
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
assert.True(t, cb.Allow())
// Third failure should open it
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
assert.False(t, cb.Allow())
}
func TestCircuitBreaker_DefaultThresholdIsFive(t *testing.T) {
cb := NewCircuitBreaker("test")
for i := 0; i < 4; i++ {
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
}
cb.RecordFailure() // 5th failure
assert.Equal(t, "open", cb.State())
}
func TestCircuitBreaker_RejectsRequestsWhenOpen(t *testing.T) {
cb := NewCircuitBreaker("test", WithFailureThreshold(1))
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
// Multiple calls should all be rejected
for i := 0; i < 10; i++ {
assert.False(t, cb.Allow())
}
}
func TestCircuitBreaker_TransitionsToHalfOpenAfterRecoveryTimeout(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(1),
WithRecoveryTimeout(50*time.Millisecond),
)
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
assert.False(t, cb.Allow())
// Wait for recovery timeout
time.Sleep(60 * time.Millisecond)
// Should now allow one request (half-open)
assert.True(t, cb.Allow())
assert.Equal(t, "half-open", cb.State())
}
func TestCircuitBreaker_HalfOpenRejectsSecondRequest(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(1),
WithRecoveryTimeout(50*time.Millisecond),
)
cb.RecordFailure()
time.Sleep(60 * time.Millisecond)
// First request allowed (probe)
assert.True(t, cb.Allow())
assert.Equal(t, "half-open", cb.State())
// Second request rejected while probe is in flight
assert.False(t, cb.Allow())
}
func TestCircuitBreaker_HalfOpenSuccess_ResetsToClosed(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(1),
WithRecoveryTimeout(50*time.Millisecond),
)
cb.RecordFailure()
time.Sleep(60 * time.Millisecond)
// Probe request
assert.True(t, cb.Allow())
// Probe succeeds
cb.RecordSuccess()
assert.Equal(t, "closed", cb.State())
assert.Equal(t, 0, cb.Counts())
// Normal operation resumes
assert.True(t, cb.Allow())
assert.True(t, cb.Allow())
}
func TestCircuitBreaker_HalfOpenFailure_ReturnsToOpen(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(2),
WithRecoveryTimeout(50*time.Millisecond),
)
// Open the circuit
cb.RecordFailure()
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
time.Sleep(60 * time.Millisecond)
// Probe request
assert.True(t, cb.Allow())
assert.Equal(t, "half-open", cb.State())
// Probe fails - the failure count is now 3 which is >= threshold of 2
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
assert.False(t, cb.Allow())
}
func TestCircuitBreaker_SuccessResetsFailureCount(t *testing.T) {
cb := NewCircuitBreaker("test", WithFailureThreshold(3))
cb.RecordFailure()
cb.RecordFailure()
assert.Equal(t, 2, cb.Counts())
// A success should reset the counter
cb.RecordSuccess()
assert.Equal(t, 0, cb.Counts())
assert.Equal(t, "closed", cb.State())
// Now it should take 3 more failures to open
cb.RecordFailure()
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
}
func TestCircuitBreaker_Reset(t *testing.T) {
cb := NewCircuitBreaker("test", WithFailureThreshold(1))
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
cb.Reset()
assert.Equal(t, "closed", cb.State())
assert.Equal(t, 0, cb.Counts())
assert.True(t, cb.Allow())
}
func TestCircuitBreaker_Name(t *testing.T) {
cb := NewCircuitBreaker("apns-breaker")
assert.Equal(t, "apns-breaker", cb.Name())
}
func TestCircuitBreaker_CustomOptions(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(10),
WithRecoveryTimeout(5*time.Minute),
)
// Should take 10 failures to open
for i := 0; i < 9; i++ {
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
}
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
}
func TestCircuitBreaker_InvalidOptionsIgnored(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(0), // Should be ignored (keeps default)
WithRecoveryTimeout(-1), // Should be ignored (keeps default)
)
// Default threshold of 5 should still apply
for i := 0; i < 4; i++ {
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
}
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
}
func TestCircuitBreaker_ThreadSafety(t *testing.T) {
cb := NewCircuitBreaker("test",
WithFailureThreshold(100),
WithRecoveryTimeout(10*time.Millisecond),
)
var wg sync.WaitGroup
const goroutines = 50
const iterations = 100
// Hammer it from many goroutines
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
cb.Allow()
if j%2 == 0 {
cb.RecordFailure()
} else {
cb.RecordSuccess()
}
_ = cb.State()
_ = cb.Counts()
}
}(i)
}
wg.Wait()
// Should not panic or deadlock. State should be valid.
state := cb.State()
require.Contains(t, []string{"closed", "open", "half-open"}, state)
}
func TestCircuitBreaker_FullLifecycle(t *testing.T) {
cb := NewCircuitBreaker("lifecycle-test",
WithFailureThreshold(3),
WithRecoveryTimeout(50*time.Millisecond),
)
// 1. Closed: requests flow normally
assert.True(t, cb.Allow())
cb.RecordSuccess()
assert.Equal(t, "closed", cb.State())
// 2. Accumulate failures
cb.RecordFailure()
cb.RecordFailure()
assert.Equal(t, "closed", cb.State())
// 3. Third failure opens the circuit
cb.RecordFailure()
assert.Equal(t, "open", cb.State())
assert.False(t, cb.Allow())
// 4. Wait for recovery
time.Sleep(60 * time.Millisecond)
// 5. Half-open: probe request allowed
assert.True(t, cb.Allow())
assert.Equal(t, "half-open", cb.State())
// 6. Probe succeeds, back to closed
cb.RecordSuccess()
assert.Equal(t, "closed", cb.State())
assert.True(t, cb.Allow())
}

View File

@@ -2,6 +2,7 @@ package push
import ( import (
"context" "context"
"time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -14,16 +15,25 @@ const (
PlatformAndroid = "android" PlatformAndroid = "android"
) )
// Timeout for individual push notification send operations.
const pushSendTimeout = 15 * time.Second
// Client provides a unified interface for sending push notifications // Client provides a unified interface for sending push notifications
type Client struct { type Client struct {
apns *APNsClient apns *APNsClient
fcm *FCMClient fcm *FCMClient
enabled bool enabled bool
apnsBreaker *CircuitBreaker
fcmBreaker *CircuitBreaker
} }
// NewClient creates a new unified push notification client // NewClient creates a new unified push notification client
func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) { func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
client := &Client{enabled: enabled} client := &Client{
enabled: enabled,
apnsBreaker: NewCircuitBreaker("apns"),
fcmBreaker: NewCircuitBreaker("fcm"),
}
// Initialize APNs client (iOS) // Initialize APNs client (iOS)
if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" { if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" {
@@ -54,7 +64,8 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
return client, nil return client, nil
} }
// SendToIOS sends a push notification to iOS devices // SendToIOS sends a push notification to iOS devices.
// The call is guarded by a circuit breaker and uses a context timeout.
func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error { func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if !c.enabled { if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag") log.Debug().Msg("Push notifications disabled by feature flag")
@@ -64,10 +75,26 @@ func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message
log.Warn().Msg("APNs client not initialized, skipping iOS push") log.Warn().Msg("APNs client not initialized, skipping iOS push")
return nil return nil
} }
return c.apns.Send(ctx, tokens, title, message, data) if !c.apnsBreaker.Allow() {
log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS push")
return ErrCircuitOpen
}
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
defer cancel()
err := c.apns.Send(sendCtx, tokens, title, message, data)
if err != nil {
c.apnsBreaker.RecordFailure()
log.Warn().Err(err).Str("breaker_state", c.apnsBreaker.State()).Msg("APNs send failed, recorded circuit breaker failure")
return err
}
c.apnsBreaker.RecordSuccess()
return nil
} }
// SendToAndroid sends a push notification to Android devices // SendToAndroid sends a push notification to Android devices.
// The call is guarded by a circuit breaker and uses a context timeout.
func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error { func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if !c.enabled { if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag") log.Debug().Msg("Push notifications disabled by feature flag")
@@ -77,7 +104,22 @@ func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, mess
log.Warn().Msg("FCM client not initialized, skipping Android push") log.Warn().Msg("FCM client not initialized, skipping Android push")
return nil return nil
} }
return c.fcm.Send(ctx, tokens, title, message, data) if !c.fcmBreaker.Allow() {
log.Warn().Str("breaker", c.fcmBreaker.Name()).Msg("FCM circuit breaker is open, skipping Android push")
return ErrCircuitOpen
}
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
defer cancel()
err := c.fcm.Send(sendCtx, tokens, title, message, data)
if err != nil {
c.fcmBreaker.RecordFailure()
log.Warn().Err(err).Str("breaker_state", c.fcmBreaker.State()).Msg("FCM send failed, recorded circuit breaker failure")
return err
}
c.fcmBreaker.RecordSuccess()
return nil
} }
// SendToAll sends a push notification to both iOS and Android devices // SendToAll sends a push notification to both iOS and Android devices
@@ -115,8 +157,9 @@ func (c *Client) IsAndroidEnabled() bool {
return c.fcm != nil return c.fcm != nil
} }
// SendActionableNotification sends notifications with action button support // SendActionableNotification sends notifications with action button support.
// iOS receives a category for actionable notifications, Android handles actions via data payload // iOS receives a category for actionable notifications, Android handles actions via data payload.
// Both platforms are guarded by their respective circuit breakers.
func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error { func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error {
if !c.enabled { if !c.enabled {
log.Debug().Msg("Push notifications disabled by feature flag") log.Debug().Msg("Push notifications disabled by feature flag")
@@ -127,10 +170,19 @@ func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, andr
if len(iosTokens) > 0 { if len(iosTokens) > 0 {
if c.apns == nil { if c.apns == nil {
log.Warn().Msg("APNs client not initialized, skipping iOS actionable push") log.Warn().Msg("APNs client not initialized, skipping iOS actionable push")
} else if !c.apnsBreaker.Allow() {
log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS actionable push")
lastErr = ErrCircuitOpen
} else { } else {
if err := c.apns.SendWithCategory(ctx, iosTokens, title, message, data, iosCategoryID); err != nil { sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
err := c.apns.SendWithCategory(sendCtx, iosTokens, title, message, data, iosCategoryID)
cancel()
if err != nil {
c.apnsBreaker.RecordFailure()
log.Error().Err(err).Msg("Failed to send iOS actionable notifications") log.Error().Err(err).Msg("Failed to send iOS actionable notifications")
lastErr = err lastErr = err
} else {
c.apnsBreaker.RecordSuccess()
} }
} }
} }

View File

@@ -165,7 +165,7 @@ func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) {
} }
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Second, Timeout: 15 * time.Second,
Transport: transport, Transport: transport,
} }

View File

@@ -173,6 +173,27 @@ func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error
return &token, nil return &token, nil
} }
// FindTokenByKey looks up an auth token by its key value.
func (r *UserRepository) FindTokenByKey(key string) (*models.AuthToken, error) {
var token models.AuthToken
if err := r.db.Where("key = ?", key).First(&token).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTokenNotFound
}
return nil, err
}
return &token, nil
}
// CreateToken creates a new auth token for a user.
func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
token := models.AuthToken{UserID: userID}
if err := r.db.Create(&token).Error; err != nil {
return nil, err
}
return &token, nil
}
// DeleteToken deletes an auth token // DeleteToken deletes an auth token
func (r *UserRepository) DeleteToken(token string) error { func (r *UserRepository) DeleteToken(token string) error {
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{}) result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})

View File

@@ -0,0 +1,136 @@
package router
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// setupTestDB creates an in-memory SQLite database for health check tests
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
return db
}
func TestLiveCheck_Returns200(t *testing.T) {
e := echo.New()
e.GET("/api/health/live", liveCheck)
req := httptest.NewRequest(http.MethodGet, "/api/health/live", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &resp)
require.NoError(t, err)
assert.Equal(t, "alive", resp["status"])
assert.Equal(t, Version, resp["version"])
assert.Contains(t, resp, "timestamp")
}
func TestReadinessCheck_HealthyDB_NilCache_Returns200(t *testing.T) {
db := setupTestDB(t)
deps := &Dependencies{
DB: db,
Cache: nil, // No Redis configured
}
e := echo.New()
e.GET("/api/health/", readinessCheck(deps))
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &resp)
require.NoError(t, err)
assert.Equal(t, "healthy", resp["status"])
assert.Equal(t, Version, resp["version"])
checks := resp["checks"].(map[string]interface{})
assert.Equal(t, "ok", checks["postgres"])
assert.Equal(t, "not configured", checks["redis"])
}
func TestReadinessCheck_DBDown_Returns503(t *testing.T) {
// Open an in-memory SQLite DB, then close the underlying sql.DB to simulate failure
db := setupTestDB(t)
sqlDB, err := db.DB()
require.NoError(t, err)
sqlDB.Close()
deps := &Dependencies{
DB: db,
Cache: nil,
}
e := echo.New()
e.GET("/api/health/", readinessCheck(deps))
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp map[string]interface{}
err = json.Unmarshal(rec.Body.Bytes(), &resp)
require.NoError(t, err)
assert.Equal(t, "unhealthy", resp["status"])
checks := resp["checks"].(map[string]interface{})
assert.Contains(t, checks["postgres"], "ping failed")
}
func TestReadinessCheck_ResponseFormat(t *testing.T) {
db := setupTestDB(t)
deps := &Dependencies{
DB: db,
Cache: nil,
}
e := echo.New()
e.GET("/api/health/", readinessCheck(deps))
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
var resp map[string]interface{}
err := json.Unmarshal(rec.Body.Bytes(), &resp)
require.NoError(t, err)
// Verify all expected fields are present
assert.Contains(t, resp, "status")
assert.Contains(t, resp, "version")
assert.Contains(t, resp, "checks")
assert.Contains(t, resp, "timestamp")
// Verify checks is a map with expected keys
checks, ok := resp["checks"].(map[string]interface{})
require.True(t, ok, "checks should be a map")
assert.Contains(t, checks, "postgres")
assert.Contains(t, checks, "redis")
}

View File

@@ -1,6 +1,7 @@
package router package router
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@@ -62,11 +63,12 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.) // Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.)
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
XSSProtection: "1; mode=block", XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff", ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN", XFrameOptions: "SAMEORIGIN",
HSTSMaxAge: 31536000, // 1 year in seconds HSTSMaxAge: 31536000, // 1 year in seconds
ReferrerPolicy: "strict-origin-when-cross-origin", ReferrerPolicy: "strict-origin-when-cross-origin",
ContentSecurityPolicy: "default-src 'none'; frame-ancestors 'none'",
})) }))
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{ e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
Limit: "1M", // 1MB default for JSON payloads Limit: "1M", // 1MB default for JSON payloads
@@ -93,6 +95,14 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
e.Use(corsMiddleware(cfg)) e.Use(corsMiddleware(cfg))
e.Use(i18n.Middleware()) e.Use(i18n.Middleware())
// Gzip compression (skip media endpoints since they serve binary files)
e.Use(middleware.GzipWithConfig(middleware.GzipConfig{
Level: 5,
Skipper: func(c echo.Context) bool {
return strings.HasPrefix(c.Request().URL.Path, "/api/media/")
},
}))
// Monitoring metrics middleware (if monitoring is enabled) // Monitoring metrics middleware (if monitoring is enabled)
if deps.MonitoringService != nil { if deps.MonitoringService != nil {
if metricsMiddleware := deps.MonitoringService.MetricsMiddleware(); metricsMiddleware != nil { if metricsMiddleware := deps.MonitoringService.MetricsMiddleware(); metricsMiddleware != nil {
@@ -114,8 +124,9 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
}) })
} }
// Health check endpoint (no auth required) // Health check endpoints (no auth required)
e.GET("/api/health/", healthCheck) e.GET("/api/health/", readinessCheck(deps))
e.GET("/api/health/live", liveCheck)
// Initialize onboarding email service for tracking handler // Initialize onboarding email service for tracking handler
onboardingService := services.NewOnboardingEmailService(deps.DB, deps.EmailService, cfg.Server.BaseURL) onboardingService := services.NewOnboardingEmailService(deps.DB, deps.EmailService, cfg.Server.BaseURL)
@@ -172,17 +183,21 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
subscriptionWebhookHandler.SetStripeService(stripeService) subscriptionWebhookHandler.SetStripeService(stripeService)
// Initialize middleware // Initialize middleware
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache) authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg)
// Initialize Apple auth service // Initialize Apple auth service
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg) appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg) googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg)
// Initialize audit service for security event logging
auditService := services.NewAuditService(deps.DB)
// Initialize handlers // Initialize handlers
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache) authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
authHandler.SetAppleAuthService(appleAuthService) authHandler.SetAppleAuthService(appleAuthService)
authHandler.SetGoogleAuthService(googleAuthService) authHandler.SetGoogleAuthService(googleAuthService)
authHandler.SetStorageService(deps.StorageService) authHandler.SetStorageService(deps.StorageService)
authHandler.SetAuditService(auditService)
userHandler := handlers.NewUserHandler(userService) userHandler := handlers.NewUserHandler(userService)
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled) residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled)
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService) taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
@@ -201,6 +216,11 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService) mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
} }
// Prometheus metrics endpoint (no auth required, for scraping)
if deps.MonitoringService != nil {
e.GET("/metrics", prometheusMetrics(deps.MonitoringService))
}
// Set up admin routes with monitoring handler (if available) // Set up admin routes with monitoring handler (if available)
var monitoringHandler *monitoring.Handler var monitoringHandler *monitoring.Handler
if deps.MonitoringService != nil { if deps.MonitoringService != nil {
@@ -295,16 +315,126 @@ func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
}) })
} }
// healthCheck returns API health status // liveCheck returns a simple 200 for Kubernetes liveness probes
func healthCheck(c echo.Context) error { func liveCheck(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{ return c.JSON(http.StatusOK, map[string]interface{}{
"status": "healthy", "status": "alive",
"version": Version, "version": Version,
"framework": "Echo",
"timestamp": time.Now().UTC().Format(time.RFC3339), "timestamp": time.Now().UTC().Format(time.RFC3339),
}) })
} }
// readinessCheck returns 200 if PostgreSQL and Redis are reachable, 503 otherwise.
// This is used by Kubernetes readiness probes and load balancers.
func readinessCheck(deps *Dependencies) echo.HandlerFunc {
return func(c echo.Context) error {
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second)
defer cancel()
status := "healthy"
httpStatus := http.StatusOK
checks := make(map[string]string)
// Check PostgreSQL
sqlDB, err := deps.DB.DB()
if err != nil {
checks["postgres"] = fmt.Sprintf("failed to get sql.DB: %v", err)
status = "unhealthy"
httpStatus = http.StatusServiceUnavailable
} else if err := sqlDB.PingContext(ctx); err != nil {
checks["postgres"] = fmt.Sprintf("ping failed: %v", err)
status = "unhealthy"
httpStatus = http.StatusServiceUnavailable
} else {
checks["postgres"] = "ok"
}
// Check Redis (if cache service is available)
if deps.Cache != nil {
if err := deps.Cache.Client().Ping(ctx).Err(); err != nil {
checks["redis"] = fmt.Sprintf("ping failed: %v", err)
status = "unhealthy"
httpStatus = http.StatusServiceUnavailable
} else {
checks["redis"] = "ok"
}
} else {
checks["redis"] = "not configured"
}
return c.JSON(httpStatus, map[string]interface{}{
"status": status,
"version": Version,
"checks": checks,
"timestamp": time.Now().UTC().Format(time.RFC3339),
})
}
}
// prometheusMetrics returns an Echo handler that outputs metrics in Prometheus text format.
// It uses the existing monitoring service's HTTP stats collector to avoid adding external dependencies.
func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc {
return func(c echo.Context) error {
httpCollector := monSvc.HTTPCollector()
if httpCollector == nil {
return c.String(http.StatusOK, "# No HTTP metrics available (collector not initialized)\n")
}
stats := httpCollector.GetStats()
var b strings.Builder
// Request count by method+path+status
b.WriteString("# HELP http_requests_total Total number of HTTP requests.\n")
b.WriteString("# TYPE http_requests_total counter\n")
for statusCode, count := range stats.ByStatusCode {
fmt.Fprintf(&b, "http_requests_total{status_code=\"%d\"} %d\n", statusCode, count)
}
// Per-endpoint request count
b.WriteString("# HELP http_endpoint_requests_total Total requests per endpoint.\n")
b.WriteString("# TYPE http_endpoint_requests_total counter\n")
for endpoint, epStats := range stats.ByEndpoint {
// endpoint is "METHOD /path"
parts := strings.SplitN(endpoint, " ", 2)
method := endpoint
path := ""
if len(parts) == 2 {
method = parts[0]
path = parts[1]
}
fmt.Fprintf(&b, "http_endpoint_requests_total{method=\"%s\",path=\"%s\"} %d\n", method, path, epStats.Count)
}
// Request duration (avg latency as a gauge, since we don't have raw histogram buckets)
b.WriteString("# HELP http_request_duration_ms Average request duration in milliseconds per endpoint.\n")
b.WriteString("# TYPE http_request_duration_ms gauge\n")
for endpoint, epStats := range stats.ByEndpoint {
parts := strings.SplitN(endpoint, " ", 2)
method := endpoint
path := ""
if len(parts) == 2 {
method = parts[0]
path = parts[1]
}
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"avg\"} %.2f\n", method, path, epStats.AvgLatencyMs)
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"p95\"} %.2f\n", method, path, epStats.P95LatencyMs)
}
// Error rate
b.WriteString("# HELP http_error_rate Overall error rate (4xx+5xx / total).\n")
b.WriteString("# TYPE http_error_rate gauge\n")
fmt.Fprintf(&b, "http_error_rate %.4f\n", stats.ErrorRate)
// Requests per minute
b.WriteString("# HELP http_requests_per_minute Current request rate.\n")
b.WriteString("# TYPE http_requests_per_minute gauge\n")
fmt.Fprintf(&b, "http_requests_per_minute %.2f\n", stats.RequestsPerMinute)
c.Response().Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
return c.String(http.StatusOK, b.String())
}
}
// setupPublicAuthRoutes configures public authentication routes with // setupPublicAuthRoutes configures public authentication routes with
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing. // per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
// Rate limiters are disabled in debug mode to allow UI test suites to run // Rate limiters are disabled in debug mode to allow UI test suites to run
@@ -342,6 +472,7 @@ func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler
auth := api.Group("/auth") auth := api.Group("/auth")
{ {
auth.POST("/logout/", authHandler.Logout) auth.POST("/logout/", authHandler.Logout)
auth.POST("/refresh/", authHandler.RefreshToken)
auth.GET("/me/", authHandler.CurrentUser) auth.GET("/me/", authHandler.CurrentUser)
auth.PUT("/profile/", authHandler.UpdateProfile) auth.PUT("/profile/", authHandler.UpdateProfile)
auth.PATCH("/profile/", authHandler.UpdateProfile) auth.PATCH("/profile/", authHandler.UpdateProfile)

View File

@@ -0,0 +1,93 @@
package services
import (
"sync"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/models"
)
// Audit event type constants
const (
AuditEventLogin = "auth.login"
AuditEventLoginFailed = "auth.login_failed"
AuditEventRegister = "auth.register"
AuditEventLogout = "auth.logout"
AuditEventPasswordReset = "auth.password_reset"
AuditEventPasswordChanged = "auth.password_changed"
AuditEventAccountDeleted = "auth.account_deleted"
)
// AuditService handles audit logging for security-relevant events.
// It writes audit log entries asynchronously via a buffered channel to avoid
// blocking request handlers.
type AuditService struct {
db *gorm.DB
logChan chan *models.AuditLog
done chan struct{}
stopOnce sync.Once
}
// NewAuditService creates a new audit service with a buffered channel for async writes.
// Call Stop() when shutting down to flush remaining entries.
func NewAuditService(db *gorm.DB) *AuditService {
s := &AuditService{
db: db,
logChan: make(chan *models.AuditLog, 256),
done: make(chan struct{}),
}
go s.processLogs()
return s
}
// processLogs drains the log channel and writes entries to the database.
func (s *AuditService) processLogs() {
defer close(s.done)
for entry := range s.logChan {
if err := s.db.Create(entry).Error; err != nil {
log.Error().Err(err).
Str("event_type", entry.EventType).
Msg("Failed to write audit log entry")
}
}
}
// Stop closes the log channel and waits for all pending entries to be written.
// It is safe to call Stop multiple times.
func (s *AuditService) Stop() {
s.stopOnce.Do(func() {
close(s.logChan)
})
<-s.done
}
// LogEvent records an audit event. It extracts the client IP and User-Agent from
// the Echo context and sends the entry to the background writer. If the channel
// is full the entry is dropped and an error is logged (non-blocking).
func (s *AuditService) LogEvent(c echo.Context, userID *uint, eventType string, details map[string]interface{}) {
var ip, ua string
if c != nil {
ip = c.RealIP()
ua = c.Request().Header.Get("User-Agent")
}
entry := &models.AuditLog{
UserID: userID,
EventType: eventType,
IPAddress: ip,
UserAgent: ua,
Details: models.JSONMap(details),
}
select {
case s.logChan <- entry:
// sent
default:
log.Warn().
Str("event_type", eventType).
Msg("Audit log channel full, dropping entry")
}
}

View File

@@ -0,0 +1,176 @@
package services
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/testutil"
)
func TestAuditService_LogEvent_WritesToDatabase(t *testing.T) {
db := testutil.SetupTestDB(t)
svc := NewAuditService(db)
defer svc.Stop()
// Create a fake Echo context
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
req.Header.Set("User-Agent", "TestAgent/1.0")
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
userID := uint(42)
svc.LogEvent(c, &userID, AuditEventLogin, map[string]interface{}{
"method": "password",
})
// Stop flushes the channel
svc.Stop()
var entries []models.AuditLog
err := db.Find(&entries).Error
require.NoError(t, err)
require.Len(t, entries, 1)
entry := entries[0]
assert.Equal(t, uint(42), *entry.UserID)
assert.Equal(t, AuditEventLogin, entry.EventType)
assert.Equal(t, "TestAgent/1.0", entry.UserAgent)
assert.NotEmpty(t, entry.IPAddress)
assert.Equal(t, "password", entry.Details["method"])
assert.False(t, entry.CreatedAt.IsZero())
}
func TestAuditService_LogEvent_NilUserID(t *testing.T) {
db := testutil.SetupTestDB(t)
svc := NewAuditService(db)
defer svc.Stop()
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
svc.LogEvent(c, nil, AuditEventLoginFailed, map[string]interface{}{
"identifier": "unknown@test.com",
})
svc.Stop()
var entries []models.AuditLog
err := db.Find(&entries).Error
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Nil(t, entries[0].UserID)
assert.Equal(t, AuditEventLoginFailed, entries[0].EventType)
}
func TestAuditService_LogEvent_NilContext(t *testing.T) {
db := testutil.SetupTestDB(t)
svc := NewAuditService(db)
defer svc.Stop()
userID := uint(1)
svc.LogEvent(nil, &userID, AuditEventLogout, nil)
svc.Stop()
var entries []models.AuditLog
err := db.Find(&entries).Error
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, AuditEventLogout, entries[0].EventType)
assert.Empty(t, entries[0].IPAddress)
assert.Empty(t, entries[0].UserAgent)
}
func TestAuditService_LogEvent_MultipleEvents(t *testing.T) {
db := testutil.SetupTestDB(t)
svc := NewAuditService(db)
defer svc.Stop()
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
userID := uint(10)
svc.LogEvent(c, &userID, AuditEventRegister, nil)
svc.LogEvent(c, &userID, AuditEventLogin, nil)
svc.LogEvent(c, &userID, AuditEventLogout, nil)
svc.Stop()
var count int64
err := db.Model(&models.AuditLog{}).Count(&count).Error
require.NoError(t, err)
assert.Equal(t, int64(3), count)
}
func TestAuditService_EventTypeConstants(t *testing.T) {
// Verify all event constants have expected values
assert.Equal(t, "auth.login", AuditEventLogin)
assert.Equal(t, "auth.login_failed", AuditEventLoginFailed)
assert.Equal(t, "auth.register", AuditEventRegister)
assert.Equal(t, "auth.logout", AuditEventLogout)
assert.Equal(t, "auth.password_reset", AuditEventPasswordReset)
assert.Equal(t, "auth.password_changed", AuditEventPasswordChanged)
assert.Equal(t, "auth.account_deleted", AuditEventAccountDeleted)
}
func TestAuditService_Stop_FlushesRemainingEntries(t *testing.T) {
db := testutil.SetupTestDB(t)
svc := NewAuditService(db)
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// Send many events quickly
for i := 0; i < 50; i++ {
uid := uint(i)
svc.LogEvent(c, &uid, AuditEventLogin, nil)
}
// Stop should block until all entries are written
svc.Stop()
var count int64
err := db.Model(&models.AuditLog{}).Count(&count).Error
require.NoError(t, err)
assert.Equal(t, int64(50), count)
}
func TestAuditLog_TableName(t *testing.T) {
log := models.AuditLog{}
assert.Equal(t, "audit_log", log.TableName())
}
func TestAuditLog_JSONMap_NilHandling(t *testing.T) {
db := testutil.SetupTestDB(t)
// Create entry with nil details
entry := &models.AuditLog{
EventType: "test",
CreatedAt: time.Now().UTC(),
}
err := db.Create(entry).Error
require.NoError(t, err)
// Read it back
var found models.AuditLog
err = db.First(&found, entry.ID).Error
require.NoError(t, err)
assert.Nil(t, found.Details)
}

View File

@@ -0,0 +1,173 @@
package services
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
)
func setupRefreshTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
require.NoError(t, err)
return db
}
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
t.Helper()
user := &models.User{
Username: "refreshtest",
Email: "refresh@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("password123"))
require.NoError(t, db.Create(user).Error)
return user
}
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
t.Helper()
token := &models.AuthToken{
UserID: userID,
}
require.NoError(t, db.Create(token).Error)
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
require.NoError(t, db.Model(token).Update("created", backdated).Error)
token.Created = backdated
return token
}
func newTestAuthService(db *gorm.DB) *AuthService {
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{
SecretKey: "test-secret",
TokenExpiryDays: 90,
TokenRefreshDays: 60,
},
}
return NewAuthService(userRepo, cfg)
}
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
assert.Contains(t, resp.Message, "still valid")
}
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
assert.Contains(t, resp.Message, "refreshed")
// Verify old token was deleted
var count int64
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
assert.Equal(t, int64(0), count, "old token should be deleted")
// Verify new token exists in DB
db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count)
assert.Equal(t, int64(1), count, "new token should exist in DB")
// Verify new token belongs to the same user
var newToken models.AuthToken
require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error)
assert.Equal(t, user.ID, newToken.UserID)
}
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
// meaning it enters the renewal window
token := createTokenWithAge(t, db, user.ID, 61)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
}
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken("nonexistent-token-key", user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 75)
svc := newTestAuthService(db)
// Try to refresh with a different user ID
resp, err := svc.RefreshToken(token.Key, user.ID+999)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(token.Key, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
}

View File

@@ -188,6 +188,66 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
}, code, nil }, code, nil
} }
// RefreshToken handles token refresh logic.
// - If token is expired (> expiryDays old), returns error (must re-login).
// - If token is in the renewal window (> refreshDays old), generates a new token.
// - If token is still fresh (< refreshDays old), returns the existing token (no-op).
func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
expiryDays := s.cfg.Security.TokenExpiryDays
if expiryDays <= 0 {
expiryDays = 90
}
refreshDays := s.cfg.Security.TokenRefreshDays
if refreshDays <= 0 {
refreshDays = 60
}
// Look up the token
authToken, err := s.userRepo.FindTokenByKey(tokenKey)
if err != nil {
return nil, apperrors.Unauthorized("error.invalid_token")
}
// Verify ownership
if authToken.UserID != userID {
return nil, apperrors.Unauthorized("error.invalid_token")
}
tokenAge := time.Since(authToken.Created)
expiryDuration := time.Duration(expiryDays) * 24 * time.Hour
refreshDuration := time.Duration(refreshDays) * 24 * time.Hour
// Token is expired — must re-login
if tokenAge > expiryDuration {
return nil, apperrors.Unauthorized("error.token_expired")
}
// Token is still fresh — no-op refresh
if tokenAge < refreshDuration {
return &responses.RefreshTokenResponse{
Token: tokenKey,
Message: "Token is still valid.",
}, nil
}
// Token is in the renewal window — generate a new one
// Delete the old token
if err := s.userRepo.DeleteToken(tokenKey); err != nil {
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
}
// Create a new token
newToken, err := s.userRepo.CreateToken(userID)
if err != nil {
return nil, apperrors.Internal(err)
}
return &responses.RefreshTokenResponse{
Token: newToken.Key,
Message: "Token refreshed successfully.",
}, nil
}
// Logout invalidates a user's token // Logout invalidates a user's token
func (s *AuthService) Logout(token string) error { func (s *AuthService) Logout(token string) error {
return s.userRepo.DeleteToken(token) return s.userRepo.DeleteToken(token)

View File

@@ -141,6 +141,12 @@ func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL) return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
} }
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
key := AuthTokenPrefix + token
return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
}
// GetCachedAuthToken gets a cached user ID for a token // GetCachedAuthToken gets a cached user ID for a token
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) { func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
key := AuthTokenPrefix + token key := AuthTokenPrefix + token
@@ -154,6 +160,24 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui
return userID, err return userID, err
} }
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
key := AuthTokenPrefix + token
val, err := c.GetString(ctx, key)
if err != nil {
return 0, 0, err
}
var userID uint
var createdUnix int64
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
if n < 1 {
return 0, 0, fmt.Errorf("invalid cached token format")
}
return userID, createdUnix, nil
}
// InvalidateAuthToken removes a cached token // InvalidateAuthToken removes a cached token
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error { func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
key := AuthTokenPrefix + token key := AuthTokenPrefix + token

View File

@@ -28,6 +28,7 @@ func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
mail.WithUsername(cfg.User), mail.WithUsername(cfg.User),
mail.WithPassword(cfg.Password), mail.WithPassword(cfg.Password),
mail.WithTLSPortPolicy(mail.TLSOpportunistic), mail.WithTLSPortPolicy(mail.TLSOpportunistic),
mail.WithTimeout(30*time.Second),
) )
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent") log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")

View File

@@ -69,6 +69,7 @@ func SetupTestDB(t *testing.T) *gorm.DB {
&models.FeatureBenefit{}, &models.FeatureBenefit{},
&models.UpgradeTrigger{}, &models.UpgradeTrigger{},
&models.Promotion{}, &models.Promotion{},
&models.AuditLog{},
) )
require.NoError(t, err) require.NoError(t, err)

View File

@@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"reflect" "reflect"
"strings" "strings"
"unicode"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@@ -27,9 +28,34 @@ func NewCustomValidator() *CustomValidator {
return name return name
}) })
// Register custom password complexity validator
v.RegisterValidation("password_complexity", validatePasswordComplexity)
return &CustomValidator{validator: v} return &CustomValidator{validator: v}
} }
// validatePasswordComplexity checks that a password contains at least one
// uppercase letter, one lowercase letter, and one digit.
// Minimum length is enforced separately via the "min" tag.
func validatePasswordComplexity(fl validator.FieldLevel) bool {
password := fl.Field().String()
var hasUpper, hasLower, hasDigit bool
for _, ch := range password {
switch {
case unicode.IsUpper(ch):
hasUpper = true
case unicode.IsLower(ch):
hasLower = true
case unicode.IsDigit(ch):
hasDigit = true
}
if hasUpper && hasLower && hasDigit {
return true
}
}
return hasUpper && hasLower && hasDigit
}
// Validate implements echo.Validator interface // Validate implements echo.Validator interface
func (cv *CustomValidator) Validate(i interface{}) error { func (cv *CustomValidator) Validate(i interface{}) error {
if err := cv.validator.Struct(i); err != nil { if err := cv.validator.Struct(i); err != nil {
@@ -96,6 +122,8 @@ func formatMessage(fe validator.FieldError) string {
return "Must be a valid URL" return "Must be a valid URL"
case "uuid": case "uuid":
return "Must be a valid UUID" return "Must be a valid UUID"
case "password_complexity":
return "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit"
default: default:
return "Invalid value" return "Invalid value"
} }

View File

@@ -0,0 +1,115 @@
package validator
import (
"testing"
govalidator "github.com/go-playground/validator/v10"
)
func TestValidatePasswordComplexity(t *testing.T) {
tests := []struct {
name string
password string
valid bool
}{
{"valid password", "Password1", true},
{"valid complex password", "MyP@ssw0rd!", true},
{"missing uppercase", "password1", false},
{"missing lowercase", "PASSWORD1", false},
{"missing digit", "Password", false},
{"only digits", "12345678", false},
{"only lowercase", "abcdefgh", false},
{"only uppercase", "ABCDEFGH", false},
{"empty string", "", false},
{"single valid char each", "aA1", true},
{"unicode uppercase with digit and lower", "Über1abc", true},
}
v := govalidator.New()
v.RegisterValidation("password_complexity", validatePasswordComplexity)
type testStruct struct {
Password string `validate:"password_complexity"`
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s := testStruct{Password: tc.password}
err := v.Struct(s)
if tc.valid && err != nil {
t.Errorf("expected password %q to be valid, got error: %v", tc.password, err)
}
if !tc.valid && err == nil {
t.Errorf("expected password %q to be invalid, got nil error", tc.password)
}
})
}
}
func TestValidatePasswordComplexityWithMinLength(t *testing.T) {
v := govalidator.New()
v.RegisterValidation("password_complexity", validatePasswordComplexity)
type request struct {
Password string `validate:"required,min=8,password_complexity"`
}
tests := []struct {
name string
password string
valid bool
}{
{"valid 8+ chars with complexity", "Abcdefg1", true},
{"too short but complex", "Ab1", false},
{"long but no uppercase", "abcdefgh1", false},
{"long but no lowercase", "ABCDEFGH1", false},
{"long but no digit", "Abcdefghi", false},
{"empty", "", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := request{Password: tc.password}
err := v.Struct(r)
if tc.valid && err != nil {
t.Errorf("expected %q to be valid, got error: %v", tc.password, err)
}
if !tc.valid && err == nil {
t.Errorf("expected %q to be invalid, got nil", tc.password)
}
})
}
}
func TestFormatMessagePasswordComplexity(t *testing.T) {
cv := NewCustomValidator()
type request struct {
Password string `json:"password" validate:"required,min=8,password_complexity"`
}
r := request{Password: "lowercase1"}
err := cv.Validate(r)
if err == nil {
t.Fatal("expected validation error for password without uppercase")
}
resp := FormatValidationErrors(err)
if resp == nil {
t.Fatal("expected non-nil error response")
}
field, ok := resp.Fields["password"]
if !ok {
t.Fatal("expected 'password' field in error response")
}
expectedMsg := "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit"
if field.Message != expectedMsg {
t.Errorf("expected message %q, got %q", expectedMsg, field.Message)
}
if field.Tag != "password_complexity" {
t.Errorf("expected tag 'password_complexity', got %q", field.Tag)
}
}

View File

@@ -0,0 +1,2 @@
-- No-op: the created column is part of Django's original schema and should not
-- be removed.

View File

@@ -0,0 +1,6 @@
-- Ensure created column exists on user_authtoken (Django already creates it,
-- but this migration guarantees it for fresh Go-only deployments).
ALTER TABLE user_authtoken ADD COLUMN IF NOT EXISTS created TIMESTAMP WITH TIME ZONE DEFAULT NOW();
-- Backfill any rows that may have a NULL created timestamp.
UPDATE user_authtoken SET created = NOW() WHERE created IS NULL;

View File

@@ -0,0 +1,40 @@
-- Rollback: 017_fk_indexes
-- Drop all FK indexes added in the up migration.
-- auth / user tables
DROP INDEX IF EXISTS idx_authtoken_user_id;
DROP INDEX IF EXISTS idx_userprofile_user_id;
DROP INDEX IF EXISTS idx_confirmationcode_user_id;
DROP INDEX IF EXISTS idx_passwordresetcode_user_id;
DROP INDEX IF EXISTS idx_applesocialauth_user_id;
DROP INDEX IF EXISTS idx_googlesocialauth_user_id;
-- push notification device tables
DROP INDEX IF EXISTS idx_apnsdevice_user_id;
DROP INDEX IF EXISTS idx_gcmdevice_user_id;
-- notification tables
DROP INDEX IF EXISTS idx_notificationpreference_user_id;
-- subscription tables
DROP INDEX IF EXISTS idx_subscription_user_id;
-- residence tables
DROP INDEX IF EXISTS idx_residence_owner_id;
DROP INDEX IF EXISTS idx_sharecode_residence_id;
DROP INDEX IF EXISTS idx_sharecode_created_by_id;
-- task tables
DROP INDEX IF EXISTS idx_task_created_by_id;
DROP INDEX IF EXISTS idx_task_assigned_to_id;
DROP INDEX IF EXISTS idx_task_category_id;
DROP INDEX IF EXISTS idx_task_priority_id;
DROP INDEX IF EXISTS idx_task_frequency_id;
DROP INDEX IF EXISTS idx_task_contractor_id;
DROP INDEX IF EXISTS idx_task_parent_task_id;
DROP INDEX IF EXISTS idx_completionimage_completion_id;
DROP INDEX IF EXISTS idx_document_created_by_id;
DROP INDEX IF EXISTS idx_document_task_id;
DROP INDEX IF EXISTS idx_documentimage_document_id;
DROP INDEX IF EXISTS idx_contractor_residence_id;
DROP INDEX IF EXISTS idx_reminderlog_notification_id;

View File

@@ -0,0 +1,131 @@
-- Migration: 017_fk_indexes
-- Add indexes on all foreign key columns that are not already covered by existing indexes.
-- Uses CREATE INDEX IF NOT EXISTS to be idempotent (safe to re-run).
-- =====================================================
-- auth / user tables
-- =====================================================
-- user_authtoken: user_id (unique FK, but ensure index exists)
CREATE UNIQUE INDEX IF NOT EXISTS idx_authtoken_user_id
ON user_authtoken (user_id);
-- user_userprofile: user_id (unique FK)
CREATE UNIQUE INDEX IF NOT EXISTS idx_userprofile_user_id
ON user_userprofile (user_id);
-- user_confirmationcode: user_id
CREATE INDEX IF NOT EXISTS idx_confirmationcode_user_id
ON user_confirmationcode (user_id);
-- user_passwordresetcode: user_id
CREATE INDEX IF NOT EXISTS idx_passwordresetcode_user_id
ON user_passwordresetcode (user_id);
-- user_applesocialauth: user_id (unique FK)
CREATE UNIQUE INDEX IF NOT EXISTS idx_applesocialauth_user_id
ON user_applesocialauth (user_id);
-- user_googlesocialauth: user_id (unique FK)
CREATE UNIQUE INDEX IF NOT EXISTS idx_googlesocialauth_user_id
ON user_googlesocialauth (user_id);
-- =====================================================
-- push notification device tables
-- =====================================================
-- push_notifications_apnsdevice: user_id
CREATE INDEX IF NOT EXISTS idx_apnsdevice_user_id
ON push_notifications_apnsdevice (user_id);
-- push_notifications_gcmdevice: user_id
CREATE INDEX IF NOT EXISTS idx_gcmdevice_user_id
ON push_notifications_gcmdevice (user_id);
-- =====================================================
-- notification tables
-- =====================================================
-- notifications_notificationpreference: user_id (unique FK)
CREATE UNIQUE INDEX IF NOT EXISTS idx_notificationpreference_user_id
ON notifications_notificationpreference (user_id);
-- =====================================================
-- subscription tables
-- =====================================================
-- subscription_usersubscription: user_id (unique FK)
CREATE UNIQUE INDEX IF NOT EXISTS idx_subscription_user_id
ON subscription_usersubscription (user_id);
-- =====================================================
-- residence tables
-- =====================================================
-- residence_residence: owner_id
CREATE INDEX IF NOT EXISTS idx_residence_owner_id
ON residence_residence (owner_id);
-- residence_residencesharecode: residence_id (may already exist from model index tag via GORM)
CREATE INDEX IF NOT EXISTS idx_sharecode_residence_id
ON residence_residencesharecode (residence_id);
-- residence_residencesharecode: created_by_id
CREATE INDEX IF NOT EXISTS idx_sharecode_created_by_id
ON residence_residencesharecode (created_by_id);
-- =====================================================
-- task tables
-- =====================================================
-- task_task: created_by_id
CREATE INDEX IF NOT EXISTS idx_task_created_by_id
ON task_task (created_by_id);
-- task_task: assigned_to_id
CREATE INDEX IF NOT EXISTS idx_task_assigned_to_id
ON task_task (assigned_to_id);
-- task_task: category_id
CREATE INDEX IF NOT EXISTS idx_task_category_id
ON task_task (category_id);
-- task_task: priority_id
CREATE INDEX IF NOT EXISTS idx_task_priority_id
ON task_task (priority_id);
-- task_task: frequency_id
CREATE INDEX IF NOT EXISTS idx_task_frequency_id
ON task_task (frequency_id);
-- task_task: contractor_id
CREATE INDEX IF NOT EXISTS idx_task_contractor_id
ON task_task (contractor_id);
-- task_task: parent_task_id
CREATE INDEX IF NOT EXISTS idx_task_parent_task_id
ON task_task (parent_task_id);
-- task_taskcompletionimage: completion_id
CREATE INDEX IF NOT EXISTS idx_completionimage_completion_id
ON task_taskcompletionimage (completion_id);
-- task_document: created_by_id
CREATE INDEX IF NOT EXISTS idx_document_created_by_id
ON task_document (created_by_id);
-- task_document: task_id
CREATE INDEX IF NOT EXISTS idx_document_task_id
ON task_document (task_id);
-- task_documentimage: document_id
CREATE INDEX IF NOT EXISTS idx_documentimage_document_id
ON task_documentimage (document_id);
-- task_contractor: residence_id
CREATE INDEX IF NOT EXISTS idx_contractor_residence_id
ON task_contractor (residence_id);
-- task_reminderlog: notification_id
CREATE INDEX IF NOT EXISTS idx_reminderlog_notification_id
ON task_reminderlog (notification_id);

View File

@@ -0,0 +1,2 @@
-- Rollback: 018_audit_log
DROP TABLE IF EXISTS audit_log;

View File

@@ -0,0 +1,16 @@
-- Migration: 018_audit_log
-- Create audit_log table for tracking security-relevant events (login, register, etc.)
CREATE TABLE IF NOT EXISTS audit_log (
id SERIAL PRIMARY KEY,
user_id INTEGER,
event_type VARCHAR(50) NOT NULL,
ip_address VARCHAR(45),
user_agent TEXT,
details JSONB,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
CREATE INDEX idx_audit_log_user_id ON audit_log(user_id);
CREATE INDEX idx_audit_log_event_type ON audit_log(event_type);
CREATE INDEX idx_audit_log_created_at ON audit_log(created_at);