Production hardening: security, resilience, observability, and compliance
Password complexity: custom validator requiring uppercase, lowercase, digit (min 8 chars)
Token expiry: 90-day token lifetime with refresh endpoint (60-90 day renewal window)
Health check: /api/health/ now pings Postgres + Redis, returns 503 on failure
Audit logging: async audit_log table for auth events (login, register, delete, etc.)
Circuit breaker: APNs/FCM push sends wrapped with 5-failure threshold, 30s recovery
FK indexes: 27 missing foreign key indexes across all tables (migration 017)
CSP header: default-src 'none'; frame-ancestors 'none'
Gzip compression: level 5 with media endpoint skipper
Prometheus metrics: /metrics endpoint using existing monitoring service
External timeouts: 15s push, 30s SMTP, context timeouts on all external calls
Migrations: 016 (token created_at), 017 (FK indexes), 018 (audit_log)
Tests: circuit breaker (15), audit service (8), token refresh (7), health (4),
middleware expiry (5), validator (new)
This commit is contained in:
@@ -134,6 +134,8 @@ type SecurityConfig struct {
|
|||||||
PasswordResetExpiry time.Duration
|
PasswordResetExpiry time.Duration
|
||||||
ConfirmationExpiry time.Duration
|
ConfirmationExpiry time.Duration
|
||||||
MaxPasswordResetRate int // per hour
|
MaxPasswordResetRate int // per hour
|
||||||
|
TokenExpiryDays int // Number of days before auth tokens expire (default 90)
|
||||||
|
TokenRefreshDays int // Token must be at least this many days old before refresh (default 60)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StorageConfig holds file storage settings
|
// StorageConfig holds file storage settings
|
||||||
@@ -262,6 +264,8 @@ func Load() (*Config, error) {
|
|||||||
PasswordResetExpiry: 15 * time.Minute,
|
PasswordResetExpiry: 15 * time.Minute,
|
||||||
ConfirmationExpiry: 24 * time.Hour,
|
ConfirmationExpiry: 24 * time.Hour,
|
||||||
MaxPasswordResetRate: 3,
|
MaxPasswordResetRate: 3,
|
||||||
|
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
|
||||||
|
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
|
||||||
},
|
},
|
||||||
Storage: StorageConfig{
|
Storage: StorageConfig{
|
||||||
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
||||||
@@ -369,6 +373,10 @@ func setDefaults() {
|
|||||||
viper.SetDefault("OVERDUE_REMINDER_HOUR", 15) // 9:00 AM UTC
|
viper.SetDefault("OVERDUE_REMINDER_HOUR", 15) // 9:00 AM UTC
|
||||||
viper.SetDefault("DAILY_DIGEST_HOUR", 3) // 3:00 AM UTC
|
viper.SetDefault("DAILY_DIGEST_HOUR", 3) // 3:00 AM UTC
|
||||||
|
|
||||||
|
// Token expiry defaults
|
||||||
|
viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days
|
||||||
|
viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days
|
||||||
|
|
||||||
// Storage defaults
|
// Storage defaults
|
||||||
viper.SetDefault("STORAGE_UPLOAD_DIR", "./uploads")
|
viper.SetDefault("STORAGE_UPLOAD_DIR", "./uploads")
|
||||||
viper.SetDefault("STORAGE_BASE_URL", "/uploads")
|
viper.SetDefault("STORAGE_BASE_URL", "/uploads")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ type LoginRequest struct {
|
|||||||
type RegisterRequest struct {
|
type RegisterRequest struct {
|
||||||
Username string `json:"username" validate:"required,min=3,max=150"`
|
Username string `json:"username" validate:"required,min=3,max=150"`
|
||||||
Email string `json:"email" validate:"required,email,max=254"`
|
Email string `json:"email" validate:"required,email,max=254"`
|
||||||
Password string `json:"password" validate:"required,min=8"`
|
Password string `json:"password" validate:"required,min=8,password_complexity"`
|
||||||
FirstName string `json:"first_name" validate:"max=150"`
|
FirstName string `json:"first_name" validate:"max=150"`
|
||||||
LastName string `json:"last_name" validate:"max=150"`
|
LastName string `json:"last_name" validate:"max=150"`
|
||||||
}
|
}
|
||||||
@@ -35,7 +35,7 @@ type VerifyResetCodeRequest struct {
|
|||||||
// ResetPasswordRequest represents the reset password request body
|
// ResetPasswordRequest represents the reset password request body
|
||||||
type ResetPasswordRequest struct {
|
type ResetPasswordRequest struct {
|
||||||
ResetToken string `json:"reset_token" validate:"required"`
|
ResetToken string `json:"reset_token" validate:"required"`
|
||||||
NewPassword string `json:"new_password" validate:"required,min=8"`
|
NewPassword string `json:"new_password" validate:"required,min=8,password_complexity"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProfileRequest represents the profile update request body
|
// UpdateProfileRequest represents the profile update request body
|
||||||
|
|||||||
@@ -79,6 +79,12 @@ type ResetPasswordResponse struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshTokenResponse represents the token refresh response
|
||||||
|
type RefreshTokenResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
// MessageResponse represents a simple message response
|
// MessageResponse represents a simple message response
|
||||||
type MessageResponse struct {
|
type MessageResponse struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ type AuthHandler struct {
|
|||||||
appleAuthService *services.AppleAuthService
|
appleAuthService *services.AppleAuthService
|
||||||
googleAuthService *services.GoogleAuthService
|
googleAuthService *services.GoogleAuthService
|
||||||
storageService *services.StorageService
|
storageService *services.StorageService
|
||||||
|
auditService *services.AuditService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthHandler creates a new auth handler
|
// NewAuthHandler creates a new auth handler
|
||||||
@@ -49,6 +50,11 @@ func (h *AuthHandler) SetStorageService(storageService *services.StorageService)
|
|||||||
h.storageService = storageService
|
h.storageService = storageService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAuditService sets the audit service for logging security events
|
||||||
|
func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
|
||||||
|
h.auditService = auditService
|
||||||
|
}
|
||||||
|
|
||||||
// Login handles POST /api/auth/login/
|
// Login handles POST /api/auth/login/
|
||||||
func (h *AuthHandler) Login(c echo.Context) error {
|
func (h *AuthHandler) Login(c echo.Context) error {
|
||||||
var req requests.LoginRequest
|
var req requests.LoginRequest
|
||||||
@@ -62,9 +68,19 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
|||||||
response, err := h.authService.Login(&req)
|
response, err := h.authService.Login(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
|
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
|
||||||
|
if h.auditService != nil {
|
||||||
|
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
|
||||||
|
"identifier": req.Username,
|
||||||
|
})
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.auditService != nil {
|
||||||
|
userID := response.User.ID
|
||||||
|
h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil)
|
||||||
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, response)
|
return c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,6 +100,14 @@ func (h *AuthHandler) Register(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.auditService != nil {
|
||||||
|
userID := response.User.ID
|
||||||
|
h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{
|
||||||
|
"username": req.Username,
|
||||||
|
"email": req.Email,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Send welcome email with confirmation code (async)
|
// Send welcome email with confirmation code (async)
|
||||||
if h.emailService != nil && confirmationCode != "" {
|
if h.emailService != nil && confirmationCode != "" {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -108,6 +132,14 @@ func (h *AuthHandler) Logout(c echo.Context) error {
|
|||||||
return apperrors.Unauthorized("error.not_authenticated")
|
return apperrors.Unauthorized("error.not_authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log audit event before invalidating the token
|
||||||
|
if h.auditService != nil {
|
||||||
|
user := middleware.GetAuthUser(c)
|
||||||
|
if user != nil {
|
||||||
|
h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Invalidate token in database
|
// Invalidate token in database
|
||||||
if err := h.authService.Logout(token); err != nil {
|
if err := h.authService.Logout(token); err != nil {
|
||||||
log.Warn().Err(err).Msg("Failed to delete token from database")
|
log.Warn().Err(err).Msg("Failed to delete token from database")
|
||||||
@@ -270,6 +302,12 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.auditService != nil {
|
||||||
|
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
|
||||||
|
"email": req.Email,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Always return success to prevent email enumeration
|
// Always return success to prevent email enumeration
|
||||||
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
|
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
|
||||||
Message: "Password reset email sent",
|
Message: "Password reset email sent",
|
||||||
@@ -314,6 +352,12 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.auditService != nil {
|
||||||
|
h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{
|
||||||
|
"method": "reset_token",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return c.JSON(http.StatusOK, responses.ResetPasswordResponse{
|
return c.JSON(http.StatusOK, responses.ResetPasswordResponse{
|
||||||
Message: "Password reset successful",
|
Message: "Password reset successful",
|
||||||
})
|
})
|
||||||
@@ -413,6 +457,34 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
|||||||
return c.JSON(http.StatusOK, response)
|
return c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshToken handles POST /api/auth/refresh/
|
||||||
|
func (h *AuthHandler) RefreshToken(c echo.Context) error {
|
||||||
|
user, err := middleware.MustGetAuthUser(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
token := middleware.GetAuthToken(c)
|
||||||
|
if token == "" {
|
||||||
|
return apperrors.Unauthorized("error.not_authenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := h.authService.RefreshToken(token, user.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the token was refreshed (new token), invalidate the old one from cache
|
||||||
|
if response.Token != token && h.cache != nil {
|
||||||
|
if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil {
|
||||||
|
log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteAccount handles DELETE /api/auth/account/
|
// DeleteAccount handles DELETE /api/auth/account/
|
||||||
func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
||||||
user, err := middleware.MustGetAuthUser(c)
|
user, err := middleware.MustGetAuthUser(c)
|
||||||
@@ -431,6 +503,14 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.auditService != nil {
|
||||||
|
h.auditService.LogEvent(c, &user.ID, services.AuditEventAccountDeleted, map[string]interface{}{
|
||||||
|
"user_id": user.ID,
|
||||||
|
"username": user.Username,
|
||||||
|
"email": user.Email,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Delete files from disk (best effort, don't fail the request)
|
// Delete files from disk (best effort, don't fail the request)
|
||||||
if h.storageService != nil && len(fileURLs) > 0 {
|
if h.storageService != nil && len(fileURLs) > 0 {
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -6,8 +6,11 @@
|
|||||||
"error.email_taken": "Email already registered",
|
"error.email_taken": "Email already registered",
|
||||||
"error.email_already_taken": "Email already taken",
|
"error.email_already_taken": "Email already taken",
|
||||||
"error.registration_failed": "Registration failed",
|
"error.registration_failed": "Registration failed",
|
||||||
|
"error.password_complexity": "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit",
|
||||||
"error.not_authenticated": "Not authenticated",
|
"error.not_authenticated": "Not authenticated",
|
||||||
"error.invalid_token": "Invalid token",
|
"error.invalid_token": "Invalid token",
|
||||||
|
"error.token_expired": "Your session has expired. Please log in again.",
|
||||||
|
"error.token_refresh_not_needed": "Token is still valid.",
|
||||||
"error.failed_to_get_user": "Failed to get user",
|
"error.failed_to_get_user": "Failed to get user",
|
||||||
"error.failed_to_update_profile": "Failed to update profile",
|
"error.failed_to_update_profile": "Failed to update profile",
|
||||||
"error.invalid_verification_code": "Invalid verification code",
|
"error.invalid_verification_code": "Invalid verification code",
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/config"
|
||||||
"github.com/treytartt/honeydue-api/internal/models"
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
"github.com/treytartt/honeydue-api/internal/services"
|
"github.com/treytartt/honeydue-api/internal/services"
|
||||||
)
|
)
|
||||||
@@ -28,24 +29,56 @@ const (
|
|||||||
// UserCacheTTL is how long full user records are cached in memory to
|
// UserCacheTTL is how long full user records are cached in memory to
|
||||||
// avoid hitting the database on every authenticated request.
|
// avoid hitting the database on every authenticated request.
|
||||||
UserCacheTTL = 30 * time.Second
|
UserCacheTTL = 30 * time.Second
|
||||||
|
|
||||||
|
// DefaultTokenExpiryDays is the default number of days before a token expires.
|
||||||
|
DefaultTokenExpiryDays = 90
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthMiddleware provides token authentication middleware
|
// AuthMiddleware provides token authentication middleware
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
cache *services.CacheService
|
cache *services.CacheService
|
||||||
userCache *UserCache
|
userCache *UserCache
|
||||||
|
tokenExpiryDays int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates a new auth middleware instance
|
// NewAuthMiddleware creates a new auth middleware instance
|
||||||
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
db: db,
|
db: db,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
userCache: NewUserCache(UserCacheTTL),
|
userCache: NewUserCache(UserCacheTTL),
|
||||||
|
tokenExpiryDays: DefaultTokenExpiryDays,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration
|
||||||
|
func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware {
|
||||||
|
expiryDays := DefaultTokenExpiryDays
|
||||||
|
if cfg != nil && cfg.Security.TokenExpiryDays > 0 {
|
||||||
|
expiryDays = cfg.Security.TokenExpiryDays
|
||||||
|
}
|
||||||
|
return &AuthMiddleware{
|
||||||
|
db: db,
|
||||||
|
cache: cache,
|
||||||
|
userCache: NewUserCache(UserCacheTTL),
|
||||||
|
tokenExpiryDays: expiryDays,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenExpiryDuration returns the token expiry duration.
|
||||||
|
func (m *AuthMiddleware) TokenExpiryDuration() time.Duration {
|
||||||
|
return time.Duration(m.tokenExpiryDays) * 24 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTokenExpired checks if a token's created timestamp indicates expiry.
|
||||||
|
func (m *AuthMiddleware) isTokenExpired(created time.Time) bool {
|
||||||
|
if created.IsZero() {
|
||||||
|
return false // Legacy tokens without created time are not expired
|
||||||
|
}
|
||||||
|
return time.Since(created) > m.TokenExpiryDuration()
|
||||||
|
}
|
||||||
|
|
||||||
// TokenAuth returns an Echo middleware that validates token authentication
|
// TokenAuth returns an Echo middleware that validates token authentication
|
||||||
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
@@ -56,7 +89,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
|||||||
return apperrors.Unauthorized("error.not_authenticated")
|
return apperrors.Unauthorized("error.not_authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to get user from cache first
|
// Try to get user from cache first (includes expiry check)
|
||||||
user, err := m.getUserFromCache(c.Request().Context(), token)
|
user, err := m.getUserFromCache(c.Request().Context(), token)
|
||||||
if err == nil && user != nil {
|
if err == nil && user != nil {
|
||||||
// Cache hit - set user in context and continue
|
// Cache hit - set user in context and continue
|
||||||
@@ -65,16 +98,27 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
|||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the cache indicated token expiry
|
||||||
|
if err != nil && err.Error() == "token expired" {
|
||||||
|
return apperrors.Unauthorized("error.token_expired")
|
||||||
|
}
|
||||||
|
|
||||||
// Cache miss - look up token in database
|
// Cache miss - look up token in database
|
||||||
user, err = m.getUserFromDatabase(token)
|
user, authToken, err := m.getUserFromDatabaseWithToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
||||||
return apperrors.Unauthorized("error.invalid_token")
|
return apperrors.Unauthorized("error.invalid_token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the user ID for future requests
|
// Check token expiry
|
||||||
if cacheErr := m.cacheUserID(c.Request().Context(), token, user.ID); cacheErr != nil {
|
if m.isTokenExpired(authToken.Created) {
|
||||||
log.Warn().Err(cacheErr).Msg("Failed to cache user ID")
|
log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired")
|
||||||
|
return apperrors.Unauthorized("error.token_expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the user ID and token creation time for future requests
|
||||||
|
if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil {
|
||||||
|
log.Warn().Err(cacheErr).Msg("Failed to cache token info")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set user in context
|
// Set user in context
|
||||||
@@ -104,9 +148,9 @@ func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try database
|
// Try database
|
||||||
user, err = m.getUserFromDatabase(token)
|
user, authToken, err := m.getUserFromDatabaseWithToken(token)
|
||||||
if err == nil {
|
if err == nil && !m.isTokenExpired(authToken.Created) {
|
||||||
m.cacheUserID(c.Request().Context(), token, user.ID)
|
m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created)
|
||||||
c.Set(AuthUserKey, user)
|
c.Set(AuthUserKey, user)
|
||||||
c.Set(AuthTokenKey, token)
|
c.Set(AuthTokenKey, token)
|
||||||
}
|
}
|
||||||
@@ -145,12 +189,13 @@ func extractToken(c echo.Context) (string, error) {
|
|||||||
|
|
||||||
// getUserFromCache tries to get user from Redis cache, then from the
|
// getUserFromCache tries to get user from Redis cache, then from the
|
||||||
// in-memory user cache, before falling back to the database.
|
// in-memory user cache, before falling back to the database.
|
||||||
|
// Returns a "token expired" error if the cached creation time indicates expiry.
|
||||||
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
||||||
if m.cache == nil {
|
if m.cache == nil {
|
||||||
return nil, fmt.Errorf("cache not available")
|
return nil, fmt.Errorf("cache not available")
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := m.cache.GetCachedAuthToken(ctx, token)
|
userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, fmt.Errorf("token not in cache")
|
return nil, fmt.Errorf("token not in cache")
|
||||||
@@ -158,6 +203,15 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check token expiry from cached creation time
|
||||||
|
if createdUnix > 0 {
|
||||||
|
created := time.Unix(createdUnix, 0)
|
||||||
|
if m.isTokenExpired(created) {
|
||||||
|
m.cache.InvalidateAuthToken(ctx, token)
|
||||||
|
return nil, fmt.Errorf("token expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Try in-memory user cache first to avoid a DB round-trip
|
// Try in-memory user cache first to avoid a DB round-trip
|
||||||
if cached := m.userCache.Get(userID); cached != nil {
|
if cached := m.userCache.Get(userID); cached != nil {
|
||||||
if !cached.IsActive {
|
if !cached.IsActive {
|
||||||
@@ -187,22 +241,38 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserFromDatabase looks up the token in the database and caches the
|
// getUserFromDatabaseWithToken looks up the token in the database and returns
|
||||||
// resulting user record in memory.
|
// both the user and the auth token record (for expiry checking).
|
||||||
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
|
func (m *AuthMiddleware) getUserFromDatabaseWithToken(token string) (*models.User, *models.AuthToken, error) {
|
||||||
var authToken models.AuthToken
|
var authToken models.AuthToken
|
||||||
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
|
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
|
||||||
return nil, fmt.Errorf("token not found")
|
return nil, nil, fmt.Errorf("token not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user is active
|
// Check if user is active
|
||||||
if !authToken.User.IsActive {
|
if !authToken.User.IsActive {
|
||||||
return nil, fmt.Errorf("user is inactive")
|
return nil, nil, fmt.Errorf("user is inactive")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store in in-memory cache for subsequent requests
|
// Store in in-memory cache for subsequent requests
|
||||||
m.userCache.Set(&authToken.User)
|
m.userCache.Set(&authToken.User)
|
||||||
return &authToken.User, nil
|
return &authToken.User, &authToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserFromDatabase looks up the token in the database and caches the
|
||||||
|
// resulting user record in memory.
|
||||||
|
// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking.
|
||||||
|
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
|
||||||
|
user, _, err := m.getUserFromDatabaseWithToken(token)
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheTokenInfo caches the user ID and token creation time for a token
|
||||||
|
func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error {
|
||||||
|
if m.cache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheUserID caches the user ID for a token
|
// cacheUserID caches the user ID for a token
|
||||||
|
|||||||
165
internal/middleware/auth_expiry_test.go
Normal file
165
internal/middleware/auth_expiry_test.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupTestDB creates a temporary in-memory SQLite database with the required
|
||||||
|
// tables for auth middleware tests.
|
||||||
|
func setupTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&models.User{}, &models.AuthToken{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTestUserAndToken creates a user and an auth token, then backdates the
|
||||||
|
// token's Created timestamp by the specified number of days.
|
||||||
|
func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
user := &models.User{
|
||||||
|
Username: username,
|
||||||
|
Email: username + "@test.com",
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
require.NoError(t, user.SetPassword("password123"))
|
||||||
|
require.NoError(t, db.Create(user).Error)
|
||||||
|
|
||||||
|
token := &models.AuthToken{
|
||||||
|
UserID: user.ID,
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(token).Error)
|
||||||
|
|
||||||
|
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||||
|
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||||
|
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||||
|
token.Created = backdated
|
||||||
|
|
||||||
|
return user, token
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
_, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry
|
||||||
|
|
||||||
|
m := NewAuthMiddleware(db, nil) // No Redis cache for these tests
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+token.Key)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
err := handler(c)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "error.token_expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenAuth_AcceptsValidToken(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
_, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry
|
||||||
|
|
||||||
|
m := NewAuthMiddleware(db, nil)
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+token.Key)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
err := handler(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
// Verify user was set in context
|
||||||
|
user := GetAuthUser(c)
|
||||||
|
require.NotNil(t, user)
|
||||||
|
assert.Equal(t, "valid_user", user.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
_, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry
|
||||||
|
|
||||||
|
m := NewAuthMiddleware(db, nil)
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||||
|
req.Header.Set("Authorization", "Token "+token.Key)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
err := handler(c)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenAuth_RejectsInvalidToken(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
|
||||||
|
m := NewAuthMiddleware(db, nil)
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||||
|
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
err := handler(c)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
|
||||||
|
m := NewAuthMiddleware(db, nil)
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "ok")
|
||||||
|
})
|
||||||
|
|
||||||
|
err := handler(c)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||||
|
}
|
||||||
48
internal/models/audit_log.go
Normal file
48
internal/models/audit_log.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JSONMap is a custom type for JSONB columns that handles JSON serialization
|
||||||
|
type JSONMap map[string]interface{}
|
||||||
|
|
||||||
|
// Value implements driver.Valuer for database writes
|
||||||
|
func (j JSONMap) Value() (driver.Value, error) {
|
||||||
|
if j == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return json.Marshal(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements sql.Scanner for database reads
|
||||||
|
func (j *JSONMap) Scan(value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
*j = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("audit_log: failed to scan JSONMap value")
|
||||||
|
}
|
||||||
|
return json.Unmarshal(bytes, j)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditLog represents the audit_log table for tracking security-relevant events
|
||||||
|
type AuditLog struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"id"`
|
||||||
|
UserID *uint `gorm:"column:user_id" json:"user_id"`
|
||||||
|
EventType string `gorm:"column:event_type;size:50;not null" json:"event_type"`
|
||||||
|
IPAddress string `gorm:"column:ip_address;size:45" json:"ip_address"`
|
||||||
|
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent"`
|
||||||
|
Details JSONMap `gorm:"column:details;type:jsonb" json:"details"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName returns the table name for GORM
|
||||||
|
func (AuditLog) TableName() string {
|
||||||
|
return "audit_log"
|
||||||
|
}
|
||||||
167
internal/push/circuit_breaker.go
Normal file
167
internal/push/circuit_breaker.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package push
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Circuit breaker states
|
||||||
|
const (
|
||||||
|
stateClosed = iota // Normal operation, requests pass through
|
||||||
|
stateOpen // Too many failures, requests are rejected
|
||||||
|
stateHalfOpen // Testing recovery, one request allowed through
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default circuit breaker settings
|
||||||
|
const (
|
||||||
|
defaultFailureThreshold = 5 // Open after this many consecutive failures
|
||||||
|
defaultRecoveryTimeout = 30 * time.Second // Try again after this duration
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrCircuitOpen is returned when the circuit breaker is open and rejecting requests.
|
||||||
|
var ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||||
|
|
||||||
|
// CircuitBreaker implements a simple circuit breaker pattern for external service calls.
|
||||||
|
// It is thread-safe and requires no external dependencies.
|
||||||
|
//
|
||||||
|
// States:
|
||||||
|
// - Closed: normal operation, all requests pass through. Consecutive failures are counted.
|
||||||
|
// - Open: after reaching the failure threshold, all requests are immediately rejected
|
||||||
|
// with ErrCircuitOpen until the recovery timeout elapses.
|
||||||
|
// - Half-Open: after the recovery timeout, one request is allowed through. If it
|
||||||
|
// succeeds the breaker resets to Closed; if it fails it returns to Open.
|
||||||
|
type CircuitBreaker struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
state int
|
||||||
|
failureCount int
|
||||||
|
failureThreshold int
|
||||||
|
recoveryTimeout time.Duration
|
||||||
|
lastFailureTime time.Time
|
||||||
|
name string // For logging
|
||||||
|
}
|
||||||
|
|
||||||
|
// CircuitBreakerOption configures a CircuitBreaker.
|
||||||
|
type CircuitBreakerOption func(*CircuitBreaker)
|
||||||
|
|
||||||
|
// WithFailureThreshold sets the number of consecutive failures before opening the circuit.
|
||||||
|
func WithFailureThreshold(n int) CircuitBreakerOption {
|
||||||
|
return func(cb *CircuitBreaker) {
|
||||||
|
if n > 0 {
|
||||||
|
cb.failureThreshold = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRecoveryTimeout sets how long the circuit stays open before trying half-open.
|
||||||
|
func WithRecoveryTimeout(d time.Duration) CircuitBreakerOption {
|
||||||
|
return func(cb *CircuitBreaker) {
|
||||||
|
if d > 0 {
|
||||||
|
cb.recoveryTimeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCircuitBreaker creates a new CircuitBreaker with the given name and options.
|
||||||
|
// The name is used for logging and identification.
|
||||||
|
func NewCircuitBreaker(name string, opts ...CircuitBreakerOption) *CircuitBreaker {
|
||||||
|
cb := &CircuitBreaker{
|
||||||
|
state: stateClosed,
|
||||||
|
failureThreshold: defaultFailureThreshold,
|
||||||
|
recoveryTimeout: defaultRecoveryTimeout,
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cb)
|
||||||
|
}
|
||||||
|
return cb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow checks whether a request should be allowed through.
|
||||||
|
// It returns true if the request can proceed, false if the circuit is open.
|
||||||
|
// When transitioning from open to half-open, it returns true for the probe request.
|
||||||
|
func (cb *CircuitBreaker) Allow() bool {
|
||||||
|
cb.mu.Lock()
|
||||||
|
defer cb.mu.Unlock()
|
||||||
|
|
||||||
|
switch cb.state {
|
||||||
|
case stateClosed:
|
||||||
|
return true
|
||||||
|
case stateOpen:
|
||||||
|
// Check if recovery timeout has elapsed
|
||||||
|
if time.Since(cb.lastFailureTime) >= cb.recoveryTimeout {
|
||||||
|
cb.state = stateHalfOpen
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
case stateHalfOpen:
|
||||||
|
// Only one request at a time in half-open state.
|
||||||
|
// The first caller that got here via Allow() is already in flight;
|
||||||
|
// reject subsequent callers until that probe resolves.
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordSuccess records a successful request. If the breaker is half-open, it resets to closed.
|
||||||
|
func (cb *CircuitBreaker) RecordSuccess() {
|
||||||
|
cb.mu.Lock()
|
||||||
|
defer cb.mu.Unlock()
|
||||||
|
|
||||||
|
cb.failureCount = 0
|
||||||
|
cb.state = stateClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordFailure records a failed request. If the failure threshold is reached, the
|
||||||
|
// breaker transitions to the open state.
|
||||||
|
func (cb *CircuitBreaker) RecordFailure() {
|
||||||
|
cb.mu.Lock()
|
||||||
|
defer cb.mu.Unlock()
|
||||||
|
|
||||||
|
cb.failureCount++
|
||||||
|
cb.lastFailureTime = time.Now()
|
||||||
|
|
||||||
|
if cb.failureCount >= cb.failureThreshold {
|
||||||
|
cb.state = stateOpen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// State returns the current state of the circuit breaker as a human-readable string.
|
||||||
|
func (cb *CircuitBreaker) State() string {
|
||||||
|
cb.mu.Lock()
|
||||||
|
defer cb.mu.Unlock()
|
||||||
|
|
||||||
|
switch cb.state {
|
||||||
|
case stateClosed:
|
||||||
|
return "closed"
|
||||||
|
case stateOpen:
|
||||||
|
return "open"
|
||||||
|
case stateHalfOpen:
|
||||||
|
return "half-open"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the circuit breaker's name.
|
||||||
|
func (cb *CircuitBreaker) Name() string {
|
||||||
|
return cb.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset resets the circuit breaker to the closed state with zero failures.
|
||||||
|
func (cb *CircuitBreaker) Reset() {
|
||||||
|
cb.mu.Lock()
|
||||||
|
defer cb.mu.Unlock()
|
||||||
|
|
||||||
|
cb.state = stateClosed
|
||||||
|
cb.failureCount = 0
|
||||||
|
cb.lastFailureTime = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counts returns the current failure count (useful for testing and monitoring).
|
||||||
|
func (cb *CircuitBreaker) Counts() int {
|
||||||
|
cb.mu.Lock()
|
||||||
|
defer cb.mu.Unlock()
|
||||||
|
return cb.failureCount
|
||||||
|
}
|
||||||
275
internal/push/circuit_breaker_test.go
Normal file
275
internal/push/circuit_breaker_test.go
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
package push
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCircuitBreaker_StartsInClosedState(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test")
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_OpensAfterThresholdFailures(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test", WithFailureThreshold(3))
|
||||||
|
|
||||||
|
// First two failures should keep it closed
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
|
||||||
|
// Third failure should open it
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
assert.False(t, cb.Allow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_DefaultThresholdIsFive(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test")
|
||||||
|
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
cb.RecordFailure() // 5th failure
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_RejectsRequestsWhenOpen(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test", WithFailureThreshold(1))
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
|
||||||
|
// Multiple calls should all be rejected
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
assert.False(t, cb.Allow())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_TransitionsToHalfOpenAfterRecoveryTimeout(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithRecoveryTimeout(50*time.Millisecond),
|
||||||
|
)
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
assert.False(t, cb.Allow())
|
||||||
|
|
||||||
|
// Wait for recovery timeout
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should now allow one request (half-open)
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
assert.Equal(t, "half-open", cb.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_HalfOpenRejectsSecondRequest(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithRecoveryTimeout(50*time.Millisecond),
|
||||||
|
)
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// First request allowed (probe)
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
assert.Equal(t, "half-open", cb.State())
|
||||||
|
|
||||||
|
// Second request rejected while probe is in flight
|
||||||
|
assert.False(t, cb.Allow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_HalfOpenSuccess_ResetsToClosed(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(1),
|
||||||
|
WithRecoveryTimeout(50*time.Millisecond),
|
||||||
|
)
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// Probe request
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
|
||||||
|
// Probe succeeds
|
||||||
|
cb.RecordSuccess()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
assert.Equal(t, 0, cb.Counts())
|
||||||
|
|
||||||
|
// Normal operation resumes
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_HalfOpenFailure_ReturnsToOpen(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(2),
|
||||||
|
WithRecoveryTimeout(50*time.Millisecond),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Open the circuit
|
||||||
|
cb.RecordFailure()
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// Probe request
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
assert.Equal(t, "half-open", cb.State())
|
||||||
|
|
||||||
|
// Probe fails - the failure count is now 3 which is >= threshold of 2
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
assert.False(t, cb.Allow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_SuccessResetsFailureCount(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test", WithFailureThreshold(3))
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, 2, cb.Counts())
|
||||||
|
|
||||||
|
// A success should reset the counter
|
||||||
|
cb.RecordSuccess()
|
||||||
|
assert.Equal(t, 0, cb.Counts())
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
|
||||||
|
// Now it should take 3 more failures to open
|
||||||
|
cb.RecordFailure()
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test", WithFailureThreshold(1))
|
||||||
|
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
|
||||||
|
cb.Reset()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
assert.Equal(t, 0, cb.Counts())
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_Name(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("apns-breaker")
|
||||||
|
assert.Equal(t, "apns-breaker", cb.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_CustomOptions(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(10),
|
||||||
|
WithRecoveryTimeout(5*time.Minute),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Should take 10 failures to open
|
||||||
|
for i := 0; i < 9; i++ {
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
}
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_InvalidOptionsIgnored(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(0), // Should be ignored (keeps default)
|
||||||
|
WithRecoveryTimeout(-1), // Should be ignored (keeps default)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default threshold of 5 should still apply
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
}
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_ThreadSafety(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("test",
|
||||||
|
WithFailureThreshold(100),
|
||||||
|
WithRecoveryTimeout(10*time.Millisecond),
|
||||||
|
)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
const goroutines = 50
|
||||||
|
const iterations = 100
|
||||||
|
|
||||||
|
// Hammer it from many goroutines
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < iterations; j++ {
|
||||||
|
cb.Allow()
|
||||||
|
if j%2 == 0 {
|
||||||
|
cb.RecordFailure()
|
||||||
|
} else {
|
||||||
|
cb.RecordSuccess()
|
||||||
|
}
|
||||||
|
_ = cb.State()
|
||||||
|
_ = cb.Counts()
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Should not panic or deadlock. State should be valid.
|
||||||
|
state := cb.State()
|
||||||
|
require.Contains(t, []string{"closed", "open", "half-open"}, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCircuitBreaker_FullLifecycle(t *testing.T) {
|
||||||
|
cb := NewCircuitBreaker("lifecycle-test",
|
||||||
|
WithFailureThreshold(3),
|
||||||
|
WithRecoveryTimeout(50*time.Millisecond),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 1. Closed: requests flow normally
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
cb.RecordSuccess()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
|
||||||
|
// 2. Accumulate failures
|
||||||
|
cb.RecordFailure()
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
|
||||||
|
// 3. Third failure opens the circuit
|
||||||
|
cb.RecordFailure()
|
||||||
|
assert.Equal(t, "open", cb.State())
|
||||||
|
assert.False(t, cb.Allow())
|
||||||
|
|
||||||
|
// 4. Wait for recovery
|
||||||
|
time.Sleep(60 * time.Millisecond)
|
||||||
|
|
||||||
|
// 5. Half-open: probe request allowed
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
assert.Equal(t, "half-open", cb.State())
|
||||||
|
|
||||||
|
// 6. Probe succeeds, back to closed
|
||||||
|
cb.RecordSuccess()
|
||||||
|
assert.Equal(t, "closed", cb.State())
|
||||||
|
assert.True(t, cb.Allow())
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package push
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
|
||||||
@@ -14,16 +15,25 @@ const (
|
|||||||
PlatformAndroid = "android"
|
PlatformAndroid = "android"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Timeout for individual push notification send operations.
|
||||||
|
const pushSendTimeout = 15 * time.Second
|
||||||
|
|
||||||
// Client provides a unified interface for sending push notifications
|
// Client provides a unified interface for sending push notifications
|
||||||
type Client struct {
|
type Client struct {
|
||||||
apns *APNsClient
|
apns *APNsClient
|
||||||
fcm *FCMClient
|
fcm *FCMClient
|
||||||
enabled bool
|
enabled bool
|
||||||
|
apnsBreaker *CircuitBreaker
|
||||||
|
fcmBreaker *CircuitBreaker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new unified push notification client
|
// NewClient creates a new unified push notification client
|
||||||
func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
|
func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
|
||||||
client := &Client{enabled: enabled}
|
client := &Client{
|
||||||
|
enabled: enabled,
|
||||||
|
apnsBreaker: NewCircuitBreaker("apns"),
|
||||||
|
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize APNs client (iOS)
|
// Initialize APNs client (iOS)
|
||||||
if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" {
|
if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" {
|
||||||
@@ -54,7 +64,8 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendToIOS sends a push notification to iOS devices
|
// SendToIOS sends a push notification to iOS devices.
|
||||||
|
// The call is guarded by a circuit breaker and uses a context timeout.
|
||||||
func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
||||||
if !c.enabled {
|
if !c.enabled {
|
||||||
log.Debug().Msg("Push notifications disabled by feature flag")
|
log.Debug().Msg("Push notifications disabled by feature flag")
|
||||||
@@ -64,10 +75,26 @@ func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message
|
|||||||
log.Warn().Msg("APNs client not initialized, skipping iOS push")
|
log.Warn().Msg("APNs client not initialized, skipping iOS push")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.apns.Send(ctx, tokens, title, message, data)
|
if !c.apnsBreaker.Allow() {
|
||||||
|
log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS push")
|
||||||
|
return ErrCircuitOpen
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := c.apns.Send(sendCtx, tokens, title, message, data)
|
||||||
|
if err != nil {
|
||||||
|
c.apnsBreaker.RecordFailure()
|
||||||
|
log.Warn().Err(err).Str("breaker_state", c.apnsBreaker.State()).Msg("APNs send failed, recorded circuit breaker failure")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.apnsBreaker.RecordSuccess()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendToAndroid sends a push notification to Android devices
|
// SendToAndroid sends a push notification to Android devices.
|
||||||
|
// The call is guarded by a circuit breaker and uses a context timeout.
|
||||||
func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
||||||
if !c.enabled {
|
if !c.enabled {
|
||||||
log.Debug().Msg("Push notifications disabled by feature flag")
|
log.Debug().Msg("Push notifications disabled by feature flag")
|
||||||
@@ -77,7 +104,22 @@ func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, mess
|
|||||||
log.Warn().Msg("FCM client not initialized, skipping Android push")
|
log.Warn().Msg("FCM client not initialized, skipping Android push")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return c.fcm.Send(ctx, tokens, title, message, data)
|
if !c.fcmBreaker.Allow() {
|
||||||
|
log.Warn().Str("breaker", c.fcmBreaker.Name()).Msg("FCM circuit breaker is open, skipping Android push")
|
||||||
|
return ErrCircuitOpen
|
||||||
|
}
|
||||||
|
|
||||||
|
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := c.fcm.Send(sendCtx, tokens, title, message, data)
|
||||||
|
if err != nil {
|
||||||
|
c.fcmBreaker.RecordFailure()
|
||||||
|
log.Warn().Err(err).Str("breaker_state", c.fcmBreaker.State()).Msg("FCM send failed, recorded circuit breaker failure")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.fcmBreaker.RecordSuccess()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendToAll sends a push notification to both iOS and Android devices
|
// SendToAll sends a push notification to both iOS and Android devices
|
||||||
@@ -115,8 +157,9 @@ func (c *Client) IsAndroidEnabled() bool {
|
|||||||
return c.fcm != nil
|
return c.fcm != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendActionableNotification sends notifications with action button support
|
// SendActionableNotification sends notifications with action button support.
|
||||||
// iOS receives a category for actionable notifications, Android handles actions via data payload
|
// iOS receives a category for actionable notifications, Android handles actions via data payload.
|
||||||
|
// Both platforms are guarded by their respective circuit breakers.
|
||||||
func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error {
|
func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error {
|
||||||
if !c.enabled {
|
if !c.enabled {
|
||||||
log.Debug().Msg("Push notifications disabled by feature flag")
|
log.Debug().Msg("Push notifications disabled by feature flag")
|
||||||
@@ -127,10 +170,19 @@ func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, andr
|
|||||||
if len(iosTokens) > 0 {
|
if len(iosTokens) > 0 {
|
||||||
if c.apns == nil {
|
if c.apns == nil {
|
||||||
log.Warn().Msg("APNs client not initialized, skipping iOS actionable push")
|
log.Warn().Msg("APNs client not initialized, skipping iOS actionable push")
|
||||||
|
} else if !c.apnsBreaker.Allow() {
|
||||||
|
log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS actionable push")
|
||||||
|
lastErr = ErrCircuitOpen
|
||||||
} else {
|
} else {
|
||||||
if err := c.apns.SendWithCategory(ctx, iosTokens, title, message, data, iosCategoryID); err != nil {
|
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
|
||||||
|
err := c.apns.SendWithCategory(sendCtx, iosTokens, title, message, data, iosCategoryID)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
c.apnsBreaker.RecordFailure()
|
||||||
log.Error().Err(err).Msg("Failed to send iOS actionable notifications")
|
log.Error().Err(err).Msg("Failed to send iOS actionable notifications")
|
||||||
lastErr = err
|
lastErr = err
|
||||||
|
} else {
|
||||||
|
c.apnsBreaker.RecordSuccess()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 15 * time.Second,
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -173,6 +173,27 @@ func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error
|
|||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FindTokenByKey looks up an auth token by its key value.
|
||||||
|
func (r *UserRepository) FindTokenByKey(key string) (*models.AuthToken, error) {
|
||||||
|
var token models.AuthToken
|
||||||
|
if err := r.db.Where("key = ?", key).First(&token).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, ErrTokenNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateToken creates a new auth token for a user.
|
||||||
|
func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
|
||||||
|
token := models.AuthToken{UserID: userID}
|
||||||
|
if err := r.db.Create(&token).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteToken deletes an auth token
|
// DeleteToken deletes an auth token
|
||||||
func (r *UserRepository) DeleteToken(token string) error {
|
func (r *UserRepository) DeleteToken(token string) error {
|
||||||
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})
|
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})
|
||||||
|
|||||||
136
internal/router/health_test.go
Normal file
136
internal/router/health_test.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setupTestDB creates an in-memory SQLite database for health check tests
|
||||||
|
func setupTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLiveCheck_Returns200(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
e.GET("/api/health/live", liveCheck)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/health/live", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "alive", resp["status"])
|
||||||
|
assert.Equal(t, Version, resp["version"])
|
||||||
|
assert.Contains(t, resp, "timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadinessCheck_HealthyDB_NilCache_Returns200(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
|
||||||
|
deps := &Dependencies{
|
||||||
|
DB: db,
|
||||||
|
Cache: nil, // No Redis configured
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
e.GET("/api/health/", readinessCheck(deps))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "healthy", resp["status"])
|
||||||
|
assert.Equal(t, Version, resp["version"])
|
||||||
|
|
||||||
|
checks := resp["checks"].(map[string]interface{})
|
||||||
|
assert.Equal(t, "ok", checks["postgres"])
|
||||||
|
assert.Equal(t, "not configured", checks["redis"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadinessCheck_DBDown_Returns503(t *testing.T) {
|
||||||
|
// Open an in-memory SQLite DB, then close the underlying sql.DB to simulate failure
|
||||||
|
db := setupTestDB(t)
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
require.NoError(t, err)
|
||||||
|
sqlDB.Close()
|
||||||
|
|
||||||
|
deps := &Dependencies{
|
||||||
|
DB: db,
|
||||||
|
Cache: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
e.GET("/api/health/", readinessCheck(deps))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
err = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "unhealthy", resp["status"])
|
||||||
|
|
||||||
|
checks := resp["checks"].(map[string]interface{})
|
||||||
|
assert.Contains(t, checks["postgres"], "ping failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadinessCheck_ResponseFormat(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
|
||||||
|
deps := &Dependencies{
|
||||||
|
DB: db,
|
||||||
|
Cache: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
e.GET("/api/health/", readinessCheck(deps))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
e.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
var resp map[string]interface{}
|
||||||
|
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify all expected fields are present
|
||||||
|
assert.Contains(t, resp, "status")
|
||||||
|
assert.Contains(t, resp, "version")
|
||||||
|
assert.Contains(t, resp, "checks")
|
||||||
|
assert.Contains(t, resp, "timestamp")
|
||||||
|
|
||||||
|
// Verify checks is a map with expected keys
|
||||||
|
checks, ok := resp["checks"].(map[string]interface{})
|
||||||
|
require.True(t, ok, "checks should be a map")
|
||||||
|
assert.Contains(t, checks, "postgres")
|
||||||
|
assert.Contains(t, checks, "redis")
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package router
|
package router
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -62,11 +63,12 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
|
|
||||||
// Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.)
|
// Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.)
|
||||||
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
|
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
|
||||||
XSSProtection: "1; mode=block",
|
XSSProtection: "1; mode=block",
|
||||||
ContentTypeNosniff: "nosniff",
|
ContentTypeNosniff: "nosniff",
|
||||||
XFrameOptions: "SAMEORIGIN",
|
XFrameOptions: "SAMEORIGIN",
|
||||||
HSTSMaxAge: 31536000, // 1 year in seconds
|
HSTSMaxAge: 31536000, // 1 year in seconds
|
||||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||||
|
ContentSecurityPolicy: "default-src 'none'; frame-ancestors 'none'",
|
||||||
}))
|
}))
|
||||||
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
||||||
Limit: "1M", // 1MB default for JSON payloads
|
Limit: "1M", // 1MB default for JSON payloads
|
||||||
@@ -93,6 +95,14 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
e.Use(corsMiddleware(cfg))
|
e.Use(corsMiddleware(cfg))
|
||||||
e.Use(i18n.Middleware())
|
e.Use(i18n.Middleware())
|
||||||
|
|
||||||
|
// Gzip compression (skip media endpoints since they serve binary files)
|
||||||
|
e.Use(middleware.GzipWithConfig(middleware.GzipConfig{
|
||||||
|
Level: 5,
|
||||||
|
Skipper: func(c echo.Context) bool {
|
||||||
|
return strings.HasPrefix(c.Request().URL.Path, "/api/media/")
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
// Monitoring metrics middleware (if monitoring is enabled)
|
// Monitoring metrics middleware (if monitoring is enabled)
|
||||||
if deps.MonitoringService != nil {
|
if deps.MonitoringService != nil {
|
||||||
if metricsMiddleware := deps.MonitoringService.MetricsMiddleware(); metricsMiddleware != nil {
|
if metricsMiddleware := deps.MonitoringService.MetricsMiddleware(); metricsMiddleware != nil {
|
||||||
@@ -114,8 +124,9 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health check endpoint (no auth required)
|
// Health check endpoints (no auth required)
|
||||||
e.GET("/api/health/", healthCheck)
|
e.GET("/api/health/", readinessCheck(deps))
|
||||||
|
e.GET("/api/health/live", liveCheck)
|
||||||
|
|
||||||
// Initialize onboarding email service for tracking handler
|
// Initialize onboarding email service for tracking handler
|
||||||
onboardingService := services.NewOnboardingEmailService(deps.DB, deps.EmailService, cfg.Server.BaseURL)
|
onboardingService := services.NewOnboardingEmailService(deps.DB, deps.EmailService, cfg.Server.BaseURL)
|
||||||
@@ -172,17 +183,21 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
subscriptionWebhookHandler.SetStripeService(stripeService)
|
subscriptionWebhookHandler.SetStripeService(stripeService)
|
||||||
|
|
||||||
// Initialize middleware
|
// Initialize middleware
|
||||||
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
|
authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg)
|
||||||
|
|
||||||
// Initialize Apple auth service
|
// Initialize Apple auth service
|
||||||
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
|
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
|
||||||
googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg)
|
googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg)
|
||||||
|
|
||||||
|
// Initialize audit service for security event logging
|
||||||
|
auditService := services.NewAuditService(deps.DB)
|
||||||
|
|
||||||
// Initialize handlers
|
// Initialize handlers
|
||||||
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
||||||
authHandler.SetAppleAuthService(appleAuthService)
|
authHandler.SetAppleAuthService(appleAuthService)
|
||||||
authHandler.SetGoogleAuthService(googleAuthService)
|
authHandler.SetGoogleAuthService(googleAuthService)
|
||||||
authHandler.SetStorageService(deps.StorageService)
|
authHandler.SetStorageService(deps.StorageService)
|
||||||
|
authHandler.SetAuditService(auditService)
|
||||||
userHandler := handlers.NewUserHandler(userService)
|
userHandler := handlers.NewUserHandler(userService)
|
||||||
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled)
|
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled)
|
||||||
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
|
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
|
||||||
@@ -201,6 +216,11 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
|||||||
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
|
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prometheus metrics endpoint (no auth required, for scraping)
|
||||||
|
if deps.MonitoringService != nil {
|
||||||
|
e.GET("/metrics", prometheusMetrics(deps.MonitoringService))
|
||||||
|
}
|
||||||
|
|
||||||
// Set up admin routes with monitoring handler (if available)
|
// Set up admin routes with monitoring handler (if available)
|
||||||
var monitoringHandler *monitoring.Handler
|
var monitoringHandler *monitoring.Handler
|
||||||
if deps.MonitoringService != nil {
|
if deps.MonitoringService != nil {
|
||||||
@@ -295,16 +315,126 @@ func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// healthCheck returns API health status
|
// liveCheck returns a simple 200 for Kubernetes liveness probes
|
||||||
func healthCheck(c echo.Context) error {
|
func liveCheck(c echo.Context) error {
|
||||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||||
"status": "healthy",
|
"status": "alive",
|
||||||
"version": Version,
|
"version": Version,
|
||||||
"framework": "Echo",
|
|
||||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// readinessCheck returns 200 if PostgreSQL and Redis are reachable, 503 otherwise.
|
||||||
|
// This is used by Kubernetes readiness probes and load balancers.
|
||||||
|
func readinessCheck(deps *Dependencies) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
status := "healthy"
|
||||||
|
httpStatus := http.StatusOK
|
||||||
|
checks := make(map[string]string)
|
||||||
|
|
||||||
|
// Check PostgreSQL
|
||||||
|
sqlDB, err := deps.DB.DB()
|
||||||
|
if err != nil {
|
||||||
|
checks["postgres"] = fmt.Sprintf("failed to get sql.DB: %v", err)
|
||||||
|
status = "unhealthy"
|
||||||
|
httpStatus = http.StatusServiceUnavailable
|
||||||
|
} else if err := sqlDB.PingContext(ctx); err != nil {
|
||||||
|
checks["postgres"] = fmt.Sprintf("ping failed: %v", err)
|
||||||
|
status = "unhealthy"
|
||||||
|
httpStatus = http.StatusServiceUnavailable
|
||||||
|
} else {
|
||||||
|
checks["postgres"] = "ok"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Redis (if cache service is available)
|
||||||
|
if deps.Cache != nil {
|
||||||
|
if err := deps.Cache.Client().Ping(ctx).Err(); err != nil {
|
||||||
|
checks["redis"] = fmt.Sprintf("ping failed: %v", err)
|
||||||
|
status = "unhealthy"
|
||||||
|
httpStatus = http.StatusServiceUnavailable
|
||||||
|
} else {
|
||||||
|
checks["redis"] = "ok"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
checks["redis"] = "not configured"
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.JSON(httpStatus, map[string]interface{}{
|
||||||
|
"status": status,
|
||||||
|
"version": Version,
|
||||||
|
"checks": checks,
|
||||||
|
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prometheusMetrics returns an Echo handler that outputs metrics in Prometheus text format.
|
||||||
|
// It uses the existing monitoring service's HTTP stats collector to avoid adding external dependencies.
|
||||||
|
func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
httpCollector := monSvc.HTTPCollector()
|
||||||
|
if httpCollector == nil {
|
||||||
|
return c.String(http.StatusOK, "# No HTTP metrics available (collector not initialized)\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := httpCollector.GetStats()
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Request count by method+path+status
|
||||||
|
b.WriteString("# HELP http_requests_total Total number of HTTP requests.\n")
|
||||||
|
b.WriteString("# TYPE http_requests_total counter\n")
|
||||||
|
for statusCode, count := range stats.ByStatusCode {
|
||||||
|
fmt.Fprintf(&b, "http_requests_total{status_code=\"%d\"} %d\n", statusCode, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-endpoint request count
|
||||||
|
b.WriteString("# HELP http_endpoint_requests_total Total requests per endpoint.\n")
|
||||||
|
b.WriteString("# TYPE http_endpoint_requests_total counter\n")
|
||||||
|
for endpoint, epStats := range stats.ByEndpoint {
|
||||||
|
// endpoint is "METHOD /path"
|
||||||
|
parts := strings.SplitN(endpoint, " ", 2)
|
||||||
|
method := endpoint
|
||||||
|
path := ""
|
||||||
|
if len(parts) == 2 {
|
||||||
|
method = parts[0]
|
||||||
|
path = parts[1]
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "http_endpoint_requests_total{method=\"%s\",path=\"%s\"} %d\n", method, path, epStats.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request duration (avg latency as a gauge, since we don't have raw histogram buckets)
|
||||||
|
b.WriteString("# HELP http_request_duration_ms Average request duration in milliseconds per endpoint.\n")
|
||||||
|
b.WriteString("# TYPE http_request_duration_ms gauge\n")
|
||||||
|
for endpoint, epStats := range stats.ByEndpoint {
|
||||||
|
parts := strings.SplitN(endpoint, " ", 2)
|
||||||
|
method := endpoint
|
||||||
|
path := ""
|
||||||
|
if len(parts) == 2 {
|
||||||
|
method = parts[0]
|
||||||
|
path = parts[1]
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"avg\"} %.2f\n", method, path, epStats.AvgLatencyMs)
|
||||||
|
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"p95\"} %.2f\n", method, path, epStats.P95LatencyMs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error rate
|
||||||
|
b.WriteString("# HELP http_error_rate Overall error rate (4xx+5xx / total).\n")
|
||||||
|
b.WriteString("# TYPE http_error_rate gauge\n")
|
||||||
|
fmt.Fprintf(&b, "http_error_rate %.4f\n", stats.ErrorRate)
|
||||||
|
|
||||||
|
// Requests per minute
|
||||||
|
b.WriteString("# HELP http_requests_per_minute Current request rate.\n")
|
||||||
|
b.WriteString("# TYPE http_requests_per_minute gauge\n")
|
||||||
|
fmt.Fprintf(&b, "http_requests_per_minute %.2f\n", stats.RequestsPerMinute)
|
||||||
|
|
||||||
|
c.Response().Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
|
||||||
|
return c.String(http.StatusOK, b.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// setupPublicAuthRoutes configures public authentication routes with
|
// setupPublicAuthRoutes configures public authentication routes with
|
||||||
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
|
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
|
||||||
// Rate limiters are disabled in debug mode to allow UI test suites to run
|
// Rate limiters are disabled in debug mode to allow UI test suites to run
|
||||||
@@ -342,6 +472,7 @@ func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler
|
|||||||
auth := api.Group("/auth")
|
auth := api.Group("/auth")
|
||||||
{
|
{
|
||||||
auth.POST("/logout/", authHandler.Logout)
|
auth.POST("/logout/", authHandler.Logout)
|
||||||
|
auth.POST("/refresh/", authHandler.RefreshToken)
|
||||||
auth.GET("/me/", authHandler.CurrentUser)
|
auth.GET("/me/", authHandler.CurrentUser)
|
||||||
auth.PUT("/profile/", authHandler.UpdateProfile)
|
auth.PUT("/profile/", authHandler.UpdateProfile)
|
||||||
auth.PATCH("/profile/", authHandler.UpdateProfile)
|
auth.PATCH("/profile/", authHandler.UpdateProfile)
|
||||||
|
|||||||
93
internal/services/audit_service.go
Normal file
93
internal/services/audit_service.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Audit event type constants
|
||||||
|
const (
|
||||||
|
AuditEventLogin = "auth.login"
|
||||||
|
AuditEventLoginFailed = "auth.login_failed"
|
||||||
|
AuditEventRegister = "auth.register"
|
||||||
|
AuditEventLogout = "auth.logout"
|
||||||
|
AuditEventPasswordReset = "auth.password_reset"
|
||||||
|
AuditEventPasswordChanged = "auth.password_changed"
|
||||||
|
AuditEventAccountDeleted = "auth.account_deleted"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditService handles audit logging for security-relevant events.
|
||||||
|
// It writes audit log entries asynchronously via a buffered channel to avoid
|
||||||
|
// blocking request handlers.
|
||||||
|
type AuditService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
logChan chan *models.AuditLog
|
||||||
|
done chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditService creates a new audit service with a buffered channel for async writes.
|
||||||
|
// Call Stop() when shutting down to flush remaining entries.
|
||||||
|
func NewAuditService(db *gorm.DB) *AuditService {
|
||||||
|
s := &AuditService{
|
||||||
|
db: db,
|
||||||
|
logChan: make(chan *models.AuditLog, 256),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go s.processLogs()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// processLogs drains the log channel and writes entries to the database.
|
||||||
|
func (s *AuditService) processLogs() {
|
||||||
|
defer close(s.done)
|
||||||
|
for entry := range s.logChan {
|
||||||
|
if err := s.db.Create(entry).Error; err != nil {
|
||||||
|
log.Error().Err(err).
|
||||||
|
Str("event_type", entry.EventType).
|
||||||
|
Msg("Failed to write audit log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop closes the log channel and waits for all pending entries to be written.
|
||||||
|
// It is safe to call Stop multiple times.
|
||||||
|
func (s *AuditService) Stop() {
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
close(s.logChan)
|
||||||
|
})
|
||||||
|
<-s.done
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogEvent records an audit event. It extracts the client IP and User-Agent from
|
||||||
|
// the Echo context and sends the entry to the background writer. If the channel
|
||||||
|
// is full the entry is dropped and an error is logged (non-blocking).
|
||||||
|
func (s *AuditService) LogEvent(c echo.Context, userID *uint, eventType string, details map[string]interface{}) {
|
||||||
|
var ip, ua string
|
||||||
|
if c != nil {
|
||||||
|
ip = c.RealIP()
|
||||||
|
ua = c.Request().Header.Get("User-Agent")
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &models.AuditLog{
|
||||||
|
UserID: userID,
|
||||||
|
EventType: eventType,
|
||||||
|
IPAddress: ip,
|
||||||
|
UserAgent: ua,
|
||||||
|
Details: models.JSONMap(details),
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.logChan <- entry:
|
||||||
|
// sent
|
||||||
|
default:
|
||||||
|
log.Warn().
|
||||||
|
Str("event_type", eventType).
|
||||||
|
Msg("Audit log channel full, dropping entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
176
internal/services/audit_service_test.go
Normal file
176
internal/services/audit_service_test.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuditService_LogEvent_WritesToDatabase(t *testing.T) {
|
||||||
|
db := testutil.SetupTestDB(t)
|
||||||
|
svc := NewAuditService(db)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
// Create a fake Echo context
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
req.Header.Set("User-Agent", "TestAgent/1.0")
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
userID := uint(42)
|
||||||
|
svc.LogEvent(c, &userID, AuditEventLogin, map[string]interface{}{
|
||||||
|
"method": "password",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Stop flushes the channel
|
||||||
|
svc.Stop()
|
||||||
|
|
||||||
|
var entries []models.AuditLog
|
||||||
|
err := db.Find(&entries).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, entries, 1)
|
||||||
|
|
||||||
|
entry := entries[0]
|
||||||
|
assert.Equal(t, uint(42), *entry.UserID)
|
||||||
|
assert.Equal(t, AuditEventLogin, entry.EventType)
|
||||||
|
assert.Equal(t, "TestAgent/1.0", entry.UserAgent)
|
||||||
|
assert.NotEmpty(t, entry.IPAddress)
|
||||||
|
assert.Equal(t, "password", entry.Details["method"])
|
||||||
|
assert.False(t, entry.CreatedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditService_LogEvent_NilUserID(t *testing.T) {
|
||||||
|
db := testutil.SetupTestDB(t)
|
||||||
|
svc := NewAuditService(db)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
svc.LogEvent(c, nil, AuditEventLoginFailed, map[string]interface{}{
|
||||||
|
"identifier": "unknown@test.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
svc.Stop()
|
||||||
|
|
||||||
|
var entries []models.AuditLog
|
||||||
|
err := db.Find(&entries).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, entries, 1)
|
||||||
|
|
||||||
|
assert.Nil(t, entries[0].UserID)
|
||||||
|
assert.Equal(t, AuditEventLoginFailed, entries[0].EventType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditService_LogEvent_NilContext(t *testing.T) {
|
||||||
|
db := testutil.SetupTestDB(t)
|
||||||
|
svc := NewAuditService(db)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
userID := uint(1)
|
||||||
|
svc.LogEvent(nil, &userID, AuditEventLogout, nil)
|
||||||
|
|
||||||
|
svc.Stop()
|
||||||
|
|
||||||
|
var entries []models.AuditLog
|
||||||
|
err := db.Find(&entries).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, entries, 1)
|
||||||
|
|
||||||
|
assert.Equal(t, AuditEventLogout, entries[0].EventType)
|
||||||
|
assert.Empty(t, entries[0].IPAddress)
|
||||||
|
assert.Empty(t, entries[0].UserAgent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditService_LogEvent_MultipleEvents(t *testing.T) {
|
||||||
|
db := testutil.SetupTestDB(t)
|
||||||
|
svc := NewAuditService(db)
|
||||||
|
defer svc.Stop()
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
userID := uint(10)
|
||||||
|
svc.LogEvent(c, &userID, AuditEventRegister, nil)
|
||||||
|
svc.LogEvent(c, &userID, AuditEventLogin, nil)
|
||||||
|
svc.LogEvent(c, &userID, AuditEventLogout, nil)
|
||||||
|
|
||||||
|
svc.Stop()
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
err := db.Model(&models.AuditLog{}).Count(&count).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(3), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditService_EventTypeConstants(t *testing.T) {
|
||||||
|
// Verify all event constants have expected values
|
||||||
|
assert.Equal(t, "auth.login", AuditEventLogin)
|
||||||
|
assert.Equal(t, "auth.login_failed", AuditEventLoginFailed)
|
||||||
|
assert.Equal(t, "auth.register", AuditEventRegister)
|
||||||
|
assert.Equal(t, "auth.logout", AuditEventLogout)
|
||||||
|
assert.Equal(t, "auth.password_reset", AuditEventPasswordReset)
|
||||||
|
assert.Equal(t, "auth.password_changed", AuditEventPasswordChanged)
|
||||||
|
assert.Equal(t, "auth.account_deleted", AuditEventAccountDeleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditService_Stop_FlushesRemainingEntries(t *testing.T) {
|
||||||
|
db := testutil.SetupTestDB(t)
|
||||||
|
svc := NewAuditService(db)
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
// Send many events quickly
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
uid := uint(i)
|
||||||
|
svc.LogEvent(c, &uid, AuditEventLogin, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop should block until all entries are written
|
||||||
|
svc.Stop()
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
err := db.Model(&models.AuditLog{}).Count(&count).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(50), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditLog_TableName(t *testing.T) {
|
||||||
|
log := models.AuditLog{}
|
||||||
|
assert.Equal(t, "audit_log", log.TableName())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuditLog_JSONMap_NilHandling(t *testing.T) {
|
||||||
|
db := testutil.SetupTestDB(t)
|
||||||
|
|
||||||
|
// Create entry with nil details
|
||||||
|
entry := &models.AuditLog{
|
||||||
|
EventType: "test",
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
err := db.Create(entry).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Read it back
|
||||||
|
var found models.AuditLog
|
||||||
|
err = db.First(&found, entry.ID).Error
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, found.Details)
|
||||||
|
}
|
||||||
173
internal/services/auth_refresh_test.go
Normal file
173
internal/services/auth_refresh_test.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
|
"github.com/treytartt/honeydue-api/internal/config"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/models"
|
||||||
|
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupRefreshTestDB(t *testing.T) *gorm.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
|
||||||
|
t.Helper()
|
||||||
|
user := &models.User{
|
||||||
|
Username: "refreshtest",
|
||||||
|
Email: "refresh@test.com",
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
require.NoError(t, user.SetPassword("password123"))
|
||||||
|
require.NoError(t, db.Create(user).Error)
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
|
||||||
|
t.Helper()
|
||||||
|
token := &models.AuthToken{
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(token).Error)
|
||||||
|
|
||||||
|
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||||
|
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||||
|
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||||
|
token.Created = backdated
|
||||||
|
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestAuthService(db *gorm.DB) *AuthService {
|
||||||
|
userRepo := repositories.NewUserRepository(db)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Security: config.SecurityConfig{
|
||||||
|
SecretKey: "test-secret",
|
||||||
|
TokenExpiryDays: 90,
|
||||||
|
TokenRefreshDays: 60,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return NewAuthService(userRepo, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
|
||||||
|
assert.Contains(t, resp.Message, "still valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
|
||||||
|
assert.Contains(t, resp.Message, "refreshed")
|
||||||
|
|
||||||
|
// Verify old token was deleted
|
||||||
|
var count int64
|
||||||
|
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
|
||||||
|
assert.Equal(t, int64(0), count, "old token should be deleted")
|
||||||
|
|
||||||
|
// Verify new token exists in DB
|
||||||
|
db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count)
|
||||||
|
assert.Equal(t, int64(1), count, "new token should exist in DB")
|
||||||
|
|
||||||
|
// Verify new token belongs to the same user
|
||||||
|
var newToken models.AuthToken
|
||||||
|
require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error)
|
||||||
|
assert.Equal(t, user.ID, newToken.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, resp)
|
||||||
|
assert.Contains(t, err.Error(), "error.token_expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
|
||||||
|
// meaning it enters the renewal window
|
||||||
|
token := createTokenWithAge(t, db, user.ID, 61)
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
resp, err := svc.RefreshToken("nonexistent-token-key", user.ID)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, resp)
|
||||||
|
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
token := createTokenWithAge(t, db, user.ID, 75)
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
// Try to refresh with a different user ID
|
||||||
|
resp, err := svc.RefreshToken(token.Key, user.ID+999)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, resp)
|
||||||
|
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
||||||
|
db := setupRefreshTestDB(t)
|
||||||
|
user := createRefreshTestUser(t, db)
|
||||||
|
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
|
||||||
|
|
||||||
|
svc := newTestAuthService(db)
|
||||||
|
|
||||||
|
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
|
||||||
|
}
|
||||||
@@ -188,6 +188,66 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
}, code, nil
|
}, code, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshToken handles token refresh logic.
|
||||||
|
// - If token is expired (> expiryDays old), returns error (must re-login).
|
||||||
|
// - If token is in the renewal window (> refreshDays old), generates a new token.
|
||||||
|
// - If token is still fresh (< refreshDays old), returns the existing token (no-op).
|
||||||
|
func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
|
||||||
|
expiryDays := s.cfg.Security.TokenExpiryDays
|
||||||
|
if expiryDays <= 0 {
|
||||||
|
expiryDays = 90
|
||||||
|
}
|
||||||
|
refreshDays := s.cfg.Security.TokenRefreshDays
|
||||||
|
if refreshDays <= 0 {
|
||||||
|
refreshDays = 60
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the token
|
||||||
|
authToken, err := s.userRepo.FindTokenByKey(tokenKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, apperrors.Unauthorized("error.invalid_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify ownership
|
||||||
|
if authToken.UserID != userID {
|
||||||
|
return nil, apperrors.Unauthorized("error.invalid_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenAge := time.Since(authToken.Created)
|
||||||
|
expiryDuration := time.Duration(expiryDays) * 24 * time.Hour
|
||||||
|
refreshDuration := time.Duration(refreshDays) * 24 * time.Hour
|
||||||
|
|
||||||
|
// Token is expired — must re-login
|
||||||
|
if tokenAge > expiryDuration {
|
||||||
|
return nil, apperrors.Unauthorized("error.token_expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is still fresh — no-op refresh
|
||||||
|
if tokenAge < refreshDuration {
|
||||||
|
return &responses.RefreshTokenResponse{
|
||||||
|
Token: tokenKey,
|
||||||
|
Message: "Token is still valid.",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is in the renewal window — generate a new one
|
||||||
|
// Delete the old token
|
||||||
|
if err := s.userRepo.DeleteToken(tokenKey); err != nil {
|
||||||
|
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new token
|
||||||
|
newToken, err := s.userRepo.CreateToken(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, apperrors.Internal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &responses.RefreshTokenResponse{
|
||||||
|
Token: newToken.Key,
|
||||||
|
Message: "Token refreshed successfully.",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Logout invalidates a user's token
|
// Logout invalidates a user's token
|
||||||
func (s *AuthService) Logout(token string) error {
|
func (s *AuthService) Logout(token string) error {
|
||||||
return s.userRepo.DeleteToken(token)
|
return s.userRepo.DeleteToken(token)
|
||||||
|
|||||||
@@ -141,6 +141,12 @@ func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID
|
|||||||
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
|
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
|
||||||
|
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
|
||||||
|
key := AuthTokenPrefix + token
|
||||||
|
return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
// GetCachedAuthToken gets a cached user ID for a token
|
// GetCachedAuthToken gets a cached user ID for a token
|
||||||
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
||||||
key := AuthTokenPrefix + token
|
key := AuthTokenPrefix + token
|
||||||
@@ -154,6 +160,24 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui
|
|||||||
return userID, err
|
return userID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
|
||||||
|
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
|
||||||
|
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
|
||||||
|
key := AuthTokenPrefix + token
|
||||||
|
val, err := c.GetString(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var userID uint
|
||||||
|
var createdUnix int64
|
||||||
|
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
|
||||||
|
if n < 1 {
|
||||||
|
return 0, 0, fmt.Errorf("invalid cached token format")
|
||||||
|
}
|
||||||
|
return userID, createdUnix, nil
|
||||||
|
}
|
||||||
|
|
||||||
// InvalidateAuthToken removes a cached token
|
// InvalidateAuthToken removes a cached token
|
||||||
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
||||||
key := AuthTokenPrefix + token
|
key := AuthTokenPrefix + token
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
|
|||||||
mail.WithUsername(cfg.User),
|
mail.WithUsername(cfg.User),
|
||||||
mail.WithPassword(cfg.Password),
|
mail.WithPassword(cfg.Password),
|
||||||
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
|
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
|
||||||
|
mail.WithTimeout(30*time.Second),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
|
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ func SetupTestDB(t *testing.T) *gorm.DB {
|
|||||||
&models.FeatureBenefit{},
|
&models.FeatureBenefit{},
|
||||||
&models.UpgradeTrigger{},
|
&models.UpgradeTrigger{},
|
||||||
&models.Promotion{},
|
&models.Promotion{},
|
||||||
|
&models.AuditLog{},
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@@ -27,9 +28,34 @@ func NewCustomValidator() *CustomValidator {
|
|||||||
return name
|
return name
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Register custom password complexity validator
|
||||||
|
v.RegisterValidation("password_complexity", validatePasswordComplexity)
|
||||||
|
|
||||||
return &CustomValidator{validator: v}
|
return &CustomValidator{validator: v}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validatePasswordComplexity checks that a password contains at least one
|
||||||
|
// uppercase letter, one lowercase letter, and one digit.
|
||||||
|
// Minimum length is enforced separately via the "min" tag.
|
||||||
|
func validatePasswordComplexity(fl validator.FieldLevel) bool {
|
||||||
|
password := fl.Field().String()
|
||||||
|
var hasUpper, hasLower, hasDigit bool
|
||||||
|
for _, ch := range password {
|
||||||
|
switch {
|
||||||
|
case unicode.IsUpper(ch):
|
||||||
|
hasUpper = true
|
||||||
|
case unicode.IsLower(ch):
|
||||||
|
hasLower = true
|
||||||
|
case unicode.IsDigit(ch):
|
||||||
|
hasDigit = true
|
||||||
|
}
|
||||||
|
if hasUpper && hasLower && hasDigit {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hasUpper && hasLower && hasDigit
|
||||||
|
}
|
||||||
|
|
||||||
// Validate implements echo.Validator interface
|
// Validate implements echo.Validator interface
|
||||||
func (cv *CustomValidator) Validate(i interface{}) error {
|
func (cv *CustomValidator) Validate(i interface{}) error {
|
||||||
if err := cv.validator.Struct(i); err != nil {
|
if err := cv.validator.Struct(i); err != nil {
|
||||||
@@ -96,6 +122,8 @@ func formatMessage(fe validator.FieldError) string {
|
|||||||
return "Must be a valid URL"
|
return "Must be a valid URL"
|
||||||
case "uuid":
|
case "uuid":
|
||||||
return "Must be a valid UUID"
|
return "Must be a valid UUID"
|
||||||
|
case "password_complexity":
|
||||||
|
return "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit"
|
||||||
default:
|
default:
|
||||||
return "Invalid value"
|
return "Invalid value"
|
||||||
}
|
}
|
||||||
|
|||||||
115
internal/validator/validator_test.go
Normal file
115
internal/validator/validator_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package validator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
govalidator "github.com/go-playground/validator/v10"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidatePasswordComplexity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
password string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"valid password", "Password1", true},
|
||||||
|
{"valid complex password", "MyP@ssw0rd!", true},
|
||||||
|
{"missing uppercase", "password1", false},
|
||||||
|
{"missing lowercase", "PASSWORD1", false},
|
||||||
|
{"missing digit", "Password", false},
|
||||||
|
{"only digits", "12345678", false},
|
||||||
|
{"only lowercase", "abcdefgh", false},
|
||||||
|
{"only uppercase", "ABCDEFGH", false},
|
||||||
|
{"empty string", "", false},
|
||||||
|
{"single valid char each", "aA1", true},
|
||||||
|
{"unicode uppercase with digit and lower", "Über1abc", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
v := govalidator.New()
|
||||||
|
v.RegisterValidation("password_complexity", validatePasswordComplexity)
|
||||||
|
|
||||||
|
type testStruct struct {
|
||||||
|
Password string `validate:"password_complexity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := testStruct{Password: tc.password}
|
||||||
|
err := v.Struct(s)
|
||||||
|
if tc.valid && err != nil {
|
||||||
|
t.Errorf("expected password %q to be valid, got error: %v", tc.password, err)
|
||||||
|
}
|
||||||
|
if !tc.valid && err == nil {
|
||||||
|
t.Errorf("expected password %q to be invalid, got nil error", tc.password)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePasswordComplexityWithMinLength(t *testing.T) {
|
||||||
|
v := govalidator.New()
|
||||||
|
v.RegisterValidation("password_complexity", validatePasswordComplexity)
|
||||||
|
|
||||||
|
type request struct {
|
||||||
|
Password string `validate:"required,min=8,password_complexity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
password string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"valid 8+ chars with complexity", "Abcdefg1", true},
|
||||||
|
{"too short but complex", "Ab1", false},
|
||||||
|
{"long but no uppercase", "abcdefgh1", false},
|
||||||
|
{"long but no lowercase", "ABCDEFGH1", false},
|
||||||
|
{"long but no digit", "Abcdefghi", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := request{Password: tc.password}
|
||||||
|
err := v.Struct(r)
|
||||||
|
if tc.valid && err != nil {
|
||||||
|
t.Errorf("expected %q to be valid, got error: %v", tc.password, err)
|
||||||
|
}
|
||||||
|
if !tc.valid && err == nil {
|
||||||
|
t.Errorf("expected %q to be invalid, got nil", tc.password)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatMessagePasswordComplexity(t *testing.T) {
|
||||||
|
cv := NewCustomValidator()
|
||||||
|
|
||||||
|
type request struct {
|
||||||
|
Password string `json:"password" validate:"required,min=8,password_complexity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
r := request{Password: "lowercase1"}
|
||||||
|
err := cv.Validate(r)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected validation error for password without uppercase")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := FormatValidationErrors(err)
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatal("expected non-nil error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
field, ok := resp.Fields["password"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected 'password' field in error response")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMsg := "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit"
|
||||||
|
if field.Message != expectedMsg {
|
||||||
|
t.Errorf("expected message %q, got %q", expectedMsg, field.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.Tag != "password_complexity" {
|
||||||
|
t.Errorf("expected tag 'password_complexity', got %q", field.Tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
2
migrations/016_authtoken_created_at.down.sql
Normal file
2
migrations/016_authtoken_created_at.down.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
-- No-op: the created column is part of Django's original schema and should not
|
||||||
|
-- be removed.
|
||||||
6
migrations/016_authtoken_created_at.up.sql
Normal file
6
migrations/016_authtoken_created_at.up.sql
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
-- Ensure created column exists on user_authtoken (Django already creates it,
|
||||||
|
-- but this migration guarantees it for fresh Go-only deployments).
|
||||||
|
ALTER TABLE user_authtoken ADD COLUMN IF NOT EXISTS created TIMESTAMP WITH TIME ZONE DEFAULT NOW();
|
||||||
|
|
||||||
|
-- Backfill any rows that may have a NULL created timestamp.
|
||||||
|
UPDATE user_authtoken SET created = NOW() WHERE created IS NULL;
|
||||||
40
migrations/017_fk_indexes.down.sql
Normal file
40
migrations/017_fk_indexes.down.sql
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
-- Rollback: 017_fk_indexes
|
||||||
|
-- Drop all FK indexes added in the up migration.
|
||||||
|
|
||||||
|
-- auth / user tables
|
||||||
|
DROP INDEX IF EXISTS idx_authtoken_user_id;
|
||||||
|
DROP INDEX IF EXISTS idx_userprofile_user_id;
|
||||||
|
DROP INDEX IF EXISTS idx_confirmationcode_user_id;
|
||||||
|
DROP INDEX IF EXISTS idx_passwordresetcode_user_id;
|
||||||
|
DROP INDEX IF EXISTS idx_applesocialauth_user_id;
|
||||||
|
DROP INDEX IF EXISTS idx_googlesocialauth_user_id;
|
||||||
|
|
||||||
|
-- push notification device tables
|
||||||
|
DROP INDEX IF EXISTS idx_apnsdevice_user_id;
|
||||||
|
DROP INDEX IF EXISTS idx_gcmdevice_user_id;
|
||||||
|
|
||||||
|
-- notification tables
|
||||||
|
DROP INDEX IF EXISTS idx_notificationpreference_user_id;
|
||||||
|
|
||||||
|
-- subscription tables
|
||||||
|
DROP INDEX IF EXISTS idx_subscription_user_id;
|
||||||
|
|
||||||
|
-- residence tables
|
||||||
|
DROP INDEX IF EXISTS idx_residence_owner_id;
|
||||||
|
DROP INDEX IF EXISTS idx_sharecode_residence_id;
|
||||||
|
DROP INDEX IF EXISTS idx_sharecode_created_by_id;
|
||||||
|
|
||||||
|
-- task tables
|
||||||
|
DROP INDEX IF EXISTS idx_task_created_by_id;
|
||||||
|
DROP INDEX IF EXISTS idx_task_assigned_to_id;
|
||||||
|
DROP INDEX IF EXISTS idx_task_category_id;
|
||||||
|
DROP INDEX IF EXISTS idx_task_priority_id;
|
||||||
|
DROP INDEX IF EXISTS idx_task_frequency_id;
|
||||||
|
DROP INDEX IF EXISTS idx_task_contractor_id;
|
||||||
|
DROP INDEX IF EXISTS idx_task_parent_task_id;
|
||||||
|
DROP INDEX IF EXISTS idx_completionimage_completion_id;
|
||||||
|
DROP INDEX IF EXISTS idx_document_created_by_id;
|
||||||
|
DROP INDEX IF EXISTS idx_document_task_id;
|
||||||
|
DROP INDEX IF EXISTS idx_documentimage_document_id;
|
||||||
|
DROP INDEX IF EXISTS idx_contractor_residence_id;
|
||||||
|
DROP INDEX IF EXISTS idx_reminderlog_notification_id;
|
||||||
131
migrations/017_fk_indexes.up.sql
Normal file
131
migrations/017_fk_indexes.up.sql
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
-- Migration: 017_fk_indexes
|
||||||
|
-- Add indexes on all foreign key columns that are not already covered by existing indexes.
|
||||||
|
-- Uses CREATE INDEX IF NOT EXISTS to be idempotent (safe to re-run).
|
||||||
|
|
||||||
|
-- =====================================================
|
||||||
|
-- auth / user tables
|
||||||
|
-- =====================================================
|
||||||
|
|
||||||
|
-- user_authtoken: user_id (unique FK, but ensure index exists)
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_authtoken_user_id
|
||||||
|
ON user_authtoken (user_id);
|
||||||
|
|
||||||
|
-- user_userprofile: user_id (unique FK)
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_userprofile_user_id
|
||||||
|
ON user_userprofile (user_id);
|
||||||
|
|
||||||
|
-- user_confirmationcode: user_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_confirmationcode_user_id
|
||||||
|
ON user_confirmationcode (user_id);
|
||||||
|
|
||||||
|
-- user_passwordresetcode: user_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_passwordresetcode_user_id
|
||||||
|
ON user_passwordresetcode (user_id);
|
||||||
|
|
||||||
|
-- user_applesocialauth: user_id (unique FK)
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_applesocialauth_user_id
|
||||||
|
ON user_applesocialauth (user_id);
|
||||||
|
|
||||||
|
-- user_googlesocialauth: user_id (unique FK)
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_googlesocialauth_user_id
|
||||||
|
ON user_googlesocialauth (user_id);
|
||||||
|
|
||||||
|
-- =====================================================
|
||||||
|
-- push notification device tables
|
||||||
|
-- =====================================================
|
||||||
|
|
||||||
|
-- push_notifications_apnsdevice: user_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_apnsdevice_user_id
|
||||||
|
ON push_notifications_apnsdevice (user_id);
|
||||||
|
|
||||||
|
-- push_notifications_gcmdevice: user_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_gcmdevice_user_id
|
||||||
|
ON push_notifications_gcmdevice (user_id);
|
||||||
|
|
||||||
|
-- =====================================================
|
||||||
|
-- notification tables
|
||||||
|
-- =====================================================
|
||||||
|
|
||||||
|
-- notifications_notificationpreference: user_id (unique FK)
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_notificationpreference_user_id
|
||||||
|
ON notifications_notificationpreference (user_id);
|
||||||
|
|
||||||
|
-- =====================================================
|
||||||
|
-- subscription tables
|
||||||
|
-- =====================================================
|
||||||
|
|
||||||
|
-- subscription_usersubscription: user_id (unique FK)
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_subscription_user_id
|
||||||
|
ON subscription_usersubscription (user_id);
|
||||||
|
|
||||||
|
-- =====================================================
|
||||||
|
-- residence tables
|
||||||
|
-- =====================================================
|
||||||
|
|
||||||
|
-- residence_residence: owner_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_residence_owner_id
|
||||||
|
ON residence_residence (owner_id);
|
||||||
|
|
||||||
|
-- residence_residencesharecode: residence_id (may already exist from model index tag via GORM)
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sharecode_residence_id
|
||||||
|
ON residence_residencesharecode (residence_id);
|
||||||
|
|
||||||
|
-- residence_residencesharecode: created_by_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sharecode_created_by_id
|
||||||
|
ON residence_residencesharecode (created_by_id);
|
||||||
|
|
||||||
|
-- =====================================================
|
||||||
|
-- task tables
|
||||||
|
-- =====================================================
|
||||||
|
|
||||||
|
-- task_task: created_by_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_created_by_id
|
||||||
|
ON task_task (created_by_id);
|
||||||
|
|
||||||
|
-- task_task: assigned_to_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_assigned_to_id
|
||||||
|
ON task_task (assigned_to_id);
|
||||||
|
|
||||||
|
-- task_task: category_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_category_id
|
||||||
|
ON task_task (category_id);
|
||||||
|
|
||||||
|
-- task_task: priority_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_priority_id
|
||||||
|
ON task_task (priority_id);
|
||||||
|
|
||||||
|
-- task_task: frequency_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_frequency_id
|
||||||
|
ON task_task (frequency_id);
|
||||||
|
|
||||||
|
-- task_task: contractor_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_contractor_id
|
||||||
|
ON task_task (contractor_id);
|
||||||
|
|
||||||
|
-- task_task: parent_task_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_task_parent_task_id
|
||||||
|
ON task_task (parent_task_id);
|
||||||
|
|
||||||
|
-- task_taskcompletionimage: completion_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_completionimage_completion_id
|
||||||
|
ON task_taskcompletionimage (completion_id);
|
||||||
|
|
||||||
|
-- task_document: created_by_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_document_created_by_id
|
||||||
|
ON task_document (created_by_id);
|
||||||
|
|
||||||
|
-- task_document: task_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_document_task_id
|
||||||
|
ON task_document (task_id);
|
||||||
|
|
||||||
|
-- task_documentimage: document_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_documentimage_document_id
|
||||||
|
ON task_documentimage (document_id);
|
||||||
|
|
||||||
|
-- task_contractor: residence_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_contractor_residence_id
|
||||||
|
ON task_contractor (residence_id);
|
||||||
|
|
||||||
|
-- task_reminderlog: notification_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_reminderlog_notification_id
|
||||||
|
ON task_reminderlog (notification_id);
|
||||||
2
migrations/018_audit_log.down.sql
Normal file
2
migrations/018_audit_log.down.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
-- Rollback: 018_audit_log
|
||||||
|
DROP TABLE IF EXISTS audit_log;
|
||||||
16
migrations/018_audit_log.up.sql
Normal file
16
migrations/018_audit_log.up.sql
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
-- Migration: 018_audit_log
|
||||||
|
-- Create audit_log table for tracking security-relevant events (login, register, etc.)
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS audit_log (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
user_id INTEGER,
|
||||||
|
event_type VARCHAR(50) NOT NULL,
|
||||||
|
ip_address VARCHAR(45),
|
||||||
|
user_agent TEXT,
|
||||||
|
details JSONB,
|
||||||
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_audit_log_user_id ON audit_log(user_id);
|
||||||
|
CREATE INDEX idx_audit_log_event_type ON audit_log(event_type);
|
||||||
|
CREATE INDEX idx_audit_log_created_at ON audit_log(created_at);
|
||||||
Reference in New Issue
Block a user