Production hardening: security, resilience, observability, and compliance
Password complexity: custom validator requiring uppercase, lowercase, digit (min 8 chars)
Token expiry: 90-day token lifetime with refresh endpoint (60-90 day renewal window)
Health check: /api/health/ now pings Postgres + Redis, returns 503 on failure
Audit logging: async audit_log table for auth events (login, register, delete, etc.)
Circuit breaker: APNs/FCM push sends wrapped with 5-failure threshold, 30s recovery
FK indexes: 27 missing foreign key indexes across all tables (migration 017)
CSP header: default-src 'none'; frame-ancestors 'none'
Gzip compression: level 5 with media endpoint skipper
Prometheus metrics: /metrics endpoint using existing monitoring service
External timeouts: 15s push, 30s SMTP, context timeouts on all external calls
Migrations: 016 (token created_at), 017 (FK indexes), 018 (audit_log)
Tests: circuit breaker (15), audit service (8), token refresh (7), health (4),
middleware expiry (5), validator (new)
This commit is contained in:
@@ -134,6 +134,8 @@ type SecurityConfig struct {
|
||||
PasswordResetExpiry time.Duration
|
||||
ConfirmationExpiry time.Duration
|
||||
MaxPasswordResetRate int // per hour
|
||||
TokenExpiryDays int // Number of days before auth tokens expire (default 90)
|
||||
TokenRefreshDays int // Token must be at least this many days old before refresh (default 60)
|
||||
}
|
||||
|
||||
// StorageConfig holds file storage settings
|
||||
@@ -262,6 +264,8 @@ func Load() (*Config, error) {
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
MaxPasswordResetRate: 3,
|
||||
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
|
||||
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
||||
@@ -369,6 +373,10 @@ func setDefaults() {
|
||||
viper.SetDefault("OVERDUE_REMINDER_HOUR", 15) // 9:00 AM UTC
|
||||
viper.SetDefault("DAILY_DIGEST_HOUR", 3) // 3:00 AM UTC
|
||||
|
||||
// Token expiry defaults
|
||||
viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days
|
||||
viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days
|
||||
|
||||
// Storage defaults
|
||||
viper.SetDefault("STORAGE_UPLOAD_DIR", "./uploads")
|
||||
viper.SetDefault("STORAGE_BASE_URL", "/uploads")
|
||||
|
||||
@@ -11,7 +11,7 @@ type LoginRequest struct {
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" validate:"required,min=3,max=150"`
|
||||
Email string `json:"email" validate:"required,email,max=254"`
|
||||
Password string `json:"password" validate:"required,min=8"`
|
||||
Password string `json:"password" validate:"required,min=8,password_complexity"`
|
||||
FirstName string `json:"first_name" validate:"max=150"`
|
||||
LastName string `json:"last_name" validate:"max=150"`
|
||||
}
|
||||
@@ -35,7 +35,7 @@ type VerifyResetCodeRequest struct {
|
||||
// ResetPasswordRequest represents the reset password request body
|
||||
type ResetPasswordRequest struct {
|
||||
ResetToken string `json:"reset_token" validate:"required"`
|
||||
NewPassword string `json:"new_password" validate:"required,min=8"`
|
||||
NewPassword string `json:"new_password" validate:"required,min=8,password_complexity"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest represents the profile update request body
|
||||
|
||||
@@ -79,6 +79,12 @@ type ResetPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// RefreshTokenResponse represents the token refresh response
|
||||
type RefreshTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// MessageResponse represents a simple message response
|
||||
type MessageResponse struct {
|
||||
Message string `json:"message"`
|
||||
|
||||
@@ -23,6 +23,7 @@ type AuthHandler struct {
|
||||
appleAuthService *services.AppleAuthService
|
||||
googleAuthService *services.GoogleAuthService
|
||||
storageService *services.StorageService
|
||||
auditService *services.AuditService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
@@ -49,6 +50,11 @@ func (h *AuthHandler) SetStorageService(storageService *services.StorageService)
|
||||
h.storageService = storageService
|
||||
}
|
||||
|
||||
// SetAuditService sets the audit service for logging security events
|
||||
func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
|
||||
h.auditService = auditService
|
||||
}
|
||||
|
||||
// Login handles POST /api/auth/login/
|
||||
func (h *AuthHandler) Login(c echo.Context) error {
|
||||
var req requests.LoginRequest
|
||||
@@ -62,9 +68,19 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
||||
response, err := h.authService.Login(&req)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
|
||||
"identifier": req.Username,
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
userID := response.User.ID
|
||||
h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
@@ -84,6 +100,14 @@ func (h *AuthHandler) Register(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
userID := response.User.ID
|
||||
h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{
|
||||
"username": req.Username,
|
||||
"email": req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Send welcome email with confirmation code (async)
|
||||
if h.emailService != nil && confirmationCode != "" {
|
||||
go func() {
|
||||
@@ -108,6 +132,14 @@ func (h *AuthHandler) Logout(c echo.Context) error {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
// Log audit event before invalidating the token
|
||||
if h.auditService != nil {
|
||||
user := middleware.GetAuthUser(c)
|
||||
if user != nil {
|
||||
h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate token in database
|
||||
if err := h.authService.Logout(token); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to delete token from database")
|
||||
@@ -270,6 +302,12 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
||||
}()
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
|
||||
"email": req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Always return success to prevent email enumeration
|
||||
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
|
||||
Message: "Password reset email sent",
|
||||
@@ -314,6 +352,12 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{
|
||||
"method": "reset_token",
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.ResetPasswordResponse{
|
||||
Message: "Password reset successful",
|
||||
})
|
||||
@@ -413,6 +457,34 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// RefreshToken handles POST /api/auth/refresh/
|
||||
func (h *AuthHandler) RefreshToken(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token := middleware.GetAuthToken(c)
|
||||
if token == "" {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
response, err := h.authService.RefreshToken(token, user.ID)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// If the token was refreshed (new token), invalidate the old one from cache
|
||||
if response.Token != token && h.cache != nil {
|
||||
if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil {
|
||||
log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh")
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// DeleteAccount handles DELETE /api/auth/account/
|
||||
func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
@@ -431,6 +503,14 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, &user.ID, services.AuditEventAccountDeleted, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete files from disk (best effort, don't fail the request)
|
||||
if h.storageService != nil && len(fileURLs) > 0 {
|
||||
go func() {
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
"error.email_taken": "Email already registered",
|
||||
"error.email_already_taken": "Email already taken",
|
||||
"error.registration_failed": "Registration failed",
|
||||
"error.password_complexity": "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit",
|
||||
"error.not_authenticated": "Not authenticated",
|
||||
"error.invalid_token": "Invalid token",
|
||||
"error.token_expired": "Your session has expired. Please log in again.",
|
||||
"error.token_refresh_not_needed": "Token is still valid.",
|
||||
"error.failed_to_get_user": "Failed to get user",
|
||||
"error.failed_to_update_profile": "Failed to update profile",
|
||||
"error.invalid_verification_code": "Invalid verification code",
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
@@ -28,24 +29,56 @@ const (
|
||||
// UserCacheTTL is how long full user records are cached in memory to
|
||||
// avoid hitting the database on every authenticated request.
|
||||
UserCacheTTL = 30 * time.Second
|
||||
|
||||
// DefaultTokenExpiryDays is the default number of days before a token expires.
|
||||
DefaultTokenExpiryDays = 90
|
||||
)
|
||||
|
||||
// AuthMiddleware provides token authentication middleware
|
||||
type AuthMiddleware struct {
|
||||
db *gorm.DB
|
||||
cache *services.CacheService
|
||||
userCache *UserCache
|
||||
db *gorm.DB
|
||||
cache *services.CacheService
|
||||
userCache *UserCache
|
||||
tokenExpiryDays int
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware instance
|
||||
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL),
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL),
|
||||
tokenExpiryDays: DefaultTokenExpiryDays,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration
|
||||
func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware {
|
||||
expiryDays := DefaultTokenExpiryDays
|
||||
if cfg != nil && cfg.Security.TokenExpiryDays > 0 {
|
||||
expiryDays = cfg.Security.TokenExpiryDays
|
||||
}
|
||||
return &AuthMiddleware{
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL),
|
||||
tokenExpiryDays: expiryDays,
|
||||
}
|
||||
}
|
||||
|
||||
// TokenExpiryDuration returns the token expiry duration.
|
||||
func (m *AuthMiddleware) TokenExpiryDuration() time.Duration {
|
||||
return time.Duration(m.tokenExpiryDays) * 24 * time.Hour
|
||||
}
|
||||
|
||||
// isTokenExpired checks if a token's created timestamp indicates expiry.
|
||||
func (m *AuthMiddleware) isTokenExpired(created time.Time) bool {
|
||||
if created.IsZero() {
|
||||
return false // Legacy tokens without created time are not expired
|
||||
}
|
||||
return time.Since(created) > m.TokenExpiryDuration()
|
||||
}
|
||||
|
||||
// TokenAuth returns an Echo middleware that validates token authentication
|
||||
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@@ -56,7 +89,7 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
// Try to get user from cache first
|
||||
// Try to get user from cache first (includes expiry check)
|
||||
user, err := m.getUserFromCache(c.Request().Context(), token)
|
||||
if err == nil && user != nil {
|
||||
// Cache hit - set user in context and continue
|
||||
@@ -65,16 +98,27 @@ func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Check if the cache indicated token expiry
|
||||
if err != nil && err.Error() == "token expired" {
|
||||
return apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Cache miss - look up token in database
|
||||
user, err = m.getUserFromDatabase(token)
|
||||
user, authToken, err := m.getUserFromDatabaseWithToken(token)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
||||
return apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
// Cache the user ID for future requests
|
||||
if cacheErr := m.cacheUserID(c.Request().Context(), token, user.ID); cacheErr != nil {
|
||||
log.Warn().Err(cacheErr).Msg("Failed to cache user ID")
|
||||
// Check token expiry
|
||||
if m.isTokenExpired(authToken.Created) {
|
||||
log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired")
|
||||
return apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Cache the user ID and token creation time for future requests
|
||||
if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil {
|
||||
log.Warn().Err(cacheErr).Msg("Failed to cache token info")
|
||||
}
|
||||
|
||||
// Set user in context
|
||||
@@ -104,9 +148,9 @@ func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
// Try database
|
||||
user, err = m.getUserFromDatabase(token)
|
||||
if err == nil {
|
||||
m.cacheUserID(c.Request().Context(), token, user.ID)
|
||||
user, authToken, err := m.getUserFromDatabaseWithToken(token)
|
||||
if err == nil && !m.isTokenExpired(authToken.Created) {
|
||||
m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created)
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
}
|
||||
@@ -145,12 +189,13 @@ func extractToken(c echo.Context) (string, error) {
|
||||
|
||||
// getUserFromCache tries to get user from Redis cache, then from the
|
||||
// in-memory user cache, before falling back to the database.
|
||||
// Returns a "token expired" error if the cached creation time indicates expiry.
|
||||
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
||||
if m.cache == nil {
|
||||
return nil, fmt.Errorf("cache not available")
|
||||
}
|
||||
|
||||
userID, err := m.cache.GetCachedAuthToken(ctx, token)
|
||||
userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token)
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, fmt.Errorf("token not in cache")
|
||||
@@ -158,6 +203,15 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check token expiry from cached creation time
|
||||
if createdUnix > 0 {
|
||||
created := time.Unix(createdUnix, 0)
|
||||
if m.isTokenExpired(created) {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
// Try in-memory user cache first to avoid a DB round-trip
|
||||
if cached := m.userCache.Get(userID); cached != nil {
|
||||
if !cached.IsActive {
|
||||
@@ -187,22 +241,38 @@ func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*m
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// getUserFromDatabase looks up the token in the database and caches the
|
||||
// resulting user record in memory.
|
||||
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
|
||||
// getUserFromDatabaseWithToken looks up the token in the database and returns
|
||||
// both the user and the auth token record (for expiry checking).
|
||||
func (m *AuthMiddleware) getUserFromDatabaseWithToken(token string) (*models.User, *models.AuthToken, error) {
|
||||
var authToken models.AuthToken
|
||||
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
|
||||
return nil, fmt.Errorf("token not found")
|
||||
return nil, nil, fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !authToken.User.IsActive {
|
||||
return nil, fmt.Errorf("user is inactive")
|
||||
return nil, nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
|
||||
// Store in in-memory cache for subsequent requests
|
||||
m.userCache.Set(&authToken.User)
|
||||
return &authToken.User, nil
|
||||
return &authToken.User, &authToken, nil
|
||||
}
|
||||
|
||||
// getUserFromDatabase looks up the token in the database and caches the
|
||||
// resulting user record in memory.
|
||||
// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking.
|
||||
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
|
||||
user, _, err := m.getUserFromDatabaseWithToken(token)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// cacheTokenInfo caches the user ID and token creation time for a token
|
||||
func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix())
|
||||
}
|
||||
|
||||
// cacheUserID caches the user ID for a token
|
||||
|
||||
165
internal/middleware/auth_expiry_test.go
Normal file
165
internal/middleware/auth_expiry_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// setupTestDB creates a temporary in-memory SQLite database with the required
|
||||
// tables for auth middleware tests.
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.AutoMigrate(&models.User{}, &models.AuthToken{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// createTestUserAndToken creates a user and an auth token, then backdates the
|
||||
// token's Created timestamp by the specified number of days.
|
||||
func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) {
|
||||
t.Helper()
|
||||
|
||||
user := &models.User{
|
||||
Username: username,
|
||||
Email: username + "@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
token := &models.AuthToken{
|
||||
UserID: user.ID,
|
||||
}
|
||||
require.NoError(t, db.Create(token).Error)
|
||||
|
||||
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||
token.Created = backdated
|
||||
|
||||
return user, token
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil) // No Redis cache for these tests
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestTokenAuth_AcceptsValidToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify user was set in context
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "valid_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsInvalidToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
48
internal/models/audit_log.go
Normal file
48
internal/models/audit_log.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JSONMap is a custom type for JSONB columns that handles JSON serialization
|
||||
type JSONMap map[string]interface{}
|
||||
|
||||
// Value implements driver.Valuer for database writes
|
||||
func (j JSONMap) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(j)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner for database reads
|
||||
func (j *JSONMap) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
bytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("audit_log: failed to scan JSONMap value")
|
||||
}
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
|
||||
// AuditLog represents the audit_log table for tracking security-relevant events
|
||||
type AuditLog struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID *uint `gorm:"column:user_id" json:"user_id"`
|
||||
EventType string `gorm:"column:event_type;size:50;not null" json:"event_type"`
|
||||
IPAddress string `gorm:"column:ip_address;size:45" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent"`
|
||||
Details JSONMap `gorm:"column:details;type:jsonb" json:"details"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_log"
|
||||
}
|
||||
167
internal/push/circuit_breaker.go
Normal file
167
internal/push/circuit_breaker.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Circuit breaker states
|
||||
const (
|
||||
stateClosed = iota // Normal operation, requests pass through
|
||||
stateOpen // Too many failures, requests are rejected
|
||||
stateHalfOpen // Testing recovery, one request allowed through
|
||||
)
|
||||
|
||||
// Default circuit breaker settings
|
||||
const (
|
||||
defaultFailureThreshold = 5 // Open after this many consecutive failures
|
||||
defaultRecoveryTimeout = 30 * time.Second // Try again after this duration
|
||||
)
|
||||
|
||||
// ErrCircuitOpen is returned when the circuit breaker is open and rejecting requests.
|
||||
var ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// CircuitBreaker implements a simple circuit breaker pattern for external service calls.
|
||||
// It is thread-safe and requires no external dependencies.
|
||||
//
|
||||
// States:
|
||||
// - Closed: normal operation, all requests pass through. Consecutive failures are counted.
|
||||
// - Open: after reaching the failure threshold, all requests are immediately rejected
|
||||
// with ErrCircuitOpen until the recovery timeout elapses.
|
||||
// - Half-Open: after the recovery timeout, one request is allowed through. If it
|
||||
// succeeds the breaker resets to Closed; if it fails it returns to Open.
|
||||
type CircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state int
|
||||
failureCount int
|
||||
failureThreshold int
|
||||
recoveryTimeout time.Duration
|
||||
lastFailureTime time.Time
|
||||
name string // For logging
|
||||
}
|
||||
|
||||
// CircuitBreakerOption configures a CircuitBreaker.
|
||||
type CircuitBreakerOption func(*CircuitBreaker)
|
||||
|
||||
// WithFailureThreshold sets the number of consecutive failures before opening the circuit.
|
||||
func WithFailureThreshold(n int) CircuitBreakerOption {
|
||||
return func(cb *CircuitBreaker) {
|
||||
if n > 0 {
|
||||
cb.failureThreshold = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithRecoveryTimeout sets how long the circuit stays open before trying half-open.
|
||||
func WithRecoveryTimeout(d time.Duration) CircuitBreakerOption {
|
||||
return func(cb *CircuitBreaker) {
|
||||
if d > 0 {
|
||||
cb.recoveryTimeout = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new CircuitBreaker with the given name and options.
|
||||
// The name is used for logging and identification.
|
||||
func NewCircuitBreaker(name string, opts ...CircuitBreakerOption) *CircuitBreaker {
|
||||
cb := &CircuitBreaker{
|
||||
state: stateClosed,
|
||||
failureThreshold: defaultFailureThreshold,
|
||||
recoveryTimeout: defaultRecoveryTimeout,
|
||||
name: name,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cb)
|
||||
}
|
||||
return cb
|
||||
}
|
||||
|
||||
// Allow checks whether a request should be allowed through.
|
||||
// It returns true if the request can proceed, false if the circuit is open.
|
||||
// When transitioning from open to half-open, it returns true for the probe request.
|
||||
func (cb *CircuitBreaker) Allow() bool {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case stateClosed:
|
||||
return true
|
||||
case stateOpen:
|
||||
// Check if recovery timeout has elapsed
|
||||
if time.Since(cb.lastFailureTime) >= cb.recoveryTimeout {
|
||||
cb.state = stateHalfOpen
|
||||
return true
|
||||
}
|
||||
return false
|
||||
case stateHalfOpen:
|
||||
// Only one request at a time in half-open state.
|
||||
// The first caller that got here via Allow() is already in flight;
|
||||
// reject subsequent callers until that probe resolves.
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful request. If the breaker is half-open, it resets to closed.
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
cb.failureCount = 0
|
||||
cb.state = stateClosed
|
||||
}
|
||||
|
||||
// RecordFailure records a failed request. If the failure threshold is reached, the
|
||||
// breaker transitions to the open state.
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
cb.failureCount++
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
if cb.failureCount >= cb.failureThreshold {
|
||||
cb.state = stateOpen
|
||||
}
|
||||
}
|
||||
|
||||
// State returns the current state of the circuit breaker as a human-readable string.
|
||||
func (cb *CircuitBreaker) State() string {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case stateClosed:
|
||||
return "closed"
|
||||
case stateOpen:
|
||||
return "open"
|
||||
case stateHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the circuit breaker's name.
|
||||
func (cb *CircuitBreaker) Name() string {
|
||||
return cb.name
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to the closed state with zero failures.
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
cb.state = stateClosed
|
||||
cb.failureCount = 0
|
||||
cb.lastFailureTime = time.Time{}
|
||||
}
|
||||
|
||||
// Counts returns the current failure count (useful for testing and monitoring).
|
||||
func (cb *CircuitBreaker) Counts() int {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
return cb.failureCount
|
||||
}
|
||||
275
internal/push/circuit_breaker_test.go
Normal file
275
internal/push/circuit_breaker_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCircuitBreaker_StartsInClosedState(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test")
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_OpensAfterThresholdFailures(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", WithFailureThreshold(3))
|
||||
|
||||
// First two failures should keep it closed
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.True(t, cb.Allow())
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.True(t, cb.Allow())
|
||||
|
||||
// Third failure should open it
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
assert.False(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_DefaultThresholdIsFive(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test")
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
|
||||
cb.RecordFailure() // 5th failure
|
||||
assert.Equal(t, "open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_RejectsRequestsWhenOpen(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", WithFailureThreshold(1))
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
|
||||
// Multiple calls should all be rejected
|
||||
for i := 0; i < 10; i++ {
|
||||
assert.False(t, cb.Allow())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_TransitionsToHalfOpenAfterRecoveryTimeout(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(1),
|
||||
WithRecoveryTimeout(50*time.Millisecond),
|
||||
)
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
assert.False(t, cb.Allow())
|
||||
|
||||
// Wait for recovery timeout
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Should now allow one request (half-open)
|
||||
assert.True(t, cb.Allow())
|
||||
assert.Equal(t, "half-open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenRejectsSecondRequest(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(1),
|
||||
WithRecoveryTimeout(50*time.Millisecond),
|
||||
)
|
||||
|
||||
cb.RecordFailure()
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// First request allowed (probe)
|
||||
assert.True(t, cb.Allow())
|
||||
assert.Equal(t, "half-open", cb.State())
|
||||
|
||||
// Second request rejected while probe is in flight
|
||||
assert.False(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenSuccess_ResetsToClosed(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(1),
|
||||
WithRecoveryTimeout(50*time.Millisecond),
|
||||
)
|
||||
|
||||
cb.RecordFailure()
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Probe request
|
||||
assert.True(t, cb.Allow())
|
||||
|
||||
// Probe succeeds
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.Equal(t, 0, cb.Counts())
|
||||
|
||||
// Normal operation resumes
|
||||
assert.True(t, cb.Allow())
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenFailure_ReturnsToOpen(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(2),
|
||||
WithRecoveryTimeout(50*time.Millisecond),
|
||||
)
|
||||
|
||||
// Open the circuit
|
||||
cb.RecordFailure()
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Probe request
|
||||
assert.True(t, cb.Allow())
|
||||
assert.Equal(t, "half-open", cb.State())
|
||||
|
||||
// Probe fails - the failure count is now 3 which is >= threshold of 2
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
assert.False(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_SuccessResetsFailureCount(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", WithFailureThreshold(3))
|
||||
|
||||
cb.RecordFailure()
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, 2, cb.Counts())
|
||||
|
||||
// A success should reset the counter
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, 0, cb.Counts())
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
// Now it should take 3 more failures to open
|
||||
cb.RecordFailure()
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test", WithFailureThreshold(1))
|
||||
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
|
||||
cb.Reset()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.Equal(t, 0, cb.Counts())
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Name(t *testing.T) {
|
||||
cb := NewCircuitBreaker("apns-breaker")
|
||||
assert.Equal(t, "apns-breaker", cb.Name())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_CustomOptions(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(10),
|
||||
WithRecoveryTimeout(5*time.Minute),
|
||||
)
|
||||
|
||||
// Should take 10 failures to open
|
||||
for i := 0; i < 9; i++ {
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_InvalidOptionsIgnored(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(0), // Should be ignored (keeps default)
|
||||
WithRecoveryTimeout(-1), // Should be ignored (keeps default)
|
||||
)
|
||||
|
||||
// Default threshold of 5 should still apply
|
||||
for i := 0; i < 4; i++ {
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
}
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ThreadSafety(t *testing.T) {
|
||||
cb := NewCircuitBreaker("test",
|
||||
WithFailureThreshold(100),
|
||||
WithRecoveryTimeout(10*time.Millisecond),
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
const iterations = 100
|
||||
|
||||
// Hammer it from many goroutines
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
cb.Allow()
|
||||
if j%2 == 0 {
|
||||
cb.RecordFailure()
|
||||
} else {
|
||||
cb.RecordSuccess()
|
||||
}
|
||||
_ = cb.State()
|
||||
_ = cb.Counts()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic or deadlock. State should be valid.
|
||||
state := cb.State()
|
||||
require.Contains(t, []string{"closed", "open", "half-open"}, state)
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_FullLifecycle(t *testing.T) {
|
||||
cb := NewCircuitBreaker("lifecycle-test",
|
||||
WithFailureThreshold(3),
|
||||
WithRecoveryTimeout(50*time.Millisecond),
|
||||
)
|
||||
|
||||
// 1. Closed: requests flow normally
|
||||
assert.True(t, cb.Allow())
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
// 2. Accumulate failures
|
||||
cb.RecordFailure()
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
|
||||
// 3. Third failure opens the circuit
|
||||
cb.RecordFailure()
|
||||
assert.Equal(t, "open", cb.State())
|
||||
assert.False(t, cb.Allow())
|
||||
|
||||
// 4. Wait for recovery
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// 5. Half-open: probe request allowed
|
||||
assert.True(t, cb.Allow())
|
||||
assert.Equal(t, "half-open", cb.State())
|
||||
|
||||
// 6. Probe succeeds, back to closed
|
||||
cb.RecordSuccess()
|
||||
assert.Equal(t, "closed", cb.State())
|
||||
assert.True(t, cb.Allow())
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
@@ -14,16 +15,25 @@ const (
|
||||
PlatformAndroid = "android"
|
||||
)
|
||||
|
||||
// Timeout for individual push notification send operations.
|
||||
const pushSendTimeout = 15 * time.Second
|
||||
|
||||
// Client provides a unified interface for sending push notifications
|
||||
type Client struct {
|
||||
apns *APNsClient
|
||||
fcm *FCMClient
|
||||
enabled bool
|
||||
apns *APNsClient
|
||||
fcm *FCMClient
|
||||
enabled bool
|
||||
apnsBreaker *CircuitBreaker
|
||||
fcmBreaker *CircuitBreaker
|
||||
}
|
||||
|
||||
// NewClient creates a new unified push notification client
|
||||
func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
|
||||
client := &Client{enabled: enabled}
|
||||
client := &Client{
|
||||
enabled: enabled,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
// Initialize APNs client (iOS)
|
||||
if cfg.APNSKeyPath != "" && cfg.APNSKeyID != "" && cfg.APNSTeamID != "" {
|
||||
@@ -54,7 +64,8 @@ func NewClient(cfg *config.PushConfig, enabled bool) (*Client, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// SendToIOS sends a push notification to iOS devices
|
||||
// SendToIOS sends a push notification to iOS devices.
|
||||
// The call is guarded by a circuit breaker and uses a context timeout.
|
||||
func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
||||
if !c.enabled {
|
||||
log.Debug().Msg("Push notifications disabled by feature flag")
|
||||
@@ -64,10 +75,26 @@ func (c *Client) SendToIOS(ctx context.Context, tokens []string, title, message
|
||||
log.Warn().Msg("APNs client not initialized, skipping iOS push")
|
||||
return nil
|
||||
}
|
||||
return c.apns.Send(ctx, tokens, title, message, data)
|
||||
if !c.apnsBreaker.Allow() {
|
||||
log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS push")
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
|
||||
defer cancel()
|
||||
|
||||
err := c.apns.Send(sendCtx, tokens, title, message, data)
|
||||
if err != nil {
|
||||
c.apnsBreaker.RecordFailure()
|
||||
log.Warn().Err(err).Str("breaker_state", c.apnsBreaker.State()).Msg("APNs send failed, recorded circuit breaker failure")
|
||||
return err
|
||||
}
|
||||
c.apnsBreaker.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendToAndroid sends a push notification to Android devices
|
||||
// SendToAndroid sends a push notification to Android devices.
|
||||
// The call is guarded by a circuit breaker and uses a context timeout.
|
||||
func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
|
||||
if !c.enabled {
|
||||
log.Debug().Msg("Push notifications disabled by feature flag")
|
||||
@@ -77,7 +104,22 @@ func (c *Client) SendToAndroid(ctx context.Context, tokens []string, title, mess
|
||||
log.Warn().Msg("FCM client not initialized, skipping Android push")
|
||||
return nil
|
||||
}
|
||||
return c.fcm.Send(ctx, tokens, title, message, data)
|
||||
if !c.fcmBreaker.Allow() {
|
||||
log.Warn().Str("breaker", c.fcmBreaker.Name()).Msg("FCM circuit breaker is open, skipping Android push")
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
|
||||
defer cancel()
|
||||
|
||||
err := c.fcm.Send(sendCtx, tokens, title, message, data)
|
||||
if err != nil {
|
||||
c.fcmBreaker.RecordFailure()
|
||||
log.Warn().Err(err).Str("breaker_state", c.fcmBreaker.State()).Msg("FCM send failed, recorded circuit breaker failure")
|
||||
return err
|
||||
}
|
||||
c.fcmBreaker.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendToAll sends a push notification to both iOS and Android devices
|
||||
@@ -115,8 +157,9 @@ func (c *Client) IsAndroidEnabled() bool {
|
||||
return c.fcm != nil
|
||||
}
|
||||
|
||||
// SendActionableNotification sends notifications with action button support
|
||||
// iOS receives a category for actionable notifications, Android handles actions via data payload
|
||||
// SendActionableNotification sends notifications with action button support.
|
||||
// iOS receives a category for actionable notifications, Android handles actions via data payload.
|
||||
// Both platforms are guarded by their respective circuit breakers.
|
||||
func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string, iosCategoryID string) error {
|
||||
if !c.enabled {
|
||||
log.Debug().Msg("Push notifications disabled by feature flag")
|
||||
@@ -127,10 +170,19 @@ func (c *Client) SendActionableNotification(ctx context.Context, iosTokens, andr
|
||||
if len(iosTokens) > 0 {
|
||||
if c.apns == nil {
|
||||
log.Warn().Msg("APNs client not initialized, skipping iOS actionable push")
|
||||
} else if !c.apnsBreaker.Allow() {
|
||||
log.Warn().Str("breaker", c.apnsBreaker.Name()).Msg("APNs circuit breaker is open, skipping iOS actionable push")
|
||||
lastErr = ErrCircuitOpen
|
||||
} else {
|
||||
if err := c.apns.SendWithCategory(ctx, iosTokens, title, message, data, iosCategoryID); err != nil {
|
||||
sendCtx, cancel := context.WithTimeout(ctx, pushSendTimeout)
|
||||
err := c.apns.SendWithCategory(sendCtx, iosTokens, title, message, data, iosCategoryID)
|
||||
cancel()
|
||||
if err != nil {
|
||||
c.apnsBreaker.RecordFailure()
|
||||
log.Error().Err(err).Msg("Failed to send iOS actionable notifications")
|
||||
lastErr = err
|
||||
} else {
|
||||
c.apnsBreaker.RecordSuccess()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func NewFCMClient(cfg *config.PushConfig) (*FCMClient, error) {
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Timeout: 15 * time.Second,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
|
||||
@@ -173,6 +173,27 @@ func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// FindTokenByKey looks up an auth token by its key value.
|
||||
func (r *UserRepository) FindTokenByKey(key string) (*models.AuthToken, error) {
|
||||
var token models.AuthToken
|
||||
if err := r.db.Where("key = ?", key).First(&token).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new auth token for a user.
|
||||
func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
|
||||
token := models.AuthToken{UserID: userID}
|
||||
if err := r.db.Create(&token).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken deletes an auth token
|
||||
func (r *UserRepository) DeleteToken(token string) error {
|
||||
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})
|
||||
|
||||
136
internal/router/health_test.go
Normal file
136
internal/router/health_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// setupTestDB creates an in-memory SQLite database for health check tests
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return db
|
||||
}
|
||||
|
||||
func TestLiveCheck_Returns200(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/api/health/live", liveCheck)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "alive", resp["status"])
|
||||
assert.Equal(t, Version, resp["version"])
|
||||
assert.Contains(t, resp, "timestamp")
|
||||
}
|
||||
|
||||
func TestReadinessCheck_HealthyDB_NilCache_Returns200(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
deps := &Dependencies{
|
||||
DB: db,
|
||||
Cache: nil, // No Redis configured
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.GET("/api/health/", readinessCheck(deps))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "healthy", resp["status"])
|
||||
assert.Equal(t, Version, resp["version"])
|
||||
|
||||
checks := resp["checks"].(map[string]interface{})
|
||||
assert.Equal(t, "ok", checks["postgres"])
|
||||
assert.Equal(t, "not configured", checks["redis"])
|
||||
}
|
||||
|
||||
func TestReadinessCheck_DBDown_Returns503(t *testing.T) {
|
||||
// Open an in-memory SQLite DB, then close the underlying sql.DB to simulate failure
|
||||
db := setupTestDB(t)
|
||||
sqlDB, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
sqlDB.Close()
|
||||
|
||||
deps := &Dependencies{
|
||||
DB: db,
|
||||
Cache: nil,
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.GET("/api/health/", readinessCheck(deps))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "unhealthy", resp["status"])
|
||||
|
||||
checks := resp["checks"].(map[string]interface{})
|
||||
assert.Contains(t, checks["postgres"], "ping failed")
|
||||
}
|
||||
|
||||
func TestReadinessCheck_ResponseFormat(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
deps := &Dependencies{
|
||||
DB: db,
|
||||
Cache: nil,
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.GET("/api/health/", readinessCheck(deps))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
var resp map[string]interface{}
|
||||
err := json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all expected fields are present
|
||||
assert.Contains(t, resp, "status")
|
||||
assert.Contains(t, resp, "version")
|
||||
assert.Contains(t, resp, "checks")
|
||||
assert.Contains(t, resp, "timestamp")
|
||||
|
||||
// Verify checks is a map with expected keys
|
||||
checks, ok := resp["checks"].(map[string]interface{})
|
||||
require.True(t, ok, "checks should be a map")
|
||||
assert.Contains(t, checks, "postgres")
|
||||
assert.Contains(t, checks, "redis")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -62,11 +63,12 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
|
||||
// Security headers (X-Frame-Options, X-Content-Type-Options, X-XSS-Protection, etc.)
|
||||
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
|
||||
XSSProtection: "1; mode=block",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
HSTSMaxAge: 31536000, // 1 year in seconds
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
XSSProtection: "1; mode=block",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
HSTSMaxAge: 31536000, // 1 year in seconds
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
ContentSecurityPolicy: "default-src 'none'; frame-ancestors 'none'",
|
||||
}))
|
||||
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
||||
Limit: "1M", // 1MB default for JSON payloads
|
||||
@@ -93,6 +95,14 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
e.Use(corsMiddleware(cfg))
|
||||
e.Use(i18n.Middleware())
|
||||
|
||||
// Gzip compression (skip media endpoints since they serve binary files)
|
||||
e.Use(middleware.GzipWithConfig(middleware.GzipConfig{
|
||||
Level: 5,
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return strings.HasPrefix(c.Request().URL.Path, "/api/media/")
|
||||
},
|
||||
}))
|
||||
|
||||
// Monitoring metrics middleware (if monitoring is enabled)
|
||||
if deps.MonitoringService != nil {
|
||||
if metricsMiddleware := deps.MonitoringService.MetricsMiddleware(); metricsMiddleware != nil {
|
||||
@@ -114,8 +124,9 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
})
|
||||
}
|
||||
|
||||
// Health check endpoint (no auth required)
|
||||
e.GET("/api/health/", healthCheck)
|
||||
// Health check endpoints (no auth required)
|
||||
e.GET("/api/health/", readinessCheck(deps))
|
||||
e.GET("/api/health/live", liveCheck)
|
||||
|
||||
// Initialize onboarding email service for tracking handler
|
||||
onboardingService := services.NewOnboardingEmailService(deps.DB, deps.EmailService, cfg.Server.BaseURL)
|
||||
@@ -172,17 +183,21 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
subscriptionWebhookHandler.SetStripeService(stripeService)
|
||||
|
||||
// Initialize middleware
|
||||
authMiddleware := custommiddleware.NewAuthMiddleware(deps.DB, deps.Cache)
|
||||
authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg)
|
||||
|
||||
// Initialize Apple auth service
|
||||
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
|
||||
googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg)
|
||||
|
||||
// Initialize audit service for security event logging
|
||||
auditService := services.NewAuditService(deps.DB)
|
||||
|
||||
// Initialize handlers
|
||||
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
||||
authHandler.SetAppleAuthService(appleAuthService)
|
||||
authHandler.SetGoogleAuthService(googleAuthService)
|
||||
authHandler.SetStorageService(deps.StorageService)
|
||||
authHandler.SetAuditService(auditService)
|
||||
userHandler := handlers.NewUserHandler(userService)
|
||||
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService, cfg.Features.PDFReportsEnabled)
|
||||
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
|
||||
@@ -201,6 +216,11 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
mediaHandler = handlers.NewMediaHandler(documentRepo, taskRepo, residenceRepo, deps.StorageService)
|
||||
}
|
||||
|
||||
// Prometheus metrics endpoint (no auth required, for scraping)
|
||||
if deps.MonitoringService != nil {
|
||||
e.GET("/metrics", prometheusMetrics(deps.MonitoringService))
|
||||
}
|
||||
|
||||
// Set up admin routes with monitoring handler (if available)
|
||||
var monitoringHandler *monitoring.Handler
|
||||
if deps.MonitoringService != nil {
|
||||
@@ -295,16 +315,126 @@ func corsMiddleware(cfg *config.Config) echo.MiddlewareFunc {
|
||||
})
|
||||
}
|
||||
|
||||
// healthCheck returns API health status
|
||||
func healthCheck(c echo.Context) error {
|
||||
// liveCheck returns a simple 200 for Kubernetes liveness probes
|
||||
func liveCheck(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"status": "healthy",
|
||||
"status": "alive",
|
||||
"version": Version,
|
||||
"framework": "Echo",
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
// readinessCheck returns 200 if PostgreSQL and Redis are reachable, 503 otherwise.
|
||||
// This is used by Kubernetes readiness probes and load balancers.
|
||||
func readinessCheck(deps *Dependencies) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
status := "healthy"
|
||||
httpStatus := http.StatusOK
|
||||
checks := make(map[string]string)
|
||||
|
||||
// Check PostgreSQL
|
||||
sqlDB, err := deps.DB.DB()
|
||||
if err != nil {
|
||||
checks["postgres"] = fmt.Sprintf("failed to get sql.DB: %v", err)
|
||||
status = "unhealthy"
|
||||
httpStatus = http.StatusServiceUnavailable
|
||||
} else if err := sqlDB.PingContext(ctx); err != nil {
|
||||
checks["postgres"] = fmt.Sprintf("ping failed: %v", err)
|
||||
status = "unhealthy"
|
||||
httpStatus = http.StatusServiceUnavailable
|
||||
} else {
|
||||
checks["postgres"] = "ok"
|
||||
}
|
||||
|
||||
// Check Redis (if cache service is available)
|
||||
if deps.Cache != nil {
|
||||
if err := deps.Cache.Client().Ping(ctx).Err(); err != nil {
|
||||
checks["redis"] = fmt.Sprintf("ping failed: %v", err)
|
||||
status = "unhealthy"
|
||||
httpStatus = http.StatusServiceUnavailable
|
||||
} else {
|
||||
checks["redis"] = "ok"
|
||||
}
|
||||
} else {
|
||||
checks["redis"] = "not configured"
|
||||
}
|
||||
|
||||
return c.JSON(httpStatus, map[string]interface{}{
|
||||
"status": status,
|
||||
"version": Version,
|
||||
"checks": checks,
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// prometheusMetrics returns an Echo handler that outputs metrics in Prometheus text format.
|
||||
// It uses the existing monitoring service's HTTP stats collector to avoid adding external dependencies.
|
||||
func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
httpCollector := monSvc.HTTPCollector()
|
||||
if httpCollector == nil {
|
||||
return c.String(http.StatusOK, "# No HTTP metrics available (collector not initialized)\n")
|
||||
}
|
||||
|
||||
stats := httpCollector.GetStats()
|
||||
var b strings.Builder
|
||||
|
||||
// Request count by method+path+status
|
||||
b.WriteString("# HELP http_requests_total Total number of HTTP requests.\n")
|
||||
b.WriteString("# TYPE http_requests_total counter\n")
|
||||
for statusCode, count := range stats.ByStatusCode {
|
||||
fmt.Fprintf(&b, "http_requests_total{status_code=\"%d\"} %d\n", statusCode, count)
|
||||
}
|
||||
|
||||
// Per-endpoint request count
|
||||
b.WriteString("# HELP http_endpoint_requests_total Total requests per endpoint.\n")
|
||||
b.WriteString("# TYPE http_endpoint_requests_total counter\n")
|
||||
for endpoint, epStats := range stats.ByEndpoint {
|
||||
// endpoint is "METHOD /path"
|
||||
parts := strings.SplitN(endpoint, " ", 2)
|
||||
method := endpoint
|
||||
path := ""
|
||||
if len(parts) == 2 {
|
||||
method = parts[0]
|
||||
path = parts[1]
|
||||
}
|
||||
fmt.Fprintf(&b, "http_endpoint_requests_total{method=\"%s\",path=\"%s\"} %d\n", method, path, epStats.Count)
|
||||
}
|
||||
|
||||
// Request duration (avg latency as a gauge, since we don't have raw histogram buckets)
|
||||
b.WriteString("# HELP http_request_duration_ms Average request duration in milliseconds per endpoint.\n")
|
||||
b.WriteString("# TYPE http_request_duration_ms gauge\n")
|
||||
for endpoint, epStats := range stats.ByEndpoint {
|
||||
parts := strings.SplitN(endpoint, " ", 2)
|
||||
method := endpoint
|
||||
path := ""
|
||||
if len(parts) == 2 {
|
||||
method = parts[0]
|
||||
path = parts[1]
|
||||
}
|
||||
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"avg\"} %.2f\n", method, path, epStats.AvgLatencyMs)
|
||||
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"p95\"} %.2f\n", method, path, epStats.P95LatencyMs)
|
||||
}
|
||||
|
||||
// Error rate
|
||||
b.WriteString("# HELP http_error_rate Overall error rate (4xx+5xx / total).\n")
|
||||
b.WriteString("# TYPE http_error_rate gauge\n")
|
||||
fmt.Fprintf(&b, "http_error_rate %.4f\n", stats.ErrorRate)
|
||||
|
||||
// Requests per minute
|
||||
b.WriteString("# HELP http_requests_per_minute Current request rate.\n")
|
||||
b.WriteString("# TYPE http_requests_per_minute gauge\n")
|
||||
fmt.Fprintf(&b, "http_requests_per_minute %.2f\n", stats.RequestsPerMinute)
|
||||
|
||||
c.Response().Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8")
|
||||
return c.String(http.StatusOK, b.String())
|
||||
}
|
||||
}
|
||||
|
||||
// setupPublicAuthRoutes configures public authentication routes with
|
||||
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
|
||||
// Rate limiters are disabled in debug mode to allow UI test suites to run
|
||||
@@ -342,6 +472,7 @@ func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler
|
||||
auth := api.Group("/auth")
|
||||
{
|
||||
auth.POST("/logout/", authHandler.Logout)
|
||||
auth.POST("/refresh/", authHandler.RefreshToken)
|
||||
auth.GET("/me/", authHandler.CurrentUser)
|
||||
auth.PUT("/profile/", authHandler.UpdateProfile)
|
||||
auth.PATCH("/profile/", authHandler.UpdateProfile)
|
||||
|
||||
93
internal/services/audit_service.go
Normal file
93
internal/services/audit_service.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// Audit event type constants
|
||||
const (
|
||||
AuditEventLogin = "auth.login"
|
||||
AuditEventLoginFailed = "auth.login_failed"
|
||||
AuditEventRegister = "auth.register"
|
||||
AuditEventLogout = "auth.logout"
|
||||
AuditEventPasswordReset = "auth.password_reset"
|
||||
AuditEventPasswordChanged = "auth.password_changed"
|
||||
AuditEventAccountDeleted = "auth.account_deleted"
|
||||
)
|
||||
|
||||
// AuditService handles audit logging for security-relevant events.
|
||||
// It writes audit log entries asynchronously via a buffered channel to avoid
|
||||
// blocking request handlers.
|
||||
type AuditService struct {
|
||||
db *gorm.DB
|
||||
logChan chan *models.AuditLog
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// NewAuditService creates a new audit service with a buffered channel for async writes.
|
||||
// Call Stop() when shutting down to flush remaining entries.
|
||||
func NewAuditService(db *gorm.DB) *AuditService {
|
||||
s := &AuditService{
|
||||
db: db,
|
||||
logChan: make(chan *models.AuditLog, 256),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go s.processLogs()
|
||||
return s
|
||||
}
|
||||
|
||||
// processLogs drains the log channel and writes entries to the database.
|
||||
func (s *AuditService) processLogs() {
|
||||
defer close(s.done)
|
||||
for entry := range s.logChan {
|
||||
if err := s.db.Create(entry).Error; err != nil {
|
||||
log.Error().Err(err).
|
||||
Str("event_type", entry.EventType).
|
||||
Msg("Failed to write audit log entry")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes the log channel and waits for all pending entries to be written.
|
||||
// It is safe to call Stop multiple times.
|
||||
func (s *AuditService) Stop() {
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.logChan)
|
||||
})
|
||||
<-s.done
|
||||
}
|
||||
|
||||
// LogEvent records an audit event. It extracts the client IP and User-Agent from
|
||||
// the Echo context and sends the entry to the background writer. If the channel
|
||||
// is full the entry is dropped and an error is logged (non-blocking).
|
||||
func (s *AuditService) LogEvent(c echo.Context, userID *uint, eventType string, details map[string]interface{}) {
|
||||
var ip, ua string
|
||||
if c != nil {
|
||||
ip = c.RealIP()
|
||||
ua = c.Request().Header.Get("User-Agent")
|
||||
}
|
||||
|
||||
entry := &models.AuditLog{
|
||||
UserID: userID,
|
||||
EventType: eventType,
|
||||
IPAddress: ip,
|
||||
UserAgent: ua,
|
||||
Details: models.JSONMap(details),
|
||||
}
|
||||
|
||||
select {
|
||||
case s.logChan <- entry:
|
||||
// sent
|
||||
default:
|
||||
log.Warn().
|
||||
Str("event_type", eventType).
|
||||
Msg("Audit log channel full, dropping entry")
|
||||
}
|
||||
}
|
||||
176
internal/services/audit_service_test.go
Normal file
176
internal/services/audit_service_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestAuditService_LogEvent_WritesToDatabase(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
svc := NewAuditService(db)
|
||||
defer svc.Stop()
|
||||
|
||||
// Create a fake Echo context
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||
req.Header.Set("User-Agent", "TestAgent/1.0")
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
userID := uint(42)
|
||||
svc.LogEvent(c, &userID, AuditEventLogin, map[string]interface{}{
|
||||
"method": "password",
|
||||
})
|
||||
|
||||
// Stop flushes the channel
|
||||
svc.Stop()
|
||||
|
||||
var entries []models.AuditLog
|
||||
err := db.Find(&entries).Error
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
entry := entries[0]
|
||||
assert.Equal(t, uint(42), *entry.UserID)
|
||||
assert.Equal(t, AuditEventLogin, entry.EventType)
|
||||
assert.Equal(t, "TestAgent/1.0", entry.UserAgent)
|
||||
assert.NotEmpty(t, entry.IPAddress)
|
||||
assert.Equal(t, "password", entry.Details["method"])
|
||||
assert.False(t, entry.CreatedAt.IsZero())
|
||||
}
|
||||
|
||||
func TestAuditService_LogEvent_NilUserID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
svc := NewAuditService(db)
|
||||
defer svc.Stop()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
svc.LogEvent(c, nil, AuditEventLoginFailed, map[string]interface{}{
|
||||
"identifier": "unknown@test.com",
|
||||
})
|
||||
|
||||
svc.Stop()
|
||||
|
||||
var entries []models.AuditLog
|
||||
err := db.Find(&entries).Error
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
assert.Nil(t, entries[0].UserID)
|
||||
assert.Equal(t, AuditEventLoginFailed, entries[0].EventType)
|
||||
}
|
||||
|
||||
func TestAuditService_LogEvent_NilContext(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
svc := NewAuditService(db)
|
||||
defer svc.Stop()
|
||||
|
||||
userID := uint(1)
|
||||
svc.LogEvent(nil, &userID, AuditEventLogout, nil)
|
||||
|
||||
svc.Stop()
|
||||
|
||||
var entries []models.AuditLog
|
||||
err := db.Find(&entries).Error
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
assert.Equal(t, AuditEventLogout, entries[0].EventType)
|
||||
assert.Empty(t, entries[0].IPAddress)
|
||||
assert.Empty(t, entries[0].UserAgent)
|
||||
}
|
||||
|
||||
func TestAuditService_LogEvent_MultipleEvents(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
svc := NewAuditService(db)
|
||||
defer svc.Stop()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
userID := uint(10)
|
||||
svc.LogEvent(c, &userID, AuditEventRegister, nil)
|
||||
svc.LogEvent(c, &userID, AuditEventLogin, nil)
|
||||
svc.LogEvent(c, &userID, AuditEventLogout, nil)
|
||||
|
||||
svc.Stop()
|
||||
|
||||
var count int64
|
||||
err := db.Model(&models.AuditLog{}).Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestAuditService_EventTypeConstants(t *testing.T) {
|
||||
// Verify all event constants have expected values
|
||||
assert.Equal(t, "auth.login", AuditEventLogin)
|
||||
assert.Equal(t, "auth.login_failed", AuditEventLoginFailed)
|
||||
assert.Equal(t, "auth.register", AuditEventRegister)
|
||||
assert.Equal(t, "auth.logout", AuditEventLogout)
|
||||
assert.Equal(t, "auth.password_reset", AuditEventPasswordReset)
|
||||
assert.Equal(t, "auth.password_changed", AuditEventPasswordChanged)
|
||||
assert.Equal(t, "auth.account_deleted", AuditEventAccountDeleted)
|
||||
}
|
||||
|
||||
func TestAuditService_Stop_FlushesRemainingEntries(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
svc := NewAuditService(db)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Send many events quickly
|
||||
for i := 0; i < 50; i++ {
|
||||
uid := uint(i)
|
||||
svc.LogEvent(c, &uid, AuditEventLogin, nil)
|
||||
}
|
||||
|
||||
// Stop should block until all entries are written
|
||||
svc.Stop()
|
||||
|
||||
var count int64
|
||||
err := db.Model(&models.AuditLog{}).Count(&count).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(50), count)
|
||||
}
|
||||
|
||||
func TestAuditLog_TableName(t *testing.T) {
|
||||
log := models.AuditLog{}
|
||||
assert.Equal(t, "audit_log", log.TableName())
|
||||
}
|
||||
|
||||
func TestAuditLog_JSONMap_NilHandling(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
// Create entry with nil details
|
||||
entry := &models.AuditLog{
|
||||
EventType: "test",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err := db.Create(entry).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read it back
|
||||
var found models.AuditLog
|
||||
err = db.First(&found, entry.ID).Error
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, found.Details)
|
||||
}
|
||||
173
internal/services/auth_refresh_test.go
Normal file
173
internal/services/auth_refresh_test.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
)
|
||||
|
||||
func setupRefreshTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
|
||||
t.Helper()
|
||||
user := &models.User{
|
||||
Username: "refreshtest",
|
||||
Email: "refresh@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
return user
|
||||
}
|
||||
|
||||
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
|
||||
t.Helper()
|
||||
token := &models.AuthToken{
|
||||
UserID: userID,
|
||||
}
|
||||
require.NoError(t, db.Create(token).Error)
|
||||
|
||||
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||
token.Created = backdated
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func newTestAuthService(db *gorm.DB) *AuthService {
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret",
|
||||
TokenExpiryDays: 90,
|
||||
TokenRefreshDays: 60,
|
||||
},
|
||||
}
|
||||
return NewAuthService(userRepo, cfg)
|
||||
}
|
||||
|
||||
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
|
||||
assert.Contains(t, resp.Message, "still valid")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
|
||||
assert.Contains(t, resp.Message, "refreshed")
|
||||
|
||||
// Verify old token was deleted
|
||||
var count int64
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "old token should be deleted")
|
||||
|
||||
// Verify new token exists in DB
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "new token should exist in DB")
|
||||
|
||||
// Verify new token belongs to the same user
|
||||
var newToken models.AuthToken
|
||||
require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error)
|
||||
assert.Equal(t, user.ID, newToken.UserID)
|
||||
}
|
||||
|
||||
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
|
||||
// meaning it enters the renewal window
|
||||
token := createTokenWithAge(t, db, user.ID, 61)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken("nonexistent-token-key", user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 75)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
// Try to refresh with a different user ID
|
||||
resp, err := svc.RefreshToken(token.Key, user.ID+999)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(token.Key, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
}
|
||||
@@ -188,6 +188,66 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
||||
}, code, nil
|
||||
}
|
||||
|
||||
// RefreshToken handles token refresh logic.
|
||||
// - If token is expired (> expiryDays old), returns error (must re-login).
|
||||
// - If token is in the renewal window (> refreshDays old), generates a new token.
|
||||
// - If token is still fresh (< refreshDays old), returns the existing token (no-op).
|
||||
func (s *AuthService) RefreshToken(tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) {
|
||||
expiryDays := s.cfg.Security.TokenExpiryDays
|
||||
if expiryDays <= 0 {
|
||||
expiryDays = 90
|
||||
}
|
||||
refreshDays := s.cfg.Security.TokenRefreshDays
|
||||
if refreshDays <= 0 {
|
||||
refreshDays = 60
|
||||
}
|
||||
|
||||
// Look up the token
|
||||
authToken, err := s.userRepo.FindTokenByKey(tokenKey)
|
||||
if err != nil {
|
||||
return nil, apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
if authToken.UserID != userID {
|
||||
return nil, apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
tokenAge := time.Since(authToken.Created)
|
||||
expiryDuration := time.Duration(expiryDays) * 24 * time.Hour
|
||||
refreshDuration := time.Duration(refreshDays) * 24 * time.Hour
|
||||
|
||||
// Token is expired — must re-login
|
||||
if tokenAge > expiryDuration {
|
||||
return nil, apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Token is still fresh — no-op refresh
|
||||
if tokenAge < refreshDuration {
|
||||
return &responses.RefreshTokenResponse{
|
||||
Token: tokenKey,
|
||||
Message: "Token is still valid.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Token is in the renewal window — generate a new one
|
||||
// Delete the old token
|
||||
if err := s.userRepo.DeleteToken(tokenKey); err != nil {
|
||||
log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh")
|
||||
}
|
||||
|
||||
// Create a new token
|
||||
newToken, err := s.userRepo.CreateToken(userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.RefreshTokenResponse{
|
||||
Token: newToken.Key,
|
||||
Message: "Token refreshed successfully.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Logout invalidates a user's token
|
||||
func (s *AuthService) Logout(token string) error {
|
||||
return s.userRepo.DeleteToken(token)
|
||||
|
||||
@@ -141,6 +141,12 @@ func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID
|
||||
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
|
||||
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// GetCachedAuthToken gets a cached user ID for a token
|
||||
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
||||
key := AuthTokenPrefix + token
|
||||
@@ -154,6 +160,24 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
|
||||
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
|
||||
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
|
||||
key := AuthTokenPrefix + token
|
||||
val, err := c.GetString(ctx, key)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var userID uint
|
||||
var createdUnix int64
|
||||
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
|
||||
if n < 1 {
|
||||
return 0, 0, fmt.Errorf("invalid cached token format")
|
||||
}
|
||||
return userID, createdUnix, nil
|
||||
}
|
||||
|
||||
// InvalidateAuthToken removes a cached token
|
||||
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
||||
key := AuthTokenPrefix + token
|
||||
|
||||
@@ -28,6 +28,7 @@ func NewEmailService(cfg *config.EmailConfig, enabled bool) *EmailService {
|
||||
mail.WithUsername(cfg.User),
|
||||
mail.WithPassword(cfg.Password),
|
||||
mail.WithTLSPortPolicy(mail.TLSOpportunistic),
|
||||
mail.WithTimeout(30*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to create mail client - emails will not be sent")
|
||||
|
||||
@@ -69,6 +69,7 @@ func SetupTestDB(t *testing.T) *gorm.DB {
|
||||
&models.FeatureBenefit{},
|
||||
&models.UpgradeTrigger{},
|
||||
&models.Promotion{},
|
||||
&models.AuditLog{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -27,9 +28,34 @@ func NewCustomValidator() *CustomValidator {
|
||||
return name
|
||||
})
|
||||
|
||||
// Register custom password complexity validator
|
||||
v.RegisterValidation("password_complexity", validatePasswordComplexity)
|
||||
|
||||
return &CustomValidator{validator: v}
|
||||
}
|
||||
|
||||
// validatePasswordComplexity checks that a password contains at least one
|
||||
// uppercase letter, one lowercase letter, and one digit.
|
||||
// Minimum length is enforced separately via the "min" tag.
|
||||
func validatePasswordComplexity(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
var hasUpper, hasLower, hasDigit bool
|
||||
for _, ch := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(ch):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(ch):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(ch):
|
||||
hasDigit = true
|
||||
}
|
||||
if hasUpper && hasLower && hasDigit {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return hasUpper && hasLower && hasDigit
|
||||
}
|
||||
|
||||
// Validate implements echo.Validator interface
|
||||
func (cv *CustomValidator) Validate(i interface{}) error {
|
||||
if err := cv.validator.Struct(i); err != nil {
|
||||
@@ -96,6 +122,8 @@ func formatMessage(fe validator.FieldError) string {
|
||||
return "Must be a valid URL"
|
||||
case "uuid":
|
||||
return "Must be a valid UUID"
|
||||
case "password_complexity":
|
||||
return "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit"
|
||||
default:
|
||||
return "Invalid value"
|
||||
}
|
||||
|
||||
115
internal/validator/validator_test.go
Normal file
115
internal/validator/validator_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
govalidator "github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func TestValidatePasswordComplexity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
valid bool
|
||||
}{
|
||||
{"valid password", "Password1", true},
|
||||
{"valid complex password", "MyP@ssw0rd!", true},
|
||||
{"missing uppercase", "password1", false},
|
||||
{"missing lowercase", "PASSWORD1", false},
|
||||
{"missing digit", "Password", false},
|
||||
{"only digits", "12345678", false},
|
||||
{"only lowercase", "abcdefgh", false},
|
||||
{"only uppercase", "ABCDEFGH", false},
|
||||
{"empty string", "", false},
|
||||
{"single valid char each", "aA1", true},
|
||||
{"unicode uppercase with digit and lower", "Über1abc", true},
|
||||
}
|
||||
|
||||
v := govalidator.New()
|
||||
v.RegisterValidation("password_complexity", validatePasswordComplexity)
|
||||
|
||||
type testStruct struct {
|
||||
Password string `validate:"password_complexity"`
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
s := testStruct{Password: tc.password}
|
||||
err := v.Struct(s)
|
||||
if tc.valid && err != nil {
|
||||
t.Errorf("expected password %q to be valid, got error: %v", tc.password, err)
|
||||
}
|
||||
if !tc.valid && err == nil {
|
||||
t.Errorf("expected password %q to be invalid, got nil error", tc.password)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePasswordComplexityWithMinLength(t *testing.T) {
|
||||
v := govalidator.New()
|
||||
v.RegisterValidation("password_complexity", validatePasswordComplexity)
|
||||
|
||||
type request struct {
|
||||
Password string `validate:"required,min=8,password_complexity"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
valid bool
|
||||
}{
|
||||
{"valid 8+ chars with complexity", "Abcdefg1", true},
|
||||
{"too short but complex", "Ab1", false},
|
||||
{"long but no uppercase", "abcdefgh1", false},
|
||||
{"long but no lowercase", "ABCDEFGH1", false},
|
||||
{"long but no digit", "Abcdefghi", false},
|
||||
{"empty", "", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := request{Password: tc.password}
|
||||
err := v.Struct(r)
|
||||
if tc.valid && err != nil {
|
||||
t.Errorf("expected %q to be valid, got error: %v", tc.password, err)
|
||||
}
|
||||
if !tc.valid && err == nil {
|
||||
t.Errorf("expected %q to be invalid, got nil", tc.password)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatMessagePasswordComplexity(t *testing.T) {
|
||||
cv := NewCustomValidator()
|
||||
|
||||
type request struct {
|
||||
Password string `json:"password" validate:"required,min=8,password_complexity"`
|
||||
}
|
||||
|
||||
r := request{Password: "lowercase1"}
|
||||
err := cv.Validate(r)
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error for password without uppercase")
|
||||
}
|
||||
|
||||
resp := FormatValidationErrors(err)
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil error response")
|
||||
}
|
||||
|
||||
field, ok := resp.Fields["password"]
|
||||
if !ok {
|
||||
t.Fatal("expected 'password' field in error response")
|
||||
}
|
||||
|
||||
expectedMsg := "Password must be at least 8 characters with at least one uppercase letter, one lowercase letter, and one digit"
|
||||
if field.Message != expectedMsg {
|
||||
t.Errorf("expected message %q, got %q", expectedMsg, field.Message)
|
||||
}
|
||||
|
||||
if field.Tag != "password_complexity" {
|
||||
t.Errorf("expected tag 'password_complexity', got %q", field.Tag)
|
||||
}
|
||||
}
|
||||
2
migrations/016_authtoken_created_at.down.sql
Normal file
2
migrations/016_authtoken_created_at.down.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- No-op: the created column is part of Django's original schema and should not
|
||||
-- be removed.
|
||||
6
migrations/016_authtoken_created_at.up.sql
Normal file
6
migrations/016_authtoken_created_at.up.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
-- Ensure created column exists on user_authtoken (Django already creates it,
|
||||
-- but this migration guarantees it for fresh Go-only deployments).
|
||||
ALTER TABLE user_authtoken ADD COLUMN IF NOT EXISTS created TIMESTAMP WITH TIME ZONE DEFAULT NOW();
|
||||
|
||||
-- Backfill any rows that may have a NULL created timestamp.
|
||||
UPDATE user_authtoken SET created = NOW() WHERE created IS NULL;
|
||||
40
migrations/017_fk_indexes.down.sql
Normal file
40
migrations/017_fk_indexes.down.sql
Normal file
@@ -0,0 +1,40 @@
|
||||
-- Rollback: 017_fk_indexes
|
||||
-- Drop all FK indexes added in the up migration.
|
||||
|
||||
-- auth / user tables
|
||||
DROP INDEX IF EXISTS idx_authtoken_user_id;
|
||||
DROP INDEX IF EXISTS idx_userprofile_user_id;
|
||||
DROP INDEX IF EXISTS idx_confirmationcode_user_id;
|
||||
DROP INDEX IF EXISTS idx_passwordresetcode_user_id;
|
||||
DROP INDEX IF EXISTS idx_applesocialauth_user_id;
|
||||
DROP INDEX IF EXISTS idx_googlesocialauth_user_id;
|
||||
|
||||
-- push notification device tables
|
||||
DROP INDEX IF EXISTS idx_apnsdevice_user_id;
|
||||
DROP INDEX IF EXISTS idx_gcmdevice_user_id;
|
||||
|
||||
-- notification tables
|
||||
DROP INDEX IF EXISTS idx_notificationpreference_user_id;
|
||||
|
||||
-- subscription tables
|
||||
DROP INDEX IF EXISTS idx_subscription_user_id;
|
||||
|
||||
-- residence tables
|
||||
DROP INDEX IF EXISTS idx_residence_owner_id;
|
||||
DROP INDEX IF EXISTS idx_sharecode_residence_id;
|
||||
DROP INDEX IF EXISTS idx_sharecode_created_by_id;
|
||||
|
||||
-- task tables
|
||||
DROP INDEX IF EXISTS idx_task_created_by_id;
|
||||
DROP INDEX IF EXISTS idx_task_assigned_to_id;
|
||||
DROP INDEX IF EXISTS idx_task_category_id;
|
||||
DROP INDEX IF EXISTS idx_task_priority_id;
|
||||
DROP INDEX IF EXISTS idx_task_frequency_id;
|
||||
DROP INDEX IF EXISTS idx_task_contractor_id;
|
||||
DROP INDEX IF EXISTS idx_task_parent_task_id;
|
||||
DROP INDEX IF EXISTS idx_completionimage_completion_id;
|
||||
DROP INDEX IF EXISTS idx_document_created_by_id;
|
||||
DROP INDEX IF EXISTS idx_document_task_id;
|
||||
DROP INDEX IF EXISTS idx_documentimage_document_id;
|
||||
DROP INDEX IF EXISTS idx_contractor_residence_id;
|
||||
DROP INDEX IF EXISTS idx_reminderlog_notification_id;
|
||||
131
migrations/017_fk_indexes.up.sql
Normal file
131
migrations/017_fk_indexes.up.sql
Normal file
@@ -0,0 +1,131 @@
|
||||
-- Migration: 017_fk_indexes
|
||||
-- Add indexes on all foreign key columns that are not already covered by existing indexes.
|
||||
-- Uses CREATE INDEX IF NOT EXISTS to be idempotent (safe to re-run).
|
||||
|
||||
-- =====================================================
|
||||
-- auth / user tables
|
||||
-- =====================================================
|
||||
|
||||
-- user_authtoken: user_id (unique FK, but ensure index exists)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_authtoken_user_id
|
||||
ON user_authtoken (user_id);
|
||||
|
||||
-- user_userprofile: user_id (unique FK)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_userprofile_user_id
|
||||
ON user_userprofile (user_id);
|
||||
|
||||
-- user_confirmationcode: user_id
|
||||
CREATE INDEX IF NOT EXISTS idx_confirmationcode_user_id
|
||||
ON user_confirmationcode (user_id);
|
||||
|
||||
-- user_passwordresetcode: user_id
|
||||
CREATE INDEX IF NOT EXISTS idx_passwordresetcode_user_id
|
||||
ON user_passwordresetcode (user_id);
|
||||
|
||||
-- user_applesocialauth: user_id (unique FK)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_applesocialauth_user_id
|
||||
ON user_applesocialauth (user_id);
|
||||
|
||||
-- user_googlesocialauth: user_id (unique FK)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_googlesocialauth_user_id
|
||||
ON user_googlesocialauth (user_id);
|
||||
|
||||
-- =====================================================
|
||||
-- push notification device tables
|
||||
-- =====================================================
|
||||
|
||||
-- push_notifications_apnsdevice: user_id
|
||||
CREATE INDEX IF NOT EXISTS idx_apnsdevice_user_id
|
||||
ON push_notifications_apnsdevice (user_id);
|
||||
|
||||
-- push_notifications_gcmdevice: user_id
|
||||
CREATE INDEX IF NOT EXISTS idx_gcmdevice_user_id
|
||||
ON push_notifications_gcmdevice (user_id);
|
||||
|
||||
-- =====================================================
|
||||
-- notification tables
|
||||
-- =====================================================
|
||||
|
||||
-- notifications_notificationpreference: user_id (unique FK)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_notificationpreference_user_id
|
||||
ON notifications_notificationpreference (user_id);
|
||||
|
||||
-- =====================================================
|
||||
-- subscription tables
|
||||
-- =====================================================
|
||||
|
||||
-- subscription_usersubscription: user_id (unique FK)
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_subscription_user_id
|
||||
ON subscription_usersubscription (user_id);
|
||||
|
||||
-- =====================================================
|
||||
-- residence tables
|
||||
-- =====================================================
|
||||
|
||||
-- residence_residence: owner_id
|
||||
CREATE INDEX IF NOT EXISTS idx_residence_owner_id
|
||||
ON residence_residence (owner_id);
|
||||
|
||||
-- residence_residencesharecode: residence_id (may already exist from model index tag via GORM)
|
||||
CREATE INDEX IF NOT EXISTS idx_sharecode_residence_id
|
||||
ON residence_residencesharecode (residence_id);
|
||||
|
||||
-- residence_residencesharecode: created_by_id
|
||||
CREATE INDEX IF NOT EXISTS idx_sharecode_created_by_id
|
||||
ON residence_residencesharecode (created_by_id);
|
||||
|
||||
-- =====================================================
|
||||
-- task tables
|
||||
-- =====================================================
|
||||
|
||||
-- task_task: created_by_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_created_by_id
|
||||
ON task_task (created_by_id);
|
||||
|
||||
-- task_task: assigned_to_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_assigned_to_id
|
||||
ON task_task (assigned_to_id);
|
||||
|
||||
-- task_task: category_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_category_id
|
||||
ON task_task (category_id);
|
||||
|
||||
-- task_task: priority_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_priority_id
|
||||
ON task_task (priority_id);
|
||||
|
||||
-- task_task: frequency_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_frequency_id
|
||||
ON task_task (frequency_id);
|
||||
|
||||
-- task_task: contractor_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_contractor_id
|
||||
ON task_task (contractor_id);
|
||||
|
||||
-- task_task: parent_task_id
|
||||
CREATE INDEX IF NOT EXISTS idx_task_parent_task_id
|
||||
ON task_task (parent_task_id);
|
||||
|
||||
-- task_taskcompletionimage: completion_id
|
||||
CREATE INDEX IF NOT EXISTS idx_completionimage_completion_id
|
||||
ON task_taskcompletionimage (completion_id);
|
||||
|
||||
-- task_document: created_by_id
|
||||
CREATE INDEX IF NOT EXISTS idx_document_created_by_id
|
||||
ON task_document (created_by_id);
|
||||
|
||||
-- task_document: task_id
|
||||
CREATE INDEX IF NOT EXISTS idx_document_task_id
|
||||
ON task_document (task_id);
|
||||
|
||||
-- task_documentimage: document_id
|
||||
CREATE INDEX IF NOT EXISTS idx_documentimage_document_id
|
||||
ON task_documentimage (document_id);
|
||||
|
||||
-- task_contractor: residence_id
|
||||
CREATE INDEX IF NOT EXISTS idx_contractor_residence_id
|
||||
ON task_contractor (residence_id);
|
||||
|
||||
-- task_reminderlog: notification_id
|
||||
CREATE INDEX IF NOT EXISTS idx_reminderlog_notification_id
|
||||
ON task_reminderlog (notification_id);
|
||||
2
migrations/018_audit_log.down.sql
Normal file
2
migrations/018_audit_log.down.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
-- Rollback: 018_audit_log
|
||||
DROP TABLE IF EXISTS audit_log;
|
||||
16
migrations/018_audit_log.up.sql
Normal file
16
migrations/018_audit_log.up.sql
Normal file
@@ -0,0 +1,16 @@
|
||||
-- Migration: 018_audit_log
|
||||
-- Create audit_log table for tracking security-relevant events (login, register, etc.)
|
||||
|
||||
CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER,
|
||||
event_type VARCHAR(50) NOT NULL,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
details JSONB,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_audit_log_user_id ON audit_log(user_id);
|
||||
CREATE INDEX idx_audit_log_event_type ON audit_log(event_type);
|
||||
CREATE INDEX idx_audit_log_created_at ON audit_log(created_at);
|
||||
Reference in New Issue
Block a user