diff --git a/internal/config/config.go b/internal/config/config.go index 45b72f8..148f63f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -134,6 +134,8 @@ type SecurityConfig struct { PasswordResetExpiry time.Duration ConfirmationExpiry time.Duration MaxPasswordResetRate int // per hour + TokenExpiryDays int // Number of days before auth tokens expire (default 90) + TokenRefreshDays int // Token must be at least this many days old before refresh (default 60) } // StorageConfig holds file storage settings @@ -262,6 +264,8 @@ func Load() (*Config, error) { PasswordResetExpiry: 15 * time.Minute, ConfirmationExpiry: 24 * time.Hour, MaxPasswordResetRate: 3, + TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"), + TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"), }, Storage: StorageConfig{ UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), @@ -369,6 +373,10 @@ func setDefaults() { viper.SetDefault("OVERDUE_REMINDER_HOUR", 15) // 9:00 AM UTC viper.SetDefault("DAILY_DIGEST_HOUR", 3) // 3:00 AM UTC + // Token expiry defaults + viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days + viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days + // Storage defaults viper.SetDefault("STORAGE_UPLOAD_DIR", "./uploads") viper.SetDefault("STORAGE_BASE_URL", "/uploads") diff --git a/internal/dto/requests/auth.go b/internal/dto/requests/auth.go index 92767e2..de745f2 100644 --- a/internal/dto/requests/auth.go +++ b/internal/dto/requests/auth.go @@ -11,7 +11,7 @@ type LoginRequest struct { type RegisterRequest struct { Username string `json:"username" validate:"required,min=3,max=150"` Email string `json:"email" validate:"required,email,max=254"` - Password string `json:"password" validate:"required,min=8"` + Password string `json:"password" validate:"required,min=8,password_complexity"` FirstName string `json:"first_name" validate:"max=150"` LastName string `json:"last_name" validate:"max=150"` } @@ -35,7 +35,7 @@ type VerifyResetCodeRequest struct { // ResetPasswordRequest represents the reset password request body type ResetPasswordRequest struct { ResetToken string `json:"reset_token" validate:"required"` - NewPassword string `json:"new_password" validate:"required,min=8"` + NewPassword string `json:"new_password" validate:"required,min=8,password_complexity"` } // UpdateProfileRequest represents the profile update request body diff --git a/internal/dto/responses/auth.go b/internal/dto/responses/auth.go index 2a4dae2..bf9e974 100644 --- a/internal/dto/responses/auth.go +++ b/internal/dto/responses/auth.go @@ -79,6 +79,12 @@ type ResetPasswordResponse struct { Message string `json:"message"` } +// RefreshTokenResponse represents the token refresh response +type RefreshTokenResponse struct { + Token string `json:"token"` + Message string `json:"message"` +} + // MessageResponse represents a simple message response type MessageResponse struct { Message string `json:"message"` diff --git a/internal/handlers/auth_handler.go b/internal/handlers/auth_handler.go index 03df15b..7ae2bcd 100644 --- a/internal/handlers/auth_handler.go +++ b/internal/handlers/auth_handler.go @@ -23,6 +23,7 @@ type AuthHandler struct { appleAuthService *services.AppleAuthService googleAuthService *services.GoogleAuthService storageService *services.StorageService + auditService *services.AuditService } // NewAuthHandler creates a new auth handler @@ -49,6 +50,11 @@ func (h *AuthHandler) SetStorageService(storageService *services.StorageService) h.storageService = storageService } +// SetAuditService sets the audit service for logging security events +func (h *AuthHandler) SetAuditService(auditService *services.AuditService) { + h.auditService = auditService +} + // Login handles POST /api/auth/login/ func (h *AuthHandler) Login(c echo.Context) error { var req requests.LoginRequest @@ -62,9 +68,19 @@ func (h *AuthHandler) Login(c echo.Context) error { response, err := h.authService.Login(&req) if err != nil { log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed") + if h.auditService != nil { + h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{ + "identifier": req.Username, + }) + } return err } + if h.auditService != nil { + userID := response.User.ID + h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil) + } + return c.JSON(http.StatusOK, response) } @@ -84,6 +100,14 @@ func (h *AuthHandler) Register(c echo.Context) error { return err } + if h.auditService != nil { + userID := response.User.ID + h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{ + "username": req.Username, + "email": req.Email, + }) + } + // Send welcome email with confirmation code (async) if h.emailService != nil && confirmationCode != "" { go func() { @@ -108,6 +132,14 @@ func (h *AuthHandler) Logout(c echo.Context) error { return apperrors.Unauthorized("error.not_authenticated") } + // Log audit event before invalidating the token + if h.auditService != nil { + user := middleware.GetAuthUser(c) + if user != nil { + h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil) + } + } + // Invalidate token in database if err := h.authService.Logout(token); err != nil { log.Warn().Err(err).Msg("Failed to delete token from database") @@ -270,6 +302,12 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error { }() } + if h.auditService != nil { + h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{ + "email": req.Email, + }) + } + // Always return success to prevent email enumeration return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{ Message: "Password reset email sent", @@ -314,6 +352,12 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error { return err } + if h.auditService != nil { + h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{ + "method": "reset_token", + }) + } + return c.JSON(http.StatusOK, responses.ResetPasswordResponse{ Message: "Password reset successful", }) @@ -413,6 +457,34 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error { return c.JSON(http.StatusOK, response) } +// RefreshToken handles POST /api/auth/refresh/ +func (h *AuthHandler) RefreshToken(c echo.Context) error { + user, err := middleware.MustGetAuthUser(c) + if err != nil { + return err + } + + token := middleware.GetAuthToken(c) + if token == "" { + return apperrors.Unauthorized("error.not_authenticated") + } + + response, err := h.authService.RefreshToken(token, user.ID) + if err != nil { + log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed") + return err + } + + // If the token was refreshed (new token), invalidate the old one from cache + if response.Token != token && h.cache != nil { + if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil { + log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh") + } + } + + return c.JSON(http.StatusOK, response) +} + // DeleteAccount handles DELETE /api/auth/account/ func (h *AuthHandler) DeleteAccount(c echo.Context) error { user, err := middleware.MustGetAuthUser(c) @@ -431,6 +503,14 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error { return err } + if h.auditService != nil { + h.auditService.LogEvent(c, &user.ID, services.AuditEventAccountDeleted, map[string]interface{}{ + "user_id": user.ID, + "username": user.Username, + "email": user.Email, + }) + } + // Delete files from disk (best effort, don't fail the request) if h.storageService != nil && len(fileURLs) > 0 { go func() { diff --git a/internal/i18n/translations/en.json b/internal/i18n/translations/en.json index 30564b3..f31f383 100644 --- a/internal/i18n/translations/en.json +++ b/internal/i18n/translations/en.json @@ -6,8 +6,11 @@ "error.email_taken": "Email already registered", "error.email_already_taken": "Email already taken", "error.registration_failed": "Registration failed", + "error.password_complexity": "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit", "error.not_authenticated": "Not authenticated", "error.invalid_token": "Invalid token", + "error.token_expired": "Your session has expired. Please log in again.", + "error.token_refresh_not_needed": "Token is still valid.", "error.failed_to_get_user": "Failed to get user", "error.failed_to_update_profile": "Failed to update profile", "error.invalid_verification_code": "Invalid verification code", diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index df4a6b7..dca1987 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -12,6 +12,7 @@ import ( "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/apperrors" + "github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/services" ) @@ -28,24 +29,56 @@ const ( // UserCacheTTL is how long full user records are cached in memory to // avoid hitting the database on every authenticated request. UserCacheTTL = 30 * time.Second + + // DefaultTokenExpiryDays is the default number of days before a token expires. + DefaultTokenExpiryDays = 90 ) // AuthMiddleware provides token authentication middleware type AuthMiddleware struct { - db *gorm.DB - cache *services.CacheService - userCache *UserCache + db *gorm.DB + cache *services.CacheService + userCache *UserCache + tokenExpiryDays int } // NewAuthMiddleware creates a new auth middleware instance func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { return &AuthMiddleware{ - db: db, - cache: cache, - userCache: NewUserCache(UserCacheTTL), + db: db, + cache: cache, + userCache: NewUserCache(UserCacheTTL), + tokenExpiryDays: DefaultTokenExpiryDays, } } +// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration +func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware { + expiryDays := DefaultTokenExpiryDays + if cfg != nil && cfg.Security.TokenExpiryDays > 0 { + expiryDays = cfg.Security.TokenExpiryDays + } + return &AuthMiddleware{ + db: db, + cache: cache, + userCache: NewUserCache(UserCacheTTL), + tokenExpiryDays: expiryDays, + } +} + +// TokenExpiryDuration returns the token expiry duration. +func (m *AuthMiddleware) TokenExpiryDuration() time.Duration { + return time.Duration(m.tokenExpiryDays) * 24 * time.Hour +} + +// isTokenExpired checks if a token's created timestamp indicates expiry. +func (m *AuthMiddleware) isTokenExpired(created time.Time) bool { + if created.IsZero() { + return false // Legacy tokens without created time are not expired + } + return time.Since(created) > m.TokenExpiryDuration() +} + // TokenAuth returns an Echo middleware that validates token authentication func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -56,7 +89,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { return apperrors.Unauthorized("error.not_authenticated") } - // Try to get user from cache first + // Try to get user from cache first (includes expiry check) user, err := m.getUserFromCache(c.Request().Context(), token) if err == nil && user != nil { // Cache hit - set user in context and continue @@ -65,16 +98,27 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { return next(c) } + // Check if the cache indicated token expiry + if err != nil && err.Error() == "token expired" { + return apperrors.Unauthorized("error.token_expired") + } + // Cache miss - look up token in database - user, err = m.getUserFromDatabase(token) + user, authToken, err := m.getUserFromDatabaseWithToken(token) if err != nil { log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed") return apperrors.Unauthorized("error.invalid_token") } - // Cache the user ID for future requests - if cacheErr := m.cacheUserID(c.Request().Context(), token, user.ID); cacheErr != nil { - log.Warn().Err(cacheErr).Msg("Failed to cache user ID") + // Check token expiry + if m.isTokenExpired(authToken.Created) { + log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired") + return apperrors.Unauthorized("error.token_expired") + } + + // Cache the user ID and token creation time for future requests + if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil { + log.Warn().Err(cacheErr).Msg("Failed to cache token info") } // Set user in context @@ -104,9 +148,9 @@ func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc { } // Try database - user, err = m.getUserFromDatabase(token) - if err == nil { - m.cacheUserID(c.Request().Context(), token, user.ID) + user, authToken, err := m.getUserFromDatabaseWithToken(token) + if err == nil && !m.isTokenExpired(authToken.Created) { + m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created) c.Set(AuthUserKey, user) c.Set(AuthTokenKey, token) } @@ -145,12 +189,13 @@ func extractToken(c echo.Context) (string, error) { // getUserFromCache tries to get user from Redis cache, then from the // in-memory user cache, before falling back to the database. +// Returns a "token expired" error if the cached creation time indicates expiry. func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { if m.cache == nil { return nil, fmt.Errorf("cache not available") } - userID, err := m.cache.GetCachedAuthToken(ctx, token) + userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token) if err != nil { if err == redis.Nil { return nil, fmt.Errorf("token not in cache") @@ -158,6 +203,15 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m return nil, err } + // Check token expiry from cached creation time + if createdUnix > 0 { + created := time.Unix(createdUnix, 0) + if m.isTokenExpired(created) { + m.cache.InvalidateAuthToken(ctx, token) + return nil, fmt.Errorf("token expired") + } + } + // Try in-memory user cache first to avoid a DB round-trip if cached := m.userCache.Get(userID); cached != nil { if !cached.IsActive { @@ -187,22 +241,38 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m return &user, nil } -// getUserFromDatabase looks up the token in the database and caches the -// resulting user record in memory. -func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { +// getUserFromDatabaseWithToken looks up the token in the database and returns +// both the user and the auth token record (for expiry checking). +func (m *AuthMiddleware) getUserFromDatabaseWithToken(token string) (*models.User, *models.AuthToken, error) { var authToken models.AuthToken if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil { - return nil, fmt.Errorf("token not found") + return nil, nil, fmt.Errorf("token not found") } // Check if user is active if !authToken.User.IsActive { - return nil, fmt.Errorf("user is inactive") + return nil, nil, fmt.Errorf("user is inactive") } // Store in in-memory cache for subsequent requests m.userCache.Set(&authToken.User) - return &authToken.User, nil + return &authToken.User, &authToken, nil +} + +// getUserFromDatabase looks up the token in the database and caches the +// resulting user record in memory. +// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking. +func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) { + user, _, err := m.getUserFromDatabaseWithToken(token) + return user, err +} + +// cacheTokenInfo caches the user ID and token creation time for a token +func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error { + if m.cache == nil { + return nil + } + return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix()) } // cacheUserID caches the user ID for a token diff --git a/internal/middleware/auth_expiry_test.go b/internal/middleware/auth_expiry_test.go new file mode 100644 index 0000000..b699982 --- /dev/null +++ b/internal/middleware/auth_expiry_test.go @@ -0,0 +1,165 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/treytartt/honeydue-api/internal/models" +) + +// setupTestDB creates a temporary in-memory SQLite database with the required +// tables for auth middleware tests. +func setupTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + + err = db.AutoMigrate(&models.User{}, &models.AuthToken{}) + require.NoError(t, err) + + return db +} + +// createTestUserAndToken creates a user and an auth token, then backdates the +// token's Created timestamp by the specified number of days. +func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) { + t.Helper() + + user := &models.User{ + Username: username, + Email: username + "@test.com", + IsActive: true, + } + require.NoError(t, user.SetPassword("password123")) + require.NoError(t, db.Create(user).Error) + + token := &models.AuthToken{ + UserID: user.ID, + } + require.NoError(t, db.Create(token).Error) + + // Backdate the token's Created timestamp after creation to bypass autoCreateTime + backdated := time.Now().UTC().AddDate(0, 0, -ageDays) + require.NoError(t, db.Model(token).Update("created", backdated).Error) + token.Created = backdated + + return user, token +} + +func TestTokenAuth_RejectsExpiredToken(t *testing.T) { + db := setupTestDB(t) + _, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry + + m := NewAuthMiddleware(db, nil) // No Redis cache for these tests + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) + req.Header.Set("Authorization", "Token "+token.Key) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := m.TokenAuth()(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.Error(t, err) + assert.Contains(t, err.Error(), "error.token_expired") +} + +func TestTokenAuth_AcceptsValidToken(t *testing.T) { + db := setupTestDB(t) + _, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry + + m := NewAuthMiddleware(db, nil) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) + req.Header.Set("Authorization", "Token "+token.Key) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := m.TokenAuth()(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Verify user was set in context + user := GetAuthUser(c) + require.NotNil(t, user) + assert.Equal(t, "valid_user", user.Username) +} + +func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) { + db := setupTestDB(t) + _, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry + + m := NewAuthMiddleware(db, nil) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) + req.Header.Set("Authorization", "Token "+token.Key) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := m.TokenAuth()(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestTokenAuth_RejectsInvalidToken(t *testing.T) { + db := setupTestDB(t) + + m := NewAuthMiddleware(db, nil) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) + req.Header.Set("Authorization", "Token nonexistent-token") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := m.TokenAuth()(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.Error(t, err) + assert.Contains(t, err.Error(), "error.invalid_token") +} + +func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) { + db := setupTestDB(t) + + m := NewAuthMiddleware(db, nil) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := m.TokenAuth()(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + require.Error(t, err) + assert.Contains(t, err.Error(), "error.not_authenticated") +} diff --git a/internal/models/audit_log.go b/internal/models/audit_log.go new file mode 100644 index 0000000..023ee1e --- /dev/null +++ b/internal/models/audit_log.go @@ -0,0 +1,48 @@ +package models + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "time" +) + +// JSONMap is a custom type for JSONB columns that handles JSON serialization +type JSONMap map[string]interface{} + +// Value implements driver.Valuer for database writes +func (j JSONMap) Value() (driver.Value, error) { + if j == nil { + return nil, nil + } + return json.Marshal(j) +} + +// Scan implements sql.Scanner for database reads +func (j *JSONMap) Scan(value interface{}) error { + if value == nil { + *j = nil + return nil + } + bytes, ok := value.([]byte) + if !ok { + return errors.New("audit_log: failed to scan JSONMap value") + } + return json.Unmarshal(bytes, j) +} + +// AuditLog represents the audit_log table for tracking security-relevant events +type AuditLog struct { + ID uint `gorm:"primaryKey" json:"id"` + UserID *uint `gorm:"column:user_id" json:"user_id"` + EventType string `gorm:"column:event_type;size:50;not null" json:"event_type"` + IPAddress string `gorm:"column:ip_address;size:45" json:"ip_address"` + UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent"` + Details JSONMap `gorm:"column:details;type:jsonb" json:"details"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` +} + +// TableName returns the table name for GORM +func (AuditLog) TableName() string { + return "audit_log" +} diff --git a/internal/push/circuit_breaker.go b/internal/push/circuit_breaker.go new file mode 100644 index 0000000..95af56f --- /dev/null +++ b/internal/push/circuit_breaker.go @@ -0,0 +1,167 @@ +package push + +import ( + "errors" + "sync" + "time" +) + +// Circuit breaker states +const ( + stateClosed = iota // Normal operation, requests pass through + stateOpen // Too many failures, requests are rejected + stateHalfOpen // Testing recovery, one request allowed through +) + +// Default circuit breaker settings +const ( + defaultFailureThreshold = 5 // Open after this many consecutive failures + defaultRecoveryTimeout = 30 * time.Second // Try again after this duration +) + +// ErrCircuitOpen is returned when the circuit breaker is open and rejecting requests. +var ErrCircuitOpen = errors.New("circuit breaker is open") + +// CircuitBreaker implements a simple circuit breaker pattern for external service calls. +// It is thread-safe and requires no external dependencies. +// +// States: +// - Closed: normal operation, all requests pass through. Consecutive failures are counted. +// - Open: after reaching the failure threshold, all requests are immediately rejected +// with ErrCircuitOpen until the recovery timeout elapses. +// - Half-Open: after the recovery timeout, one request is allowed through. If it +// succeeds the breaker resets to Closed; if it fails it returns to Open. +type CircuitBreaker struct { + mu sync.Mutex + state int + failureCount int + failureThreshold int + recoveryTimeout time.Duration + lastFailureTime time.Time + name string // For logging +} + +// CircuitBreakerOption configures a CircuitBreaker. +type CircuitBreakerOption func(*CircuitBreaker) + +// WithFailureThreshold sets the number of consecutive failures before opening the circuit. +func WithFailureThreshold(n int) CircuitBreakerOption { + return func(cb *CircuitBreaker) { + if n > 0 { + cb.failureThreshold = n + } + } +} + +// WithRecoveryTimeout sets how long the circuit stays open before trying half-open. +func WithRecoveryTimeout(d time.Duration) CircuitBreakerOption { + return func(cb *CircuitBreaker) { + if d > 0 { + cb.recoveryTimeout = d + } + } +} + +// NewCircuitBreaker creates a new CircuitBreaker with the given name and options. +// The name is used for logging and identification. +func NewCircuitBreaker(name string, opts ...CircuitBreakerOption) *CircuitBreaker { + cb := &CircuitBreaker{ + state: stateClosed, + failureThreshold: defaultFailureThreshold, + recoveryTimeout: defaultRecoveryTimeout, + name: name, + } + for _, opt := range opts { + opt(cb) + } + return cb +} + +// Allow checks whether a request should be allowed through. +// It returns true if the request can proceed, false if the circuit is open. +// When transitioning from open to half-open, it returns true for the probe request. +func (cb *CircuitBreaker) Allow() bool { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case stateClosed: + return true + case stateOpen: + // Check if recovery timeout has elapsed + if time.Since(cb.lastFailureTime) >= cb.recoveryTimeout { + cb.state = stateHalfOpen + return true + } + return false + case stateHalfOpen: + // Only one request at a time in half-open state. + // The first caller that got here via Allow() is already in flight; + // reject subsequent callers until that probe resolves. + return false + default: + return true + } +} + +// RecordSuccess records a successful request. If the breaker is half-open, it resets to closed. +func (cb *CircuitBreaker) RecordSuccess() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.failureCount = 0 + cb.state = stateClosed +} + +// RecordFailure records a failed request. If the failure threshold is reached, the +// breaker transitions to the open state. +func (cb *CircuitBreaker) RecordFailure() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.failureCount++ + cb.lastFailureTime = time.Now() + + if cb.failureCount >= cb.failureThreshold { + cb.state = stateOpen + } +} + +// State returns the current state of the circuit breaker as a human-readable string. +func (cb *CircuitBreaker) State() string { + cb.mu.Lock() + defer cb.mu.Unlock() + + switch cb.state { + case stateClosed: + return "closed" + case stateOpen: + return "open" + case stateHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// Name returns the circuit breaker's name. +func (cb *CircuitBreaker) Name() string { + return cb.name +} + +// Reset resets the circuit breaker to the closed state with zero failures. +func (cb *CircuitBreaker) Reset() { + cb.mu.Lock() + defer cb.mu.Unlock() + + cb.state = stateClosed + cb.failureCount = 0 + cb.lastFailureTime = time.Time{} +} + +// Counts returns the current failure count (useful for testing and monitoring). +func (cb *CircuitBreaker) Counts() int { + cb.mu.Lock() + defer cb.mu.Unlock() + return cb.failureCount +} diff --git a/internal/push/circuit_breaker_test.go b/internal/push/circuit_breaker_test.go new file mode 100644 index 0000000..5c13e45 --- /dev/null +++ b/internal/push/circuit_breaker_test.go @@ -0,0 +1,275 @@ +package push + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCircuitBreaker_StartsInClosedState(t *testing.T) { + cb := NewCircuitBreaker("test") + assert.Equal(t, "closed", cb.State()) + assert.True(t, cb.Allow()) +} + +func TestCircuitBreaker_OpensAfterThresholdFailures(t *testing.T) { + cb := NewCircuitBreaker("test", WithFailureThreshold(3)) + + // First two failures should keep it closed + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + assert.True(t, cb.Allow()) + + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + assert.True(t, cb.Allow()) + + // Third failure should open it + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + assert.False(t, cb.Allow()) +} + +func TestCircuitBreaker_DefaultThresholdIsFive(t *testing.T) { + cb := NewCircuitBreaker("test") + + for i := 0; i < 4; i++ { + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + } + + cb.RecordFailure() // 5th failure + assert.Equal(t, "open", cb.State()) +} + +func TestCircuitBreaker_RejectsRequestsWhenOpen(t *testing.T) { + cb := NewCircuitBreaker("test", WithFailureThreshold(1)) + + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + + // Multiple calls should all be rejected + for i := 0; i < 10; i++ { + assert.False(t, cb.Allow()) + } +} + +func TestCircuitBreaker_TransitionsToHalfOpenAfterRecoveryTimeout(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(1), + WithRecoveryTimeout(50*time.Millisecond), + ) + + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + assert.False(t, cb.Allow()) + + // Wait for recovery timeout + time.Sleep(60 * time.Millisecond) + + // Should now allow one request (half-open) + assert.True(t, cb.Allow()) + assert.Equal(t, "half-open", cb.State()) +} + +func TestCircuitBreaker_HalfOpenRejectsSecondRequest(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(1), + WithRecoveryTimeout(50*time.Millisecond), + ) + + cb.RecordFailure() + time.Sleep(60 * time.Millisecond) + + // First request allowed (probe) + assert.True(t, cb.Allow()) + assert.Equal(t, "half-open", cb.State()) + + // Second request rejected while probe is in flight + assert.False(t, cb.Allow()) +} + +func TestCircuitBreaker_HalfOpenSuccess_ResetsToClosed(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(1), + WithRecoveryTimeout(50*time.Millisecond), + ) + + cb.RecordFailure() + time.Sleep(60 * time.Millisecond) + + // Probe request + assert.True(t, cb.Allow()) + + // Probe succeeds + cb.RecordSuccess() + assert.Equal(t, "closed", cb.State()) + assert.Equal(t, 0, cb.Counts()) + + // Normal operation resumes + assert.True(t, cb.Allow()) + assert.True(t, cb.Allow()) +} + +func TestCircuitBreaker_HalfOpenFailure_ReturnsToOpen(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(2), + WithRecoveryTimeout(50*time.Millisecond), + ) + + // Open the circuit + cb.RecordFailure() + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + + time.Sleep(60 * time.Millisecond) + + // Probe request + assert.True(t, cb.Allow()) + assert.Equal(t, "half-open", cb.State()) + + // Probe fails - the failure count is now 3 which is >= threshold of 2 + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + assert.False(t, cb.Allow()) +} + +func TestCircuitBreaker_SuccessResetsFailureCount(t *testing.T) { + cb := NewCircuitBreaker("test", WithFailureThreshold(3)) + + cb.RecordFailure() + cb.RecordFailure() + assert.Equal(t, 2, cb.Counts()) + + // A success should reset the counter + cb.RecordSuccess() + assert.Equal(t, 0, cb.Counts()) + assert.Equal(t, "closed", cb.State()) + + // Now it should take 3 more failures to open + cb.RecordFailure() + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) +} + +func TestCircuitBreaker_Reset(t *testing.T) { + cb := NewCircuitBreaker("test", WithFailureThreshold(1)) + + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + + cb.Reset() + assert.Equal(t, "closed", cb.State()) + assert.Equal(t, 0, cb.Counts()) + assert.True(t, cb.Allow()) +} + +func TestCircuitBreaker_Name(t *testing.T) { + cb := NewCircuitBreaker("apns-breaker") + assert.Equal(t, "apns-breaker", cb.Name()) +} + +func TestCircuitBreaker_CustomOptions(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(10), + WithRecoveryTimeout(5*time.Minute), + ) + + // Should take 10 failures to open + for i := 0; i < 9; i++ { + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + } + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) +} + +func TestCircuitBreaker_InvalidOptionsIgnored(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(0), // Should be ignored (keeps default) + WithRecoveryTimeout(-1), // Should be ignored (keeps default) + ) + + // Default threshold of 5 should still apply + for i := 0; i < 4; i++ { + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + } + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) +} + +func TestCircuitBreaker_ThreadSafety(t *testing.T) { + cb := NewCircuitBreaker("test", + WithFailureThreshold(100), + WithRecoveryTimeout(10*time.Millisecond), + ) + + var wg sync.WaitGroup + const goroutines = 50 + const iterations = 100 + + // Hammer it from many goroutines + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < iterations; j++ { + cb.Allow() + if j%2 == 0 { + cb.RecordFailure() + } else { + cb.RecordSuccess() + } + _ = cb.State() + _ = cb.Counts() + } + }(i) + } + + wg.Wait() + + // Should not panic or deadlock. State should be valid. + state := cb.State() + require.Contains(t, []string{"closed", "open", "half-open"}, state) +} + +func TestCircuitBreaker_FullLifecycle(t *testing.T) { + cb := NewCircuitBreaker("lifecycle-test", + WithFailureThreshold(3), + WithRecoveryTimeout(50*time.Millisecond), + ) + + // 1. Closed: requests flow normally + assert.True(t, cb.Allow()) + cb.RecordSuccess() + assert.Equal(t, "closed", cb.State()) + + // 2. Accumulate failures + cb.RecordFailure() + cb.RecordFailure() + assert.Equal(t, "closed", cb.State()) + + // 3. Third failure opens the circuit + cb.RecordFailure() + assert.Equal(t, "open", cb.State()) + assert.False(t, cb.Allow()) + + // 4. Wait for recovery + time.Sleep(60 * time.Millisecond) + + // 5. Half-open: probe request allowed + assert.True(t, cb.Allow()) + assert.Equal(t, "half-open", cb.State()) + + // 6. Probe succeeds, back to closed + cb.RecordSuccess() + assert.Equal(t, "closed", cb.State()) + assert.True(t, cb.Allow()) +} diff --git a/internal/push/client.go b/internal/push/client.go index a6aabbe..d542ddd 100644 --- a/internal/push/client.go +++ b/internal/push/client.go @@ -2,6 +2,7 @@ package push import ( "context" + "time" "github.com/rs/zerolog/log" @@ -14,16 +15,25 @@ const ( PlatformAndroid = "android" ) +// Timeout for individual push notification send operations. +const pushSendTimeout = 15 * time.Second + // Client provides a unified interface for sending push notifications type Client struct { - apns *APNsClient - fcm *FCMClient - enabled bool + apns *APNsClient + fcm *FCMClient + enabled bool + apnsBreaker *CircuitBreaker + fcmBreaker *CircuitBreaker } // NewClient creates a new unified push notification client func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) { - client := &Client{enabled: enabled} + client := &Client{ + enabled: enabled, + apnsBreaker: NewCircuitBreaker("apns"), + fcmBreaker: NewCircuitBreaker("fcm"), + } // Initialize APNs client (iOS) if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" { @@ -54,7 +64,8 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) { return client, nil } -// SendToIOS sends a push notification to iOS devices +// SendToIOS sends a push notification to iOS devices. +// The call is guarded by a circuit breaker and uses a context timeout. func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error { if !c.enabled { log.Debug().Msg("Push notifications disabled by feature flag") @@ -64,10 +75,26 @@ func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message log.Warn().Msg("APNs client not initialized, skipping iOS push") return nil } - return c.apns.Send(ctx, tokens, title, message, data) + if !c.apnsBreaker.Allow() { + log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS push") + return ErrCircuitOpen + } + + sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout) + defer cancel() + + err := c.apns.Send(sendCtx, tokens, title, message, data) + if err != nil { + c.apnsBreaker.RecordFailure() + log.Warn().Err(err).Str("breaker_state", c.apnsBreaker.State()).Msg("APNs send failed, recorded circuit breaker failure") + return err + } + c.apnsBreaker.RecordSuccess() + return nil } -// SendToAndroid sends a push notification to Android devices +// SendToAndroid sends a push notification to Android devices. +// The call is guarded by a circuit breaker and uses a context timeout. func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error { if !c.enabled { log.Debug().Msg("Push notifications disabled by feature flag") @@ -77,7 +104,22 @@ func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, mess log.Warn().Msg("FCM client not initialized, skipping Android push") return nil } - return c.fcm.Send(ctx, tokens, title, message, data) + if !c.fcmBreaker.Allow() { + log.Warn().Str("breaker", c.fcmBreaker.Name()).Msg("FCM circuit breaker is open, skipping Android push") + return ErrCircuitOpen + } + + sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout) + defer cancel() + + err := c.fcm.Send(sendCtx, tokens, title, message, data) + if err != nil { + c.fcmBreaker.RecordFailure() + log.Warn().Err(err).Str("breaker_state", c.fcmBreaker.State()).Msg("FCM send failed, recorded circuit breaker failure") + return err + } + c.fcmBreaker.RecordSuccess() + return nil } // SendToAll sends a push notification to both iOS and Android devices @@ -115,8 +157,9 @@ func (c *Client) IsAndroidEnabled() bool { return c.fcm != nil } -// SendActionableNotification sends notifications with action button support -// iOS receives a category for actionable notifications, Android handles actions via data payload +// SendActionableNotification sends notifications with action button support. +// iOS receives a category for actionable notifications, Android handles actions via data payload. +// Both platforms are guarded by their respective circuit breakers. func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error { if !c.enabled { log.Debug().Msg("Push notifications disabled by feature flag") @@ -127,10 +170,19 @@ func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, andr if len(iosTokens) > 0 { if c.apns == nil { log.Warn().Msg("APNs client not initialized, skipping iOS actionable push") + } else if !c.apnsBreaker.Allow() { + log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS actionable push") + lastErr = ErrCircuitOpen } else { - if err := c.apns.SendWithCategory(ctx, iosTokens, title, message, data, iosCategoryID); err != nil { + sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout) + err := c.apns.SendWithCategory(sendCtx, iosTokens, title, message, data, iosCategoryID) + cancel() + if err != nil { + c.apnsBreaker.RecordFailure() log.Error().Err(err).Msg("Failed to send iOS actionable notifications") lastErr = err + } else { + c.apnsBreaker.RecordSuccess() } } } diff --git a/internal/push/fcm.go b/internal/push/fcm.go index c2fcfeb..f7cfff2 100644 --- a/internal/push/fcm.go +++ b/internal/push/fcm.go @@ -165,7 +165,7 @@ func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) { } httpClient := &http.Client{ - Timeout: 30 * time.Second, + Timeout: 15 * time.Second, Transport: transport, } diff --git a/internal/repositories/user_repo.go b/internal/repositories/user_repo.go index 1670005..3558a32 100644 --- a/internal/repositories/user_repo.go +++ b/internal/repositories/user_repo.go @@ -173,6 +173,27 @@ func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error return &token, nil } +// FindTokenByKey looks up an auth token by its key value. +func (r *UserRepository) FindTokenByKey(key string) (*models.AuthToken, error) { + var token models.AuthToken + if err := r.db.Where("key = ?", key).First(&token).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrTokenNotFound + } + return nil, err + } + return &token, nil +} + +// CreateToken creates a new auth token for a user. +func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) { + token := models.AuthToken{UserID: userID} + if err := r.db.Create(&token).Error; err != nil { + return nil, err + } + return &token, nil +} + // DeleteToken deletes an auth token func (r *UserRepository) DeleteToken(token string) error { result := r.db.Where("key = ?", token).Delete(&models.AuthToken{}) diff --git a/internal/router/health_test.go b/internal/router/health_test.go new file mode 100644 index 0000000..7dbbc95 --- /dev/null +++ b/internal/router/health_test.go @@ -0,0 +1,136 @@ +package router + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// setupTestDB creates an in-memory SQLite database for health check tests +func setupTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + return db +} + +func TestLiveCheck_Returns200(t *testing.T) { + e := echo.New() + e.GET("/api/health/live", liveCheck) + + req := httptest.NewRequest(http.MethodGet, "/api/health/live", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + + assert.Equal(t, "alive", resp["status"]) + assert.Equal(t, Version, resp["version"]) + assert.Contains(t, resp, "timestamp") +} + +func TestReadinessCheck_HealthyDB_NilCache_Returns200(t *testing.T) { + db := setupTestDB(t) + + deps := &Dependencies{ + DB: db, + Cache: nil, // No Redis configured + } + + e := echo.New() + e.GET("/api/health/", readinessCheck(deps)) + + req := httptest.NewRequest(http.MethodGet, "/api/health/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + + assert.Equal(t, "healthy", resp["status"]) + assert.Equal(t, Version, resp["version"]) + + checks := resp["checks"].(map[string]interface{}) + assert.Equal(t, "ok", checks["postgres"]) + assert.Equal(t, "not configured", checks["redis"]) +} + +func TestReadinessCheck_DBDown_Returns503(t *testing.T) { + // Open an in-memory SQLite DB, then close the underlying sql.DB to simulate failure + db := setupTestDB(t) + sqlDB, err := db.DB() + require.NoError(t, err) + sqlDB.Close() + + deps := &Dependencies{ + DB: db, + Cache: nil, + } + + e := echo.New() + e.GET("/api/health/", readinessCheck(deps)) + + req := httptest.NewRequest(http.MethodGet, "/api/health/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var resp map[string]interface{} + err = json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + + assert.Equal(t, "unhealthy", resp["status"]) + + checks := resp["checks"].(map[string]interface{}) + assert.Contains(t, checks["postgres"], "ping failed") +} + +func TestReadinessCheck_ResponseFormat(t *testing.T) { + db := setupTestDB(t) + + deps := &Dependencies{ + DB: db, + Cache: nil, + } + + e := echo.New() + e.GET("/api/health/", readinessCheck(deps)) + + req := httptest.NewRequest(http.MethodGet, "/api/health/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + var resp map[string]interface{} + err := json.Unmarshal(rec.Body.Bytes(), &resp) + require.NoError(t, err) + + // Verify all expected fields are present + assert.Contains(t, resp, "status") + assert.Contains(t, resp, "version") + assert.Contains(t, resp, "checks") + assert.Contains(t, resp, "timestamp") + + // Verify checks is a map with expected keys + checks, ok := resp["checks"].(map[string]interface{}) + require.True(t, ok, "checks should be a map") + assert.Contains(t, checks, "postgres") + assert.Contains(t, checks, "redis") +} diff --git a/internal/router/router.go b/internal/router/router.go index 32399fb..441a79e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,6 +1,7 @@ package router import ( + "context" "errors" "fmt" "net/http" @@ -62,11 +63,12 @@ func SetupRouter(deps *Dependencies) *echo.Echo { // Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.) e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ - XSSProtection: "1; mode=block", - ContentTypeNosniff: "nosniff", - XFrameOptions: "SAMEORIGIN", - HSTSMaxAge: 31536000, // 1 year in seconds - ReferrerPolicy: "strict-origin-when-cross-origin", + XSSProtection: "1; mode=block", + ContentTypeNosniff: "nosniff", + XFrameOptions: "SAMEORIGIN", + HSTSMaxAge: 31536000, // 1 year in seconds + ReferrerPolicy: "strict-origin-when-cross-origin", + ContentSecurityPolicy: "default-src 'none'; frame-ancestors 'none'", })) e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{ Limit: "1M", // 1MB default for JSON payloads @@ -93,6 +95,14 @@ func SetupRouter(deps *Dependencies) *echo.Echo { e.Use(corsMiddleware(cfg)) e.Use(i18n.Middleware()) + // Gzip compression (skip media endpoints since they serve binary files) + e.Use(middleware.GzipWithConfig(middleware.GzipConfig{ + Level: 5, + Skipper: func(c echo.Context) bool { + return strings.HasPrefix(c.Request().URL.Path, "/api/media/") + }, + })) + // Monitoring metrics middleware (if monitoring is enabled) if deps.MonitoringService != nil { if metricsMiddleware := deps.MonitoringService.MetricsMiddleware(); metricsMiddleware != nil { @@ -114,8 +124,9 @@ func SetupRouter(deps *Dependencies) *echo.Echo { }) } - // Health check endpoint (no auth required) - e.GET("/api/health/", healthCheck) + // Health check endpoints (no auth required) + e.GET("/api/health/", readinessCheck(deps)) + e.GET("/api/health/live", liveCheck) // Initialize onboarding email service for tracking handler onboardingService := services.NewOnboardingEmailService(deps.DB, deps.EmailService, cfg.Server.BaseURL) @@ -172,17 +183,21 @@ func SetupRouter(deps *Dependencies) *echo.Echo { subscriptionWebhookHandler.SetStripeService(stripeService) // Initialize middleware - authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache) + authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg) // Initialize Apple auth service appleAuthService := services.NewAppleAuthService(deps.Cache, cfg) googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg) + // Initialize audit service for security event logging + auditService := services.NewAuditService(deps.DB) + // Initialize handlers authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache) authHandler.SetAppleAuthService(appleAuthService) authHandler.SetGoogleAuthService(googleAuthService) authHandler.SetStorageService(deps.StorageService) + authHandler.SetAuditService(auditService) userHandler := handlers.NewUserHandler(userService) residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled) taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService) @@ -201,6 +216,11 @@ func SetupRouter(deps *Dependencies) *echo.Echo { mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService) } + // Prometheus metrics endpoint (no auth required, for scraping) + if deps.MonitoringService != nil { + e.GET("/metrics", prometheusMetrics(deps.MonitoringService)) + } + // Set up admin routes with monitoring handler (if available) var monitoringHandler *monitoring.Handler if deps.MonitoringService != nil { @@ -295,16 +315,126 @@ func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc { }) } -// healthCheck returns API health status -func healthCheck(c echo.Context) error { +// liveCheck returns a simple 200 for Kubernetes liveness probes +func liveCheck(c echo.Context) error { return c.JSON(http.StatusOK, map[string]interface{}{ - "status": "healthy", + "status": "alive", "version": Version, - "framework": "Echo", "timestamp": time.Now().UTC().Format(time.RFC3339), }) } +// readinessCheck returns 200 if PostgreSQL and Redis are reachable, 503 otherwise. +// This is used by Kubernetes readiness probes and load balancers. +func readinessCheck(deps *Dependencies) echo.HandlerFunc { + return func(c echo.Context) error { + ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) + defer cancel() + + status := "healthy" + httpStatus := http.StatusOK + checks := make(map[string]string) + + // Check PostgreSQL + sqlDB, err := deps.DB.DB() + if err != nil { + checks["postgres"] = fmt.Sprintf("failed to get sql.DB: %v", err) + status = "unhealthy" + httpStatus = http.StatusServiceUnavailable + } else if err := sqlDB.PingContext(ctx); err != nil { + checks["postgres"] = fmt.Sprintf("ping failed: %v", err) + status = "unhealthy" + httpStatus = http.StatusServiceUnavailable + } else { + checks["postgres"] = "ok" + } + + // Check Redis (if cache service is available) + if deps.Cache != nil { + if err := deps.Cache.Client().Ping(ctx).Err(); err != nil { + checks["redis"] = fmt.Sprintf("ping failed: %v", err) + status = "unhealthy" + httpStatus = http.StatusServiceUnavailable + } else { + checks["redis"] = "ok" + } + } else { + checks["redis"] = "not configured" + } + + return c.JSON(httpStatus, map[string]interface{}{ + "status": status, + "version": Version, + "checks": checks, + "timestamp": time.Now().UTC().Format(time.RFC3339), + }) + } +} + +// prometheusMetrics returns an Echo handler that outputs metrics in Prometheus text format. +// It uses the existing monitoring service's HTTP stats collector to avoid adding external dependencies. +func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc { + return func(c echo.Context) error { + httpCollector := monSvc.HTTPCollector() + if httpCollector == nil { + return c.String(http.StatusOK, "# No HTTP metrics available (collector not initialized)\n") + } + + stats := httpCollector.GetStats() + var b strings.Builder + + // Request count by method+path+status + b.WriteString("# HELP http_requests_total Total number of HTTP requests.\n") + b.WriteString("# TYPE http_requests_total counter\n") + for statusCode, count := range stats.ByStatusCode { + fmt.Fprintf(&b, "http_requests_total{status_code=\"%d\"} %d\n", statusCode, count) + } + + // Per-endpoint request count + b.WriteString("# HELP http_endpoint_requests_total Total requests per endpoint.\n") + b.WriteString("# TYPE http_endpoint_requests_total counter\n") + for endpoint, epStats := range stats.ByEndpoint { + // endpoint is "METHOD /path" + parts := strings.SplitN(endpoint, " ", 2) + method := endpoint + path := "" + if len(parts) == 2 { + method = parts[0] + path = parts[1] + } + fmt.Fprintf(&b, "http_endpoint_requests_total{method=\"%s\",path=\"%s\"} %d\n", method, path, epStats.Count) + } + + // Request duration (avg latency as a gauge, since we don't have raw histogram buckets) + b.WriteString("# HELP http_request_duration_ms Average request duration in milliseconds per endpoint.\n") + b.WriteString("# TYPE http_request_duration_ms gauge\n") + for endpoint, epStats := range stats.ByEndpoint { + parts := strings.SplitN(endpoint, " ", 2) + method := endpoint + path := "" + if len(parts) == 2 { + method = parts[0] + path = parts[1] + } + fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"avg\"} %.2f\n", method, path, epStats.AvgLatencyMs) + fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"p95\"} %.2f\n", method, path, epStats.P95LatencyMs) + } + + // Error rate + b.WriteString("# HELP http_error_rate Overall error rate (4xx+5xx / total).\n") + b.WriteString("# TYPE http_error_rate gauge\n") + fmt.Fprintf(&b, "http_error_rate %.4f\n", stats.ErrorRate) + + // Requests per minute + b.WriteString("# HELP http_requests_per_minute Current request rate.\n") + b.WriteString("# TYPE http_requests_per_minute gauge\n") + fmt.Fprintf(&b, "http_requests_per_minute %.2f\n", stats.RequestsPerMinute) + + c.Response().Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + return c.String(http.StatusOK, b.String()) + } +} + // setupPublicAuthRoutes configures public authentication routes with // per-endpoint rate limiters to mitigate brute-force and credential-stuffing. // Rate limiters are disabled in debug mode to allow UI test suites to run @@ -342,6 +472,7 @@ func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler auth := api.Group("/auth") { auth.POST("/logout/", authHandler.Logout) + auth.POST("/refresh/", authHandler.RefreshToken) auth.GET("/me/", authHandler.CurrentUser) auth.PUT("/profile/", authHandler.UpdateProfile) auth.PATCH("/profile/", authHandler.UpdateProfile) diff --git a/internal/services/audit_service.go b/internal/services/audit_service.go new file mode 100644 index 0000000..703e1d2 --- /dev/null +++ b/internal/services/audit_service.go @@ -0,0 +1,93 @@ +package services + +import ( + "sync" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" + "gorm.io/gorm" + + "github.com/treytartt/honeydue-api/internal/models" +) + +// Audit event type constants +const ( + AuditEventLogin = "auth.login" + AuditEventLoginFailed = "auth.login_failed" + AuditEventRegister = "auth.register" + AuditEventLogout = "auth.logout" + AuditEventPasswordReset = "auth.password_reset" + AuditEventPasswordChanged = "auth.password_changed" + AuditEventAccountDeleted = "auth.account_deleted" +) + +// AuditService handles audit logging for security-relevant events. +// It writes audit log entries asynchronously via a buffered channel to avoid +// blocking request handlers. +type AuditService struct { + db *gorm.DB + logChan chan *models.AuditLog + done chan struct{} + stopOnce sync.Once +} + +// NewAuditService creates a new audit service with a buffered channel for async writes. +// Call Stop() when shutting down to flush remaining entries. +func NewAuditService(db *gorm.DB) *AuditService { + s := &AuditService{ + db: db, + logChan: make(chan *models.AuditLog, 256), + done: make(chan struct{}), + } + go s.processLogs() + return s +} + +// processLogs drains the log channel and writes entries to the database. +func (s *AuditService) processLogs() { + defer close(s.done) + for entry := range s.logChan { + if err := s.db.Create(entry).Error; err != nil { + log.Error().Err(err). + Str("event_type", entry.EventType). + Msg("Failed to write audit log entry") + } + } +} + +// Stop closes the log channel and waits for all pending entries to be written. +// It is safe to call Stop multiple times. +func (s *AuditService) Stop() { + s.stopOnce.Do(func() { + close(s.logChan) + }) + <-s.done +} + +// LogEvent records an audit event. It extracts the client IP and User-Agent from +// the Echo context and sends the entry to the background writer. If the channel +// is full the entry is dropped and an error is logged (non-blocking). +func (s *AuditService) LogEvent(c echo.Context, userID *uint, eventType string, details map[string]interface{}) { + var ip, ua string + if c != nil { + ip = c.RealIP() + ua = c.Request().Header.Get("User-Agent") + } + + entry := &models.AuditLog{ + UserID: userID, + EventType: eventType, + IPAddress: ip, + UserAgent: ua, + Details: models.JSONMap(details), + } + + select { + case s.logChan <- entry: + // sent + default: + log.Warn(). + Str("event_type", eventType). + Msg("Audit log channel full, dropping entry") + } +} diff --git a/internal/services/audit_service_test.go b/internal/services/audit_service_test.go new file mode 100644 index 0000000..4072ae7 --- /dev/null +++ b/internal/services/audit_service_test.go @@ -0,0 +1,176 @@ +package services + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/treytartt/honeydue-api/internal/models" + "github.com/treytartt/honeydue-api/internal/testutil" +) + +func TestAuditService_LogEvent_WritesToDatabase(t *testing.T) { + db := testutil.SetupTestDB(t) + svc := NewAuditService(db) + defer svc.Stop() + + // Create a fake Echo context + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + req.Header.Set("User-Agent", "TestAgent/1.0") + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + userID := uint(42) + svc.LogEvent(c, &userID, AuditEventLogin, map[string]interface{}{ + "method": "password", + }) + + // Stop flushes the channel + svc.Stop() + + var entries []models.AuditLog + err := db.Find(&entries).Error + require.NoError(t, err) + require.Len(t, entries, 1) + + entry := entries[0] + assert.Equal(t, uint(42), *entry.UserID) + assert.Equal(t, AuditEventLogin, entry.EventType) + assert.Equal(t, "TestAgent/1.0", entry.UserAgent) + assert.NotEmpty(t, entry.IPAddress) + assert.Equal(t, "password", entry.Details["method"]) + assert.False(t, entry.CreatedAt.IsZero()) +} + +func TestAuditService_LogEvent_NilUserID(t *testing.T) { + db := testutil.SetupTestDB(t) + svc := NewAuditService(db) + defer svc.Stop() + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + svc.LogEvent(c, nil, AuditEventLoginFailed, map[string]interface{}{ + "identifier": "unknown@test.com", + }) + + svc.Stop() + + var entries []models.AuditLog + err := db.Find(&entries).Error + require.NoError(t, err) + require.Len(t, entries, 1) + + assert.Nil(t, entries[0].UserID) + assert.Equal(t, AuditEventLoginFailed, entries[0].EventType) +} + +func TestAuditService_LogEvent_NilContext(t *testing.T) { + db := testutil.SetupTestDB(t) + svc := NewAuditService(db) + defer svc.Stop() + + userID := uint(1) + svc.LogEvent(nil, &userID, AuditEventLogout, nil) + + svc.Stop() + + var entries []models.AuditLog + err := db.Find(&entries).Error + require.NoError(t, err) + require.Len(t, entries, 1) + + assert.Equal(t, AuditEventLogout, entries[0].EventType) + assert.Empty(t, entries[0].IPAddress) + assert.Empty(t, entries[0].UserAgent) +} + +func TestAuditService_LogEvent_MultipleEvents(t *testing.T) { + db := testutil.SetupTestDB(t) + svc := NewAuditService(db) + defer svc.Stop() + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + userID := uint(10) + svc.LogEvent(c, &userID, AuditEventRegister, nil) + svc.LogEvent(c, &userID, AuditEventLogin, nil) + svc.LogEvent(c, &userID, AuditEventLogout, nil) + + svc.Stop() + + var count int64 + err := db.Model(&models.AuditLog{}).Count(&count).Error + require.NoError(t, err) + assert.Equal(t, int64(3), count) +} + +func TestAuditService_EventTypeConstants(t *testing.T) { + // Verify all event constants have expected values + assert.Equal(t, "auth.login", AuditEventLogin) + assert.Equal(t, "auth.login_failed", AuditEventLoginFailed) + assert.Equal(t, "auth.register", AuditEventRegister) + assert.Equal(t, "auth.logout", AuditEventLogout) + assert.Equal(t, "auth.password_reset", AuditEventPasswordReset) + assert.Equal(t, "auth.password_changed", AuditEventPasswordChanged) + assert.Equal(t, "auth.account_deleted", AuditEventAccountDeleted) +} + +func TestAuditService_Stop_FlushesRemainingEntries(t *testing.T) { + db := testutil.SetupTestDB(t) + svc := NewAuditService(db) + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Send many events quickly + for i := 0; i < 50; i++ { + uid := uint(i) + svc.LogEvent(c, &uid, AuditEventLogin, nil) + } + + // Stop should block until all entries are written + svc.Stop() + + var count int64 + err := db.Model(&models.AuditLog{}).Count(&count).Error + require.NoError(t, err) + assert.Equal(t, int64(50), count) +} + +func TestAuditLog_TableName(t *testing.T) { + log := models.AuditLog{} + assert.Equal(t, "audit_log", log.TableName()) +} + +func TestAuditLog_JSONMap_NilHandling(t *testing.T) { + db := testutil.SetupTestDB(t) + + // Create entry with nil details + entry := &models.AuditLog{ + EventType: "test", + CreatedAt: time.Now().UTC(), + } + err := db.Create(entry).Error + require.NoError(t, err) + + // Read it back + var found models.AuditLog + err = db.First(&found, entry.ID).Error + require.NoError(t, err) + assert.Nil(t, found.Details) +} diff --git a/internal/services/auth_refresh_test.go b/internal/services/auth_refresh_test.go new file mode 100644 index 0000000..ce6d6d8 --- /dev/null +++ b/internal/services/auth_refresh_test.go @@ -0,0 +1,173 @@ +package services + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/treytartt/honeydue-api/internal/config" + "github.com/treytartt/honeydue-api/internal/models" + "github.com/treytartt/honeydue-api/internal/repositories" +) + +func setupRefreshTestDB(t *testing.T) *gorm.DB { + t.Helper() + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + require.NoError(t, err) + + err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{}) + require.NoError(t, err) + + return db +} + +func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User { + t.Helper() + user := &models.User{ + Username: "refreshtest", + Email: "refresh@test.com", + IsActive: true, + } + require.NoError(t, user.SetPassword("password123")) + require.NoError(t, db.Create(user).Error) + return user +} + +func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken { + t.Helper() + token := &models.AuthToken{ + UserID: userID, + } + require.NoError(t, db.Create(token).Error) + + // Backdate the token's Created timestamp after creation to bypass autoCreateTime + backdated := time.Now().UTC().AddDate(0, 0, -ageDays) + require.NoError(t, db.Model(token).Update("created", backdated).Error) + token.Created = backdated + + return token +} + +func newTestAuthService(db *gorm.DB) *AuthService { + userRepo := repositories.NewUserRepository(db) + cfg := &config.Config{ + Security: config.SecurityConfig{ + SecretKey: "test-secret", + TokenExpiryDays: 90, + TokenRefreshDays: 60, + }, + } + return NewAuthService(userRepo, cfg) +} + +func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window + + svc := newTestAuthService(db) + + resp, err := svc.RefreshToken(token.Key, user.ID) + require.NoError(t, err) + assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token") + assert.Contains(t, resp.Message, "still valid") +} + +func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90) + + svc := newTestAuthService(db) + + resp, err := svc.RefreshToken(token.Key, user.ID) + require.NoError(t, err) + assert.NotEqual(t, token.Key, resp.Token, "should return a new token") + assert.Contains(t, resp.Message, "refreshed") + + // Verify old token was deleted + var count int64 + db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count) + assert.Equal(t, int64(0), count, "old token should be deleted") + + // Verify new token exists in DB + db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count) + assert.Equal(t, int64(1), count, "new token should exist in DB") + + // Verify new token belongs to the same user + var newToken models.AuthToken + require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error) + assert.Equal(t, user.ID, newToken.UserID) +} + +func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry + + svc := newTestAuthService(db) + + resp, err := svc.RefreshToken(token.Key, user.ID) + require.Error(t, err) + assert.Nil(t, resp) + assert.Contains(t, err.Error(), "error.token_expired") +} + +func TestRefreshToken_AtExactBoundary60Days(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + // Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false, + // meaning it enters the renewal window + token := createTokenWithAge(t, db, user.ID, 61) + + svc := newTestAuthService(db) + + resp, err := svc.RefreshToken(token.Key, user.ID) + require.NoError(t, err) + assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed") +} + +func TestRefreshToken_InvalidToken_Returns401(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + + svc := newTestAuthService(db) + + resp, err := svc.RefreshToken("nonexistent-token-key", user.ID) + require.Error(t, err) + assert.Nil(t, resp) + assert.Contains(t, err.Error(), "error.invalid_token") +} + +func TestRefreshToken_WrongUser_Returns401(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + token := createTokenWithAge(t, db, user.ID, 75) + + svc := newTestAuthService(db) + + // Try to refresh with a different user ID + resp, err := svc.RefreshToken(token.Key, user.ID+999) + require.Error(t, err) + assert.Nil(t, resp) + assert.Contains(t, err.Error(), "error.invalid_token") +} + +func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) { + db := setupRefreshTestDB(t) + user := createRefreshTestUser(t, db) + token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold + + svc := newTestAuthService(db) + + resp, err := svc.RefreshToken(token.Key, user.ID) + require.NoError(t, err) + assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed") +} diff --git a/internal/services/auth_service.go b/internal/services/auth_service.go index edd2f2d..75bd21c 100644 --- a/internal/services/auth_service.go +++ b/internal/services/auth_service.go @@ -188,6 +188,66 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist }, code, nil } +// RefreshToken handles token refresh logic. +// - If token is expired (> expiryDays old), returns error (must re-login). +// - If token is in the renewal window (> refreshDays old), generates a new token. +// - If token is still fresh (< refreshDays old), returns the existing token (no-op). +func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) { + expiryDays := s.cfg.Security.TokenExpiryDays + if expiryDays <= 0 { + expiryDays = 90 + } + refreshDays := s.cfg.Security.TokenRefreshDays + if refreshDays <= 0 { + refreshDays = 60 + } + + // Look up the token + authToken, err := s.userRepo.FindTokenByKey(tokenKey) + if err != nil { + return nil, apperrors.Unauthorized("error.invalid_token") + } + + // Verify ownership + if authToken.UserID != userID { + return nil, apperrors.Unauthorized("error.invalid_token") + } + + tokenAge := time.Since(authToken.Created) + expiryDuration := time.Duration(expiryDays) * 24 * time.Hour + refreshDuration := time.Duration(refreshDays) * 24 * time.Hour + + // Token is expired — must re-login + if tokenAge > expiryDuration { + return nil, apperrors.Unauthorized("error.token_expired") + } + + // Token is still fresh — no-op refresh + if tokenAge < refreshDuration { + return &responses.RefreshTokenResponse{ + Token: tokenKey, + Message: "Token is still valid.", + }, nil + } + + // Token is in the renewal window — generate a new one + // Delete the old token + if err := s.userRepo.DeleteToken(tokenKey); err != nil { + log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh") + } + + // Create a new token + newToken, err := s.userRepo.CreateToken(userID) + if err != nil { + return nil, apperrors.Internal(err) + } + + return &responses.RefreshTokenResponse{ + Token: newToken.Key, + Message: "Token refreshed successfully.", + }, nil +} + // Logout invalidates a user's token func (s *AuthService) Logout(token string) error { return s.userRepo.DeleteToken(token) diff --git a/internal/services/cache_service.go b/internal/services/cache_service.go index 800ebc8..7b05c20 100644 --- a/internal/services/cache_service.go +++ b/internal/services/cache_service.go @@ -141,6 +141,12 @@ func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL) } +// CacheAuthTokenWithCreated caches a user ID and token creation time for a token +func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error { + key := AuthTokenPrefix + token + return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL) +} + // GetCachedAuthToken gets a cached user ID for a token func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) { key := AuthTokenPrefix + token @@ -154,6 +160,24 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui return userID, err } +// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time. +// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format). +func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) { + key := AuthTokenPrefix + token + val, err := c.GetString(ctx, key) + if err != nil { + return 0, 0, err + } + + var userID uint + var createdUnix int64 + n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix) + if n < 1 { + return 0, 0, fmt.Errorf("invalid cached token format") + } + return userID, createdUnix, nil +} + // InvalidateAuthToken removes a cached token func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error { key := AuthTokenPrefix + token diff --git a/internal/services/email_service.go b/internal/services/email_service.go index b3656a4..5609219 100644 --- a/internal/services/email_service.go +++ b/internal/services/email_service.go @@ -28,6 +28,7 @@ func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService { mail.WithUsername(cfg.User), mail.WithPassword(cfg.Password), mail.WithTLSPortPolicy(mail.TLSOpportunistic), + mail.WithTimeout(30*time.Second), ) if err != nil { log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent") diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index eb9e6f6..663d843 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -69,6 +69,7 @@ func SetupTestDB(t *testing.T) *gorm.DB { &models.FeatureBenefit{}, &models.UpgradeTrigger{}, &models.Promotion{}, + &models.AuditLog{}, ) require.NoError(t, err) diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 058f0bb..1a63ce9 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -4,6 +4,7 @@ import ( "net/http" "reflect" "strings" + "unicode" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" @@ -27,9 +28,34 @@ func NewCustomValidator() *CustomValidator { return name }) + // Register custom password complexity validator + v.RegisterValidation("password_complexity", validatePasswordComplexity) + return &CustomValidator{validator: v} } +// validatePasswordComplexity checks that a password contains at least one +// uppercase letter, one lowercase letter, and one digit. +// Minimum length is enforced separately via the "min" tag. +func validatePasswordComplexity(fl validator.FieldLevel) bool { + password := fl.Field().String() + var hasUpper, hasLower, hasDigit bool + for _, ch := range password { + switch { + case unicode.IsUpper(ch): + hasUpper = true + case unicode.IsLower(ch): + hasLower = true + case unicode.IsDigit(ch): + hasDigit = true + } + if hasUpper && hasLower && hasDigit { + return true + } + } + return hasUpper && hasLower && hasDigit +} + // Validate implements echo.Validator interface func (cv *CustomValidator) Validate(i interface{}) error { if err := cv.validator.Struct(i); err != nil { @@ -96,6 +122,8 @@ func formatMessage(fe validator.FieldError) string { return "Must be a valid URL" case "uuid": return "Must be a valid UUID" + case "password_complexity": + return "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit" default: return "Invalid value" } diff --git a/internal/validator/validator_test.go b/internal/validator/validator_test.go new file mode 100644 index 0000000..de70398 --- /dev/null +++ b/internal/validator/validator_test.go @@ -0,0 +1,115 @@ +package validator + +import ( + "testing" + + govalidator "github.com/go-playground/validator/v10" +) + +func TestValidatePasswordComplexity(t *testing.T) { + tests := []struct { + name string + password string + valid bool + }{ + {"valid password", "Password1", true}, + {"valid complex password", "MyP@ssw0rd!", true}, + {"missing uppercase", "password1", false}, + {"missing lowercase", "PASSWORD1", false}, + {"missing digit", "Password", false}, + {"only digits", "12345678", false}, + {"only lowercase", "abcdefgh", false}, + {"only uppercase", "ABCDEFGH", false}, + {"empty string", "", false}, + {"single valid char each", "aA1", true}, + {"unicode uppercase with digit and lower", "Über1abc", true}, + } + + v := govalidator.New() + v.RegisterValidation("password_complexity", validatePasswordComplexity) + + type testStruct struct { + Password string `validate:"password_complexity"` + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s := testStruct{Password: tc.password} + err := v.Struct(s) + if tc.valid && err != nil { + t.Errorf("expected password %q to be valid, got error: %v", tc.password, err) + } + if !tc.valid && err == nil { + t.Errorf("expected password %q to be invalid, got nil error", tc.password) + } + }) + } +} + +func TestValidatePasswordComplexityWithMinLength(t *testing.T) { + v := govalidator.New() + v.RegisterValidation("password_complexity", validatePasswordComplexity) + + type request struct { + Password string `validate:"required,min=8,password_complexity"` + } + + tests := []struct { + name string + password string + valid bool + }{ + {"valid 8+ chars with complexity", "Abcdefg1", true}, + {"too short but complex", "Ab1", false}, + {"long but no uppercase", "abcdefgh1", false}, + {"long but no lowercase", "ABCDEFGH1", false}, + {"long but no digit", "Abcdefghi", false}, + {"empty", "", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := request{Password: tc.password} + err := v.Struct(r) + if tc.valid && err != nil { + t.Errorf("expected %q to be valid, got error: %v", tc.password, err) + } + if !tc.valid && err == nil { + t.Errorf("expected %q to be invalid, got nil", tc.password) + } + }) + } +} + +func TestFormatMessagePasswordComplexity(t *testing.T) { + cv := NewCustomValidator() + + type request struct { + Password string `json:"password" validate:"required,min=8,password_complexity"` + } + + r := request{Password: "lowercase1"} + err := cv.Validate(r) + if err == nil { + t.Fatal("expected validation error for password without uppercase") + } + + resp := FormatValidationErrors(err) + if resp == nil { + t.Fatal("expected non-nil error response") + } + + field, ok := resp.Fields["password"] + if !ok { + t.Fatal("expected 'password' field in error response") + } + + expectedMsg := "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit" + if field.Message != expectedMsg { + t.Errorf("expected message %q, got %q", expectedMsg, field.Message) + } + + if field.Tag != "password_complexity" { + t.Errorf("expected tag 'password_complexity', got %q", field.Tag) + } +} diff --git a/migrations/016_authtoken_created_at.down.sql b/migrations/016_authtoken_created_at.down.sql new file mode 100644 index 0000000..48c2acf --- /dev/null +++ b/migrations/016_authtoken_created_at.down.sql @@ -0,0 +1,2 @@ +-- No-op: the created column is part of Django's original schema and should not +-- be removed. diff --git a/migrations/016_authtoken_created_at.up.sql b/migrations/016_authtoken_created_at.up.sql new file mode 100644 index 0000000..25d2900 --- /dev/null +++ b/migrations/016_authtoken_created_at.up.sql @@ -0,0 +1,6 @@ +-- Ensure created column exists on user_authtoken (Django already creates it, +-- but this migration guarantees it for fresh Go-only deployments). +ALTER TABLE user_authtoken ADD COLUMN IF NOT EXISTS created TIMESTAMP WITH TIME ZONE DEFAULT NOW(); + +-- Backfill any rows that may have a NULL created timestamp. +UPDATE user_authtoken SET created = NOW() WHERE created IS NULL; diff --git a/migrations/017_fk_indexes.down.sql b/migrations/017_fk_indexes.down.sql new file mode 100644 index 0000000..2fa6068 --- /dev/null +++ b/migrations/017_fk_indexes.down.sql @@ -0,0 +1,40 @@ +-- Rollback: 017_fk_indexes +-- Drop all FK indexes added in the up migration. + +-- auth / user tables +DROP INDEX IF EXISTS idx_authtoken_user_id; +DROP INDEX IF EXISTS idx_userprofile_user_id; +DROP INDEX IF EXISTS idx_confirmationcode_user_id; +DROP INDEX IF EXISTS idx_passwordresetcode_user_id; +DROP INDEX IF EXISTS idx_applesocialauth_user_id; +DROP INDEX IF EXISTS idx_googlesocialauth_user_id; + +-- push notification device tables +DROP INDEX IF EXISTS idx_apnsdevice_user_id; +DROP INDEX IF EXISTS idx_gcmdevice_user_id; + +-- notification tables +DROP INDEX IF EXISTS idx_notificationpreference_user_id; + +-- subscription tables +DROP INDEX IF EXISTS idx_subscription_user_id; + +-- residence tables +DROP INDEX IF EXISTS idx_residence_owner_id; +DROP INDEX IF EXISTS idx_sharecode_residence_id; +DROP INDEX IF EXISTS idx_sharecode_created_by_id; + +-- task tables +DROP INDEX IF EXISTS idx_task_created_by_id; +DROP INDEX IF EXISTS idx_task_assigned_to_id; +DROP INDEX IF EXISTS idx_task_category_id; +DROP INDEX IF EXISTS idx_task_priority_id; +DROP INDEX IF EXISTS idx_task_frequency_id; +DROP INDEX IF EXISTS idx_task_contractor_id; +DROP INDEX IF EXISTS idx_task_parent_task_id; +DROP INDEX IF EXISTS idx_completionimage_completion_id; +DROP INDEX IF EXISTS idx_document_created_by_id; +DROP INDEX IF EXISTS idx_document_task_id; +DROP INDEX IF EXISTS idx_documentimage_document_id; +DROP INDEX IF EXISTS idx_contractor_residence_id; +DROP INDEX IF EXISTS idx_reminderlog_notification_id; diff --git a/migrations/017_fk_indexes.up.sql b/migrations/017_fk_indexes.up.sql new file mode 100644 index 0000000..0e3c62f --- /dev/null +++ b/migrations/017_fk_indexes.up.sql @@ -0,0 +1,131 @@ +-- Migration: 017_fk_indexes +-- Add indexes on all foreign key columns that are not already covered by existing indexes. +-- Uses CREATE INDEX IF NOT EXISTS to be idempotent (safe to re-run). + +-- ===================================================== +-- auth / user tables +-- ===================================================== + +-- user_authtoken: user_id (unique FK, but ensure index exists) +CREATE UNIQUE INDEX IF NOT EXISTS idx_authtoken_user_id + ON user_authtoken (user_id); + +-- user_userprofile: user_id (unique FK) +CREATE UNIQUE INDEX IF NOT EXISTS idx_userprofile_user_id + ON user_userprofile (user_id); + +-- user_confirmationcode: user_id +CREATE INDEX IF NOT EXISTS idx_confirmationcode_user_id + ON user_confirmationcode (user_id); + +-- user_passwordresetcode: user_id +CREATE INDEX IF NOT EXISTS idx_passwordresetcode_user_id + ON user_passwordresetcode (user_id); + +-- user_applesocialauth: user_id (unique FK) +CREATE UNIQUE INDEX IF NOT EXISTS idx_applesocialauth_user_id + ON user_applesocialauth (user_id); + +-- user_googlesocialauth: user_id (unique FK) +CREATE UNIQUE INDEX IF NOT EXISTS idx_googlesocialauth_user_id + ON user_googlesocialauth (user_id); + +-- ===================================================== +-- push notification device tables +-- ===================================================== + +-- push_notifications_apnsdevice: user_id +CREATE INDEX IF NOT EXISTS idx_apnsdevice_user_id + ON push_notifications_apnsdevice (user_id); + +-- push_notifications_gcmdevice: user_id +CREATE INDEX IF NOT EXISTS idx_gcmdevice_user_id + ON push_notifications_gcmdevice (user_id); + +-- ===================================================== +-- notification tables +-- ===================================================== + +-- notifications_notificationpreference: user_id (unique FK) +CREATE UNIQUE INDEX IF NOT EXISTS idx_notificationpreference_user_id + ON notifications_notificationpreference (user_id); + +-- ===================================================== +-- subscription tables +-- ===================================================== + +-- subscription_usersubscription: user_id (unique FK) +CREATE UNIQUE INDEX IF NOT EXISTS idx_subscription_user_id + ON subscription_usersubscription (user_id); + +-- ===================================================== +-- residence tables +-- ===================================================== + +-- residence_residence: owner_id +CREATE INDEX IF NOT EXISTS idx_residence_owner_id + ON residence_residence (owner_id); + +-- residence_residencesharecode: residence_id (may already exist from model index tag via GORM) +CREATE INDEX IF NOT EXISTS idx_sharecode_residence_id + ON residence_residencesharecode (residence_id); + +-- residence_residencesharecode: created_by_id +CREATE INDEX IF NOT EXISTS idx_sharecode_created_by_id + ON residence_residencesharecode (created_by_id); + +-- ===================================================== +-- task tables +-- ===================================================== + +-- task_task: created_by_id +CREATE INDEX IF NOT EXISTS idx_task_created_by_id + ON task_task (created_by_id); + +-- task_task: assigned_to_id +CREATE INDEX IF NOT EXISTS idx_task_assigned_to_id + ON task_task (assigned_to_id); + +-- task_task: category_id +CREATE INDEX IF NOT EXISTS idx_task_category_id + ON task_task (category_id); + +-- task_task: priority_id +CREATE INDEX IF NOT EXISTS idx_task_priority_id + ON task_task (priority_id); + +-- task_task: frequency_id +CREATE INDEX IF NOT EXISTS idx_task_frequency_id + ON task_task (frequency_id); + +-- task_task: contractor_id +CREATE INDEX IF NOT EXISTS idx_task_contractor_id + ON task_task (contractor_id); + +-- task_task: parent_task_id +CREATE INDEX IF NOT EXISTS idx_task_parent_task_id + ON task_task (parent_task_id); + +-- task_taskcompletionimage: completion_id +CREATE INDEX IF NOT EXISTS idx_completionimage_completion_id + ON task_taskcompletionimage (completion_id); + +-- task_document: created_by_id +CREATE INDEX IF NOT EXISTS idx_document_created_by_id + ON task_document (created_by_id); + +-- task_document: task_id +CREATE INDEX IF NOT EXISTS idx_document_task_id + ON task_document (task_id); + +-- task_documentimage: document_id +CREATE INDEX IF NOT EXISTS idx_documentimage_document_id + ON task_documentimage (document_id); + +-- task_contractor: residence_id +CREATE INDEX IF NOT EXISTS idx_contractor_residence_id + ON task_contractor (residence_id); + +-- task_reminderlog: notification_id +CREATE INDEX IF NOT EXISTS idx_reminderlog_notification_id + ON task_reminderlog (notification_id); diff --git a/migrations/018_audit_log.down.sql b/migrations/018_audit_log.down.sql new file mode 100644 index 0000000..55f8feb --- /dev/null +++ b/migrations/018_audit_log.down.sql @@ -0,0 +1,2 @@ +-- Rollback: 018_audit_log +DROP TABLE IF EXISTS audit_log; diff --git a/migrations/018_audit_log.up.sql b/migrations/018_audit_log.up.sql new file mode 100644 index 0000000..4a9ff9e --- /dev/null +++ b/migrations/018_audit_log.up.sql @@ -0,0 +1,16 @@ +-- Migration: 018_audit_log +-- Create audit_log table for tracking security-relevant events (login, register, etc.) + +CREATE TABLE IF NOT EXISTS audit_log ( + id SERIAL PRIMARY KEY, + user_id INTEGER, + event_type VARCHAR(50) NOT NULL, + ip_address VARCHAR(45), + user_agent TEXT, + details JSONB, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX idx_audit_log_user_id ON audit_log(user_id); +CREATE INDEX idx_audit_log_event_type ON audit_log(event_type); +CREATE INDEX idx_audit_log_created_at ON audit_log(created_at);