From 81578f6e276556ee7d1705b53481b6fab5d62fc7 Mon Sep 17 00:00:00 2001 From: Trey t Date: Mon, 18 May 2026 17:55:56 -0500 Subject: [PATCH] =?UTF-8?q?feat(auth):=20replace=20hand-rolled=20auth=20wi?= =?UTF-8?q?th=20Ory=20Kratos=20=E2=80=94=20phase=202=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delegates all credential management (login, register, password reset, email verification, social sign-in) to Ory Kratos. The Go API now acts as a resource server: the new KratosAuth middleware validates sessions against the Kratos whoami endpoint, writes the local User mirror into Echo context, and all existing domain handlers continue working unchanged. Hand-rolled token auth, AuthToken model, apple_auth/ google_auth services, and the auth refresh flow are removed. Tests are updated to use the fake-token middleware pattern so existing integration assertions require no rewrite. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../handlers/apple_social_auth_handler.go | 209 +--- internal/admin/handlers/auth_token_handler.go | 135 +-- .../handlers/confirmation_code_handler.go | 154 +-- .../handlers/password_reset_code_handler.go | 151 +-- internal/admin/handlers/user_handler.go | 9 +- internal/config/config.go | 5 + internal/database/database.go | 5 - internal/handlers/auth_handler.go | 444 +------- internal/handlers/auth_handler_delete_test.go | 121 +-- internal/handlers/auth_handler_test.go | 358 +------ internal/handlers/handler_coverage_test.go | 226 ---- internal/integration/contract_test.go | 21 + internal/integration/integration_test.go | 380 ++----- .../integration/security_regression_test.go | 102 +- .../integration/subscription_is_free_test.go | 105 +- internal/kratos/client.go | 107 ++ internal/middleware/auth.go | 438 -------- internal/middleware/auth_expiry_test.go | 165 --- internal/middleware/auth_test.go | 337 ------ internal/middleware/kratos_auth.go | 271 +++++ internal/models/models_coverage_test.go | 126 +-- internal/models/user.go | 289 +----- internal/models/user_test.go | 170 +-- internal/repositories/user_repo.go | 370 +------ .../repositories/user_repo_coverage_test.go | 238 +---- .../repositories/user_repo_extended_test.go | 229 ---- internal/repositories/user_repo_test.go | 41 +- internal/router/router.go | 64 +- internal/services/apple_auth.go | 301 ------ internal/services/auth_refresh_test.go | 176 ---- internal/services/auth_service.go | 977 +----------------- internal/services/auth_service_test.go | 709 +------------ internal/services/cache_service.go | 111 -- internal/services/google_auth.go | 307 ------ internal/testutil/testutil.go | 41 +- migrations/000007_kratos_identity.sql | 37 + 36 files changed, 927 insertions(+), 7002 deletions(-) create mode 100644 internal/kratos/client.go delete mode 100644 internal/middleware/auth.go delete mode 100644 internal/middleware/auth_expiry_test.go delete mode 100644 internal/middleware/auth_test.go create mode 100644 internal/middleware/kratos_auth.go delete mode 100644 internal/services/apple_auth.go delete mode 100644 internal/services/auth_refresh_test.go delete mode 100644 internal/services/google_auth.go create mode 100644 migrations/000007_kratos_identity.sql diff --git a/internal/admin/handlers/apple_social_auth_handler.go b/internal/admin/handlers/apple_social_auth_handler.go index f98570d..b793558 100644 --- a/internal/admin/handlers/apple_social_auth_handler.go +++ b/internal/admin/handlers/apple_social_auth_handler.go @@ -1,215 +1,30 @@ +// apple_social_auth_handler is a stub — the user_applesocialauth table was +// dropped in the Ory Kratos migration (phase 2). Social sign-in is now +// handled by Kratos. package handlers import ( "net/http" - "strconv" "github.com/labstack/echo/v4" "gorm.io/gorm" - - "github.com/treytartt/honeydue-api/internal/admin/dto" - "github.com/treytartt/honeydue-api/internal/models" ) -// AdminAppleSocialAuthHandler handles admin Apple social auth management endpoints +// AdminAppleSocialAuthHandler is a no-op stub. type AdminAppleSocialAuthHandler struct { db *gorm.DB } -// NewAdminAppleSocialAuthHandler creates a new admin Apple social auth handler func NewAdminAppleSocialAuthHandler(db *gorm.DB) *AdminAppleSocialAuthHandler { return &AdminAppleSocialAuthHandler{db: db} } -// AppleSocialAuthResponse represents the response for an Apple social auth entry -type AppleSocialAuthResponse struct { - ID uint `json:"id"` - UserID uint `json:"user_id"` - Username string `json:"username"` - UserEmail string `json:"user_email"` - AppleID string `json:"apple_id"` - Email string `json:"email"` - IsPrivateEmail bool `json:"is_private_email"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` -} - -// UpdateAppleSocialAuthRequest represents the request to update an Apple social auth entry -type UpdateAppleSocialAuthRequest struct { - Email *string `json:"email"` - IsPrivateEmail *bool `json:"is_private_email"` -} - -// List handles GET /api/admin/apple-social-auth -func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error { - var filters dto.PaginationParams - if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - var entries []models.AppleSocialAuth - var total int64 - - query := h.db.Model(&models.AppleSocialAuth{}).Preload("User") - - // Apply search - if filters.Search != "" { - search := "%" + filters.Search + "%" - query = query.Joins("JOIN auth_user ON auth_user.id = user_applesocialauth.user_id"). - Where("user_applesocialauth.apple_id ILIKE ? OR user_applesocialauth.email ILIKE ? OR auth_user.username ILIKE ? OR auth_user.email ILIKE ?", - search, search, search, search) - } - - // Get total count - query.Count(&total) - - // Apply sorting (allowlist prevents SQL injection via sort_by parameter) - sortBy := filters.GetSafeSortBy([]string{ - "id", "user_id", "apple_id", "email", "is_private_email", - "created_at", "updated_at", - }, "created_at") - query = query.Order(sortBy + " " + filters.GetSortDir()) - - // Apply pagination - query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage()) - - if err := query.Find(&entries).Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entries"}) - } - - // Build response - responses := make([]AppleSocialAuthResponse, len(entries)) - for i, entry := range entries { - responses[i] = h.toResponse(&entry) - } - - return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage())) -} - -// Get handles GET /api/admin/apple-social-auth/:id -func (h *AdminAppleSocialAuthHandler) Get(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - var entry models.AppleSocialAuth - if err := h.db.Preload("User").First(&entry, id).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"}) - } - - return c.JSON(http.StatusOK, h.toResponse(&entry)) -} - -// GetByUser handles GET /api/admin/apple-social-auth/user/:user_id -func (h *AdminAppleSocialAuthHandler) GetByUser(c echo.Context) error { - userID, err := strconv.ParseUint(c.Param("user_id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"}) - } - - var entry models.AppleSocialAuth - if err := h.db.Preload("User").Where("user_id = ?", userID).First(&entry).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found for user"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"}) - } - - return c.JSON(http.StatusOK, h.toResponse(&entry)) -} - -// Update handles PUT /api/admin/apple-social-auth/:id -func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - var entry models.AppleSocialAuth - if err := h.db.First(&entry, id).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"}) - } - - var req UpdateAppleSocialAuthRequest - if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - if req.Email != nil { - entry.Email = *req.Email - } - if req.IsPrivateEmail != nil { - entry.IsPrivateEmail = *req.IsPrivateEmail - } - - if err := h.db.Save(&entry).Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update Apple social auth entry"}) - } - - h.db.Preload("User").First(&entry, id) - return c.JSON(http.StatusOK, h.toResponse(&entry)) -} - -// Delete handles DELETE /api/admin/apple-social-auth/:id -func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - var entry models.AppleSocialAuth - if err := h.db.First(&entry, id).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"}) - } - - if err := h.db.Delete(&entry).Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entry"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entry deleted successfully"}) -} - -// BulkDelete handles DELETE /api/admin/apple-social-auth/bulk -func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error { - var req dto.BulkDeleteRequest - if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - result := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{}) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": result.RowsAffected}) -} - -// toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse -func (h *AdminAppleSocialAuthHandler) toResponse(entry *models.AppleSocialAuth) AppleSocialAuthResponse { - response := AppleSocialAuthResponse{ - ID: entry.ID, - UserID: entry.UserID, - AppleID: entry.AppleID, - Email: entry.Email, - IsPrivateEmail: entry.IsPrivateEmail, - CreatedAt: entry.CreatedAt.Format("2006-01-02T15:04:05Z"), - UpdatedAt: entry.UpdatedAt.Format("2006-01-02T15:04:05Z"), - } - - if entry.User.ID != 0 { - response.Username = entry.User.Username - response.UserEmail = entry.User.Email - } - - return response +func (h *AdminAppleSocialAuthHandler) gone(c echo.Context) error { + return c.JSON(http.StatusGone, map[string]string{"message": "Apple social auth is managed by Ory Kratos"}) } +func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error { return h.gone(c) } +func (h *AdminAppleSocialAuthHandler) Get(c echo.Context) error { return h.gone(c) } +func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error { return h.gone(c) } +func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error { return h.gone(c) } +func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error { return h.gone(c) } +func (h *AdminAppleSocialAuthHandler) GetByUser(c echo.Context) error { return h.gone(c) } diff --git a/internal/admin/handlers/auth_token_handler.go b/internal/admin/handlers/auth_token_handler.go index 1f056f1..2f2a15c 100644 --- a/internal/admin/handlers/auth_token_handler.go +++ b/internal/admin/handlers/auth_token_handler.go @@ -1,144 +1,27 @@ +// auth_token_handler is a stub — the user_authtoken table was dropped in the +// Ory Kratos migration (phase 2). Auth tokens are now Kratos sessions. package handlers import ( "net/http" - "strconv" "github.com/labstack/echo/v4" "gorm.io/gorm" - - "github.com/treytartt/honeydue-api/internal/admin/dto" - "github.com/treytartt/honeydue-api/internal/models" ) -// AdminAuthTokenHandler handles admin auth token management endpoints +// AdminAuthTokenHandler is a no-op stub. type AdminAuthTokenHandler struct { db *gorm.DB } -// NewAdminAuthTokenHandler creates a new admin auth token handler func NewAdminAuthTokenHandler(db *gorm.DB) *AdminAuthTokenHandler { return &AdminAuthTokenHandler{db: db} } -// AuthTokenResponse represents an auth token in API responses -type AuthTokenResponse struct { - Key string `json:"key"` - UserID uint `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - Created string `json:"created"` -} - -// List handles GET /api/admin/auth-tokens -func (h *AdminAuthTokenHandler) List(c echo.Context) error { - var filters dto.PaginationParams - if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - var tokens []models.AuthToken - var total int64 - - query := h.db.Model(&models.AuthToken{}).Preload("User") - - // Apply search (search by user info) - if filters.Search != "" { - search := "%" + filters.Search + "%" - query = query.Joins("JOIN auth_user ON auth_user.id = user_authtoken.user_id"). - Where( - "auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_authtoken.key ILIKE ?", - search, search, search, - ) - } - - // Get total count - query.Count(&total) - - // Apply sorting (allowlist prevents SQL injection via sort_by parameter) - sortBy := filters.GetSafeSortBy([]string{ - "created", "user_id", - }, "created") - query = query.Order(sortBy + " " + filters.GetSortDir()) - - // Apply pagination - query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage()) - - if err := query.Find(&tokens).Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch auth tokens"}) - } - - // Build response - responses := make([]AuthTokenResponse, len(tokens)) - for i, token := range tokens { - responses[i] = AuthTokenResponse{ - Key: token.Key, - UserID: token.UserID, - Username: token.User.Username, - Email: token.User.Email, - Created: token.Created.Format("2006-01-02T15:04:05Z"), - } - } - - return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage())) -} - -// Get handles GET /api/admin/auth-tokens/:id (id is actually user_id) -func (h *AdminAuthTokenHandler) Get(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"}) - } - - var token models.AuthToken - if err := h.db.Preload("User").Where("user_id = ?", id).First(&token).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Auth token not found"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch auth token"}) - } - - response := AuthTokenResponse{ - Key: token.Key, - UserID: token.UserID, - Username: token.User.Username, - Email: token.User.Email, - Created: token.Created.Format("2006-01-02T15:04:05Z"), - } - - return c.JSON(http.StatusOK, response) -} - -// Delete handles DELETE /api/admin/auth-tokens/:id (revoke token) -func (h *AdminAuthTokenHandler) Delete(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"}) - } - - result := h.db.Where("user_id = ?", id).Delete(&models.AuthToken{}) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to revoke token"}) - } - - if result.RowsAffected == 0 { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Auth token not found"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Auth token revoked successfully"}) -} - -// BulkDelete handles DELETE /api/admin/auth-tokens/bulk -func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error { - var req dto.BulkDeleteRequest - if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{}) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to revoke tokens"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Auth tokens revoked successfully", "count": result.RowsAffected}) +func (h *AdminAuthTokenHandler) gone(c echo.Context) error { + return c.JSON(http.StatusGone, map[string]string{"message": "auth tokens are managed by Ory Kratos"}) } +func (h *AdminAuthTokenHandler) List(c echo.Context) error { return h.gone(c) } +func (h *AdminAuthTokenHandler) Get(c echo.Context) error { return h.gone(c) } +func (h *AdminAuthTokenHandler) Delete(c echo.Context) error { return h.gone(c) } +func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error { return h.gone(c) } diff --git a/internal/admin/handlers/confirmation_code_handler.go b/internal/admin/handlers/confirmation_code_handler.go index 758cca5..461a4bc 100644 --- a/internal/admin/handlers/confirmation_code_handler.go +++ b/internal/admin/handlers/confirmation_code_handler.go @@ -1,162 +1,28 @@ +// confirmation_code_handler is a stub — the user_confirmationcode table was +// dropped in the Ory Kratos migration (phase 2). Email verification is now +// handled by Kratos. package handlers import ( "net/http" - "strconv" - "strings" "github.com/labstack/echo/v4" "gorm.io/gorm" - - "github.com/treytartt/honeydue-api/internal/admin/dto" - "github.com/treytartt/honeydue-api/internal/models" ) -// maskCode masks a confirmation code, showing only the last 4 characters. -func maskCode(code string) string { - if len(code) <= 4 { - return strings.Repeat("*", len(code)) - } - return strings.Repeat("*", len(code)-4) + code[len(code)-4:] -} - -// AdminConfirmationCodeHandler handles admin confirmation code management endpoints +// AdminConfirmationCodeHandler is a no-op stub. type AdminConfirmationCodeHandler struct { db *gorm.DB } -// NewAdminConfirmationCodeHandler creates a new admin confirmation code handler func NewAdminConfirmationCodeHandler(db *gorm.DB) *AdminConfirmationCodeHandler { return &AdminConfirmationCodeHandler{db: db} } -// ConfirmationCodeResponse represents a confirmation code in API responses -type ConfirmationCodeResponse struct { - ID uint `json:"id"` - UserID uint `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - Code string `json:"code"` - ExpiresAt string `json:"expires_at"` - IsUsed bool `json:"is_used"` - CreatedAt string `json:"created_at"` -} - -// List handles GET /api/admin/confirmation-codes -func (h *AdminConfirmationCodeHandler) List(c echo.Context) error { - var filters dto.PaginationParams - if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - var codes []models.ConfirmationCode - var total int64 - - query := h.db.Model(&models.ConfirmationCode{}).Preload("User") - - // Apply search (search by user info or code) - if filters.Search != "" { - search := "%" + filters.Search + "%" - query = query.Joins("JOIN auth_user ON auth_user.id = user_confirmationcode.user_id"). - Where( - "auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_confirmationcode.code ILIKE ?", - search, search, search, - ) - } - - // Get total count - query.Count(&total) - - // Apply sorting (allowlist prevents SQL injection via sort_by parameter) - sortBy := filters.GetSafeSortBy([]string{ - "id", "user_id", "created_at", "expires_at", "is_used", - }, "created_at") - query = query.Order(sortBy + " " + filters.GetSortDir()) - - // Apply pagination - query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage()) - - if err := query.Find(&codes).Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch confirmation codes"}) - } - - // Build response - responses := make([]ConfirmationCodeResponse, len(codes)) - for i, code := range codes { - responses[i] = ConfirmationCodeResponse{ - ID: code.ID, - UserID: code.UserID, - Username: code.User.Username, - Email: code.User.Email, - Code: maskCode(code.Code), - ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"), - IsUsed: code.IsUsed, - CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"), - } - } - - return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage())) -} - -// Get handles GET /api/admin/confirmation-codes/:id -func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - var code models.ConfirmationCode - if err := h.db.Preload("User").First(&code, id).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Confirmation code not found"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch confirmation code"}) - } - - response := ConfirmationCodeResponse{ - ID: code.ID, - UserID: code.UserID, - Username: code.User.Username, - Email: code.User.Email, - Code: maskCode(code.Code), - ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"), - IsUsed: code.IsUsed, - CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"), - } - - return c.JSON(http.StatusOK, response) -} - -// Delete handles DELETE /api/admin/confirmation-codes/:id -func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - result := h.db.Delete(&models.ConfirmationCode{}, id) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation code"}) - } - - if result.RowsAffected == 0 { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Confirmation code not found"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Confirmation code deleted successfully"}) -} - -// BulkDelete handles DELETE /api/admin/confirmation-codes/bulk -func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error { - var req dto.BulkDeleteRequest - if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{}) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Confirmation codes deleted successfully", "count": result.RowsAffected}) +func (h *AdminConfirmationCodeHandler) gone(c echo.Context) error { + return c.JSON(http.StatusGone, map[string]string{"message": "confirmation codes are managed by Ory Kratos"}) } +func (h *AdminConfirmationCodeHandler) List(c echo.Context) error { return h.gone(c) } +func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error { return h.gone(c) } +func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error { return h.gone(c) } +func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error { return h.gone(c) } diff --git a/internal/admin/handlers/password_reset_code_handler.go b/internal/admin/handlers/password_reset_code_handler.go index 5ac139a..f7ad490 100644 --- a/internal/admin/handlers/password_reset_code_handler.go +++ b/internal/admin/handlers/password_reset_code_handler.go @@ -1,159 +1,28 @@ +// password_reset_code_handler is a stub — the user_passwordresetcode table +// was dropped in the Ory Kratos migration (phase 2). Password resets are now +// handled by Kratos. package handlers import ( "net/http" - "strconv" "github.com/labstack/echo/v4" "gorm.io/gorm" - - "github.com/treytartt/honeydue-api/internal/admin/dto" - "github.com/treytartt/honeydue-api/internal/models" ) -// AdminPasswordResetCodeHandler handles admin password reset code management endpoints +// AdminPasswordResetCodeHandler is a no-op stub. type AdminPasswordResetCodeHandler struct { db *gorm.DB } -// NewAdminPasswordResetCodeHandler creates a new admin password reset code handler func NewAdminPasswordResetCodeHandler(db *gorm.DB) *AdminPasswordResetCodeHandler { return &AdminPasswordResetCodeHandler{db: db} } -// PasswordResetCodeResponse represents a password reset code in API responses -type PasswordResetCodeResponse struct { - ID uint `json:"id"` - UserID uint `json:"user_id"` - Username string `json:"username"` - Email string `json:"email"` - ResetToken string `json:"reset_token"` - ExpiresAt string `json:"expires_at"` - Used bool `json:"used"` - Attempts int `json:"attempts"` - MaxAttempts int `json:"max_attempts"` - CreatedAt string `json:"created_at"` -} - -// List handles GET /api/admin/password-reset-codes -func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error { - var filters dto.PaginationParams - if err := c.Bind(&filters); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - var codes []models.PasswordResetCode - var total int64 - - query := h.db.Model(&models.PasswordResetCode{}).Preload("User") - - // Apply search (search by user info or token) - if filters.Search != "" { - search := "%" + filters.Search + "%" - query = query.Joins("JOIN auth_user ON auth_user.id = user_passwordresetcode.user_id"). - Where( - "auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_passwordresetcode.reset_token ILIKE ?", - search, search, search, - ) - } - - // Get total count - query.Count(&total) - - // Apply sorting (allowlist prevents SQL injection via sort_by parameter) - sortBy := filters.GetSafeSortBy([]string{ - "id", "user_id", "created_at", "expires_at", "used", - }, "created_at") - query = query.Order(sortBy + " " + filters.GetSortDir()) - - // Apply pagination - query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage()) - - if err := query.Find(&codes).Error; err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch password reset codes"}) - } - - // Build response - responses := make([]PasswordResetCodeResponse, len(codes)) - for i, code := range codes { - responses[i] = PasswordResetCodeResponse{ - ID: code.ID, - UserID: code.UserID, - Username: code.User.Username, - Email: code.User.Email, - ResetToken: code.ResetToken[:8] + "..." + code.ResetToken[len(code.ResetToken)-4:], // Truncate for display - ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"), - Used: code.Used, - Attempts: code.Attempts, - MaxAttempts: code.MaxAttempts, - CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"), - } - } - - return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage())) -} - -// Get handles GET /api/admin/password-reset-codes/:id -func (h *AdminPasswordResetCodeHandler) Get(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - var code models.PasswordResetCode - if err := h.db.Preload("User").First(&code, id).Error; err != nil { - if err == gorm.ErrRecordNotFound { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Password reset code not found"}) - } - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch password reset code"}) - } - - response := PasswordResetCodeResponse{ - ID: code.ID, - UserID: code.UserID, - Username: code.User.Username, - Email: code.User.Email, - ResetToken: code.ResetToken[:8] + "..." + code.ResetToken[len(code.ResetToken)-4:], - ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"), - Used: code.Used, - Attempts: code.Attempts, - MaxAttempts: code.MaxAttempts, - CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"), - } - - return c.JSON(http.StatusOK, response) -} - -// Delete handles DELETE /api/admin/password-reset-codes/:id -func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error { - id, err := strconv.ParseUint(c.Param("id"), 10, 32) - if err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"}) - } - - result := h.db.Delete(&models.PasswordResetCode{}, id) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset code"}) - } - - if result.RowsAffected == 0 { - return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Password reset code not found"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Password reset code deleted successfully"}) -} - -// BulkDelete handles DELETE /api/admin/password-reset-codes/bulk -func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error { - var req dto.BulkDeleteRequest - if err := c.Bind(&req); err != nil { - return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"}) - } - - result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{}) - if result.Error != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"}) - } - - return c.JSON(http.StatusOK, map[string]interface{}{"message": "Password reset codes deleted successfully", "count": result.RowsAffected}) +func (h *AdminPasswordResetCodeHandler) gone(c echo.Context) error { + return c.JSON(http.StatusGone, map[string]string{"message": "password reset codes are managed by Ory Kratos"}) } +func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error { return h.gone(c) } +func (h *AdminPasswordResetCodeHandler) Get(c echo.Context) error { return h.gone(c) } +func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error { return h.gone(c) } +func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error { return h.gone(c) } diff --git a/internal/admin/handlers/user_handler.go b/internal/admin/handlers/user_handler.go index 8c99c09..4737eb8 100644 --- a/internal/admin/handlers/user_handler.go +++ b/internal/admin/handlers/user_handler.go @@ -207,9 +207,7 @@ func (h *AdminUserHandler) Create(c echo.Context) error { user.IsSuperuser = *req.IsSuperuser } - if err := user.SetPassword(req.Password); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to hash password"}) - } + // Password management is handled by Ory Kratos; no local password hashing. if err := h.db.Create(&user).Error; err != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create user"}) @@ -284,10 +282,9 @@ func (h *AdminUserHandler) Update(c echo.Context) error { if req.IsSuperuser != nil { user.IsSuperuser = *req.IsSuperuser } + // Password management is handled by Ory Kratos; local password update ignored. if req.Password != nil { - if err := user.SetPassword(*req.Password); err != nil { - return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to hash password"}) - } + _ = req.Password // Password changes must go through Kratos admin API } if err := h.db.Save(&user).Error; err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 800e782..8f8a8b9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -142,6 +142,9 @@ type SecurityConfig struct { MaxPasswordResetRate int // per hour TokenExpiryDays int // Number of days before auth tokens expire (default 90) TokenRefreshDays int // Token must be at least this many days old before refresh (default 60) + // KratosPublicURL is the Ory Kratos public API base URL. The auth + // middleware validates sessions against {KratosPublicURL}/sessions/whoami. + KratosPublicURL string } // StorageConfig holds file storage settings. @@ -304,6 +307,7 @@ func Load() (*Config, error) { MaxPasswordResetRate: 3, TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"), TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"), + KratosPublicURL: viper.GetString("KRATOS_PUBLIC_URL"), }, Storage: StorageConfig{ UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), @@ -411,6 +415,7 @@ func setDefaults() { // Token expiry defaults viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days + viper.SetDefault("KRATOS_PUBLIC_URL", "http://kratos:4433") // Ory Kratos public API viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days // Storage defaults diff --git a/internal/database/database.go b/internal/database/database.go index 11cdd9f..4053629 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -244,12 +244,7 @@ func Migrate() error { // User and auth tables &models.User{}, - &models.AuthToken{}, &models.UserProfile{}, - &models.ConfirmationCode{}, - &models.PasswordResetCode{}, - &models.AppleSocialAuth{}, - &models.GoogleSocialAuth{}, // Admin users (separate from app users) &models.AdminUser{}, diff --git a/internal/handlers/auth_handler.go b/internal/handlers/auth_handler.go index 14b319f..c673453 100644 --- a/internal/handlers/auth_handler.go +++ b/internal/handlers/auth_handler.go @@ -1,8 +1,6 @@ package handlers import ( - "context" - "errors" "net/http" "github.com/labstack/echo/v4" @@ -16,18 +14,18 @@ import ( "github.com/treytartt/honeydue-api/internal/validator" ) -// AuthHandler handles authentication endpoints +// AuthHandler handles user profile and account management endpoints. +// Session lifecycle (login, register, logout, password reset) is delegated +// to Ory Kratos; this handler only deals with the honeyDue user record. type AuthHandler struct { - authService *services.AuthService - emailService *services.EmailService - cache *services.CacheService - appleAuthService *services.AppleAuthService - googleAuthService *services.GoogleAuthService - storageService *services.StorageService - auditService *services.AuditService + authService *services.AuthService + emailService *services.EmailService + cache *services.CacheService + storageService *services.StorageService + auditService *services.AuditService } -// NewAuthHandler creates a new auth handler +// NewAuthHandler creates a new auth handler. func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService, cache *services.CacheService) *AuthHandler { return &AuthHandler{ authService: authService, @@ -36,136 +34,21 @@ func NewAuthHandler(authService *services.AuthService, emailService *services.Em } } -// SetAppleAuthService sets the Apple auth service (called after initialization) -func (h *AuthHandler) SetAppleAuthService(appleAuth *services.AppleAuthService) { - h.appleAuthService = appleAuth -} - -// SetGoogleAuthService sets the Google auth service (called after initialization) -func (h *AuthHandler) SetGoogleAuthService(googleAuth *services.GoogleAuthService) { - h.googleAuthService = googleAuth -} - -// SetStorageService sets the storage service for file deletion during account deletion +// SetStorageService sets the storage service for file deletion during account deletion. func (h *AuthHandler) SetStorageService(storageService *services.StorageService) { h.storageService = storageService } -// SetAuditService sets the audit service for logging security events +// SetAuditService sets the audit service for logging security events. func (h *AuthHandler) SetAuditService(auditService *services.AuditService) { h.auditService = auditService } -// noStore marks a response as non-cacheable (audit L2) — auth responses -// carry tokens and user data that must never sit in any cache. +// noStore marks a response as non-cacheable. func noStore(c echo.Context) { c.Response().Header().Set("Cache-Control", "no-store") } -// Login handles POST /api/auth/login/ -func (h *AuthHandler) Login(c echo.Context) error { - noStore(c) - var req requests.LoginRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - response, err := h.authService.Login(c.Request().Context(), &req, c.RealIP()) - if err != nil { - log.Debug().Err(err).Str("identifier", req.Username). - Str("ip", c.RealIP()).Str("user_agent", c.Request().UserAgent()). - Msg("Login failed") - if h.auditService != nil { - h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{ - "identifier": req.Username, - }) - } - return err - } - - if h.auditService != nil { - userID := response.User.ID - h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil) - } - - return c.JSON(http.StatusOK, response) -} - -// Register handles POST /api/auth/register/ -func (h *AuthHandler) Register(c echo.Context) error { - noStore(c) - var req requests.RegisterRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req) - if err != nil { - log.Debug().Err(err).Msg("Registration failed") - return err - } - - if h.auditService != nil { - userID := response.User.ID - h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{ - "username": req.Username, - "email": req.Email, - }) - } - - // Send welcome email with confirmation code (async) - if h.emailService != nil && confirmationCode != "" { - go func() { - defer func() { - if r := recover(); r != nil { - log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine") - } - }() - if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil { - log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email") - } - }() - } - - return c.JSON(http.StatusCreated, response) -} - -// Logout handles POST /api/auth/logout/ -func (h *AuthHandler) Logout(c echo.Context) error { - token := middleware.GetAuthToken(c) - if token == "" { - return apperrors.Unauthorized("error.not_authenticated") - } - - // Log audit event before invalidating the token - if h.auditService != nil { - user := middleware.GetAuthUser(c) - if user != nil { - h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil) - } - } - - // Invalidate token in database - if err := h.authService.Logout(c.Request().Context(), token); err != nil { - log.Warn().Err(err).Msg("Failed to delete token from database") - } - - // Invalidate token in cache - if h.cache != nil { - if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil { - log.Warn().Err(err).Msg("Failed to invalidate token in cache") - } - } - - return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Logged out successfully"}) -} - // CurrentUser handles GET /api/auth/me/ func (h *AuthHandler) CurrentUser(c echo.Context) error { noStore(c) @@ -207,301 +90,6 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error { return c.JSON(http.StatusOK, response) } -// VerifyEmail handles POST /api/auth/verify-email/ -func (h *AuthHandler) VerifyEmail(c echo.Context) error { - user, err := middleware.MustGetAuthUser(c) - if err != nil { - return err - } - - var req requests.VerifyEmailRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code) - if err != nil { - log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed") - return err - } - - // Send post-verification welcome email with tips (async) - if h.emailService != nil { - go func() { - defer func() { - if r := recover(); r != nil { - log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine") - } - }() - if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil { - log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email") - } - }() - } - - return c.JSON(http.StatusOK, responses.VerifyEmailResponse{ - Message: "Email verified successfully", - Verified: true, - }) -} - -// ResendVerification handles POST /api/auth/resend-verification/ -func (h *AuthHandler) ResendVerification(c echo.Context) error { - user, err := middleware.MustGetAuthUser(c) - if err != nil { - return err - } - - code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID) - if err != nil { - log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification") - return err - } - - // Send verification email (async) - if h.emailService != nil { - go func() { - defer func() { - if r := recover(); r != nil { - log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine") - } - }() - if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil { - log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email") - } - }() - } - - return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Verification email sent"}) -} - -// ForgotPassword handles POST /api/auth/forgot-password/ -func (h *AuthHandler) ForgotPassword(c echo.Context) error { - var req requests.ForgotPasswordRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - noStore(c) - - if h.auditService != nil { - h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{ - "email": req.Email, - }) - } - - // Audit LIVE-L13: run the user lookup, code generation, and email send - // entirely in the background, then return the generic response - // immediately. This makes the response time identical whether or not - // the email belongs to a real account, defeating timing-based user - // enumeration. context.Background() is used because the request context - // is cancelled the moment this handler returns. Per-account rate - // limiting still runs inside the service; the edge auth-rate-limit - // middleware covers per-IP abuse. - email := req.Email - go func() { - defer func() { - if r := recover(); r != nil { - log.Error().Interface("panic", r).Str("email", email).Msg("Panic in forgot-password goroutine") - } - }() - code, user, err := h.authService.ForgotPassword(context.Background(), email) - if err != nil || code == "" || user == nil { - return - } - if h.emailService != nil { - if sendErr := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); sendErr != nil { - log.Error().Err(sendErr).Str("email", user.Email).Msg("Failed to send password reset email") - } - } - }() - - // Always return success to prevent email enumeration. - return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{ - Message: "Password reset email sent", - }) -} - -// VerifyResetCode handles POST /api/auth/verify-reset-code/ -func (h *AuthHandler) VerifyResetCode(c echo.Context) error { - var req requests.VerifyResetCodeRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code) - if err != nil { - log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed") - return err - } - - return c.JSON(http.StatusOK, responses.VerifyResetCodeResponse{ - Message: "Reset code verified", - ResetToken: resetToken, - }) -} - -// ResetPassword handles POST /api/auth/reset-password/ -func (h *AuthHandler) ResetPassword(c echo.Context) error { - var req requests.ResetPasswordRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword) - if err != nil { - log.Debug().Err(err).Msg("Password reset failed") - return err - } - - if h.auditService != nil { - h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{ - "method": "reset_token", - }) - } - - return c.JSON(http.StatusOK, responses.ResetPasswordResponse{ - Message: "Password reset successful", - }) -} - -// AppleSignIn handles POST /api/auth/apple-sign-in/ -func (h *AuthHandler) AppleSignIn(c echo.Context) error { - noStore(c) - var req requests.AppleSignInRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - if h.appleAuthService == nil { - log.Error().Msg("Apple auth service not configured") - return &apperrors.AppError{ - Code: 500, - MessageKey: "error.apple_signin_not_configured", - } - } - - response, err := h.authService.AppleSignIn(c.Request().Context(), h.appleAuthService, &req) - if err != nil { - // Check for legacy Apple Sign In error (not yet migrated) - if errors.Is(err, services.ErrAppleSignInFailed) { - log.Debug().Err(err).Msg("Apple Sign In failed (legacy error)") - return apperrors.Unauthorized("error.invalid_apple_token") - } - - log.Debug().Err(err).Msg("Apple Sign In failed") - return err - } - - // Send welcome email for new users (async) - if response.IsNewUser && h.emailService != nil && response.User.Email != "" { - go func() { - defer func() { - if r := recover(); r != nil { - log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine") - } - }() - if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil { - log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email") - } - }() - } - - return c.JSON(http.StatusOK, response) -} - -// GoogleSignIn handles POST /api/auth/google-sign-in/ -func (h *AuthHandler) GoogleSignIn(c echo.Context) error { - noStore(c) - var req requests.GoogleSignInRequest - if err := c.Bind(&req); err != nil { - return apperrors.BadRequest("error.invalid_request") - } - if err := c.Validate(&req); err != nil { - return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err)) - } - - if h.googleAuthService == nil { - log.Error().Msg("Google auth service not configured") - return &apperrors.AppError{ - Code: 500, - MessageKey: "error.google_signin_not_configured", - } - } - - response, err := h.authService.GoogleSignIn(c.Request().Context(), h.googleAuthService, &req) - if err != nil { - // Check for legacy Google Sign In error (not yet migrated) - if errors.Is(err, services.ErrGoogleSignInFailed) { - log.Debug().Err(err).Msg("Google Sign In failed (legacy error)") - return apperrors.Unauthorized("error.invalid_google_token") - } - - log.Debug().Err(err).Msg("Google Sign In failed") - return err - } - - // Send welcome email for new users (async) - if response.IsNewUser && h.emailService != nil && response.User.Email != "" { - go func() { - defer func() { - if r := recover(); r != nil { - log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine") - } - }() - if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil { - log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email") - } - }() - } - - return c.JSON(http.StatusOK, response) -} - -// RefreshToken handles POST /api/auth/refresh/ -func (h *AuthHandler) RefreshToken(c echo.Context) error { - noStore(c) - user, err := middleware.MustGetAuthUser(c) - if err != nil { - return err - } - - token := middleware.GetAuthToken(c) - if token == "" { - return apperrors.Unauthorized("error.not_authenticated") - } - - response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID) - if err != nil { - log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed") - return err - } - - // If the token was refreshed (new token), invalidate the old one from cache - if response.Token != token && h.cache != nil { - if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil { - log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh") - } - } - - return c.JSON(http.StatusOK, response) -} - // DeleteAccount handles DELETE /api/auth/account/ func (h *AuthHandler) DeleteAccount(c echo.Context) error { user, err := middleware.MustGetAuthUser(c) @@ -544,13 +132,5 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error { }() } - // Invalidate auth token from cache - token := middleware.GetAuthToken(c) - if h.cache != nil && token != "" { - if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil { - log.Warn().Err(err).Msg("Failed to invalidate token in cache after account deletion") - } - } - return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Account deleted successfully"}) } diff --git a/internal/handlers/auth_handler_delete_test.go b/internal/handlers/auth_handler_delete_test.go index 7d2c8ba..b83e500 100644 --- a/internal/handlers/auth_handler_delete_test.go +++ b/internal/handlers/auth_handler_delete_test.go @@ -35,26 +35,25 @@ func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB return handler, e, db } -func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) { +// TestAuthHandler_DeleteAccount_WithConfirmation verifies that DELETE /account/ +// succeeds when the user sends confirmation: "DELETE". +// Post-Kratos: all users (regardless of provider) must confirm with "DELETE". +func TestAuthHandler_DeleteAccount_WithConfirmation(t *testing.T) { handler, e, db := setupDeleteAccountHandler(t) - user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "Password123") + user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "ignored") // Create profile for the user profile := &models.UserProfile{UserID: user.ID, Verified: true} require.NoError(t, db.Create(profile).Error) - // Create auth token - testutil.CreateTestToken(t, db, user.ID) - authGroup := e.Group("/api/auth") authGroup.Use(testutil.MockAuthMiddleware(user)) authGroup.DELETE("/account/", handler.DeleteAccount) - t.Run("successful deletion with correct password", func(t *testing.T) { - password := "Password123" + t.Run("successful deletion with DELETE confirmation", func(t *testing.T) { req := map[string]interface{}{ - "password": password, + "confirmation": "DELETE", } w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") @@ -74,106 +73,15 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) { // Verify profile is deleted db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count) assert.Equal(t, int64(0), count) - - // Verify auth token is deleted - db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count) - assert.Equal(t, int64(0), count) }) } -func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) { +// TestAuthHandler_DeleteAccount_MissingConfirmation verifies that a missing +// confirmation string is rejected with 400. +func TestAuthHandler_DeleteAccount_MissingConfirmation(t *testing.T) { handler, e, db := setupDeleteAccountHandler(t) - user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "Password123") - - authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.DELETE("/account/", handler.DeleteAccount) - - t.Run("wrong password returns 401", func(t *testing.T) { - wrongPw := "wrongpassword" - req := map[string]interface{}{ - "password": wrongPw, - } - - w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") - - testutil.AssertStatusCode(t, w, http.StatusUnauthorized) - }) -} - -func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) { - handler, e, db := setupDeleteAccountHandler(t) - - user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "Password123") - - authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.DELETE("/account/", handler.DeleteAccount) - - t.Run("missing password returns 400", func(t *testing.T) { - req := map[string]interface{}{} - - w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") - - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - -func TestAuthHandler_DeleteAccount_SocialAuthUser(t *testing.T) { - handler, e, db := setupDeleteAccountHandler(t) - - user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "randompassword") - - // Create Apple social auth record - appleAuth := &models.AppleSocialAuth{ - UserID: user.ID, - AppleID: "apple_sub_123", - Email: "apple@test.com", - } - require.NoError(t, db.Create(appleAuth).Error) - - // Create profile - profile := &models.UserProfile{UserID: user.ID, Verified: true} - require.NoError(t, db.Create(profile).Error) - - authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.DELETE("/account/", handler.DeleteAccount) - - t.Run("successful deletion with DELETE confirmation", func(t *testing.T) { - confirmation := "DELETE" - req := map[string]interface{}{ - "confirmation": confirmation, - } - - w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") - - testutil.AssertStatusCode(t, w, http.StatusOK) - - // Verify user is deleted - var count int64 - db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count) - assert.Equal(t, int64(0), count) - - // Verify apple auth is deleted - db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", user.ID).Count(&count) - assert.Equal(t, int64(0), count) - }) -} - -func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) { - handler, e, db := setupDeleteAccountHandler(t) - - user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "randompassword") - - // Create Google social auth record - googleAuth := &models.GoogleSocialAuth{ - UserID: user.ID, - GoogleID: "google_sub_456", - Email: "google@test.com", - } - require.NoError(t, db.Create(googleAuth).Error) + user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "ignored") authGroup := e.Group("/api/auth") authGroup.Use(testutil.MockAuthMiddleware(user)) @@ -188,9 +96,8 @@ func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) { }) t.Run("wrong confirmation returns 400", func(t *testing.T) { - wrongConfirmation := "delete" req := map[string]interface{}{ - "confirmation": wrongConfirmation, + "confirmation": "delete", // lowercase — must be exact "DELETE" } w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") @@ -199,6 +106,8 @@ func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) { }) } +// TestAuthHandler_DeleteAccount_Unauthenticated verifies that 401 is returned +// when no auth middleware is set. func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) { handler, e, _ := setupDeleteAccountHandler(t) @@ -207,7 +116,7 @@ func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) { t.Run("unauthenticated request returns 401", func(t *testing.T) { req := map[string]interface{}{ - "password": "Password123", + "confirmation": "DELETE", } w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "") diff --git a/internal/handlers/auth_handler_test.go b/internal/handlers/auth_handler_test.go index 639944b..7e37fe9 100644 --- a/internal/handlers/auth_handler_test.go +++ b/internal/handlers/auth_handler_test.go @@ -1,3 +1,7 @@ +// auth_handler_test.go tests the auth handler endpoints that survived the +// Ory Kratos migration: GET /me/ and PUT/PATCH /profile/. +// Login, register, logout, forgot-password, and social sign-in are now +// handled by Kratos. package handlers import ( @@ -34,204 +38,32 @@ func setupAuthHandler(t *testing.T) (*AuthHandler, *echo.Echo, *repositories.Use return handler, e, userRepo } -func TestAuthHandler_Register(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/register/", handler.Register) - - t.Run("successful registration", func(t *testing.T) { - req := requests.RegisterRequest{ - Username: "newuser", - Email: "new@test.com", - Password: "Password123", - FirstName: "New", - LastName: "User", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusCreated) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - - testutil.AssertJSONFieldExists(t, response, "token") - testutil.AssertJSONFieldExists(t, response, "user") - testutil.AssertJSONFieldExists(t, response, "message") - - user := response["user"].(map[string]interface{}) - assert.Equal(t, "newuser", user["username"]) - assert.Equal(t, "new@test.com", user["email"]) - assert.Equal(t, "New", user["first_name"]) - assert.Equal(t, "User", user["last_name"]) - }) - - t.Run("registration with missing fields", func(t *testing.T) { - req := map[string]string{ - "username": "test", - // Missing email and password - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - - response := testutil.ParseJSON(t, w.Body.Bytes()) - testutil.AssertJSONFieldExists(t, response, "error") - }) - - t.Run("registration with short password", func(t *testing.T) { - req := requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "short", // Less than 8 chars - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) - - t.Run("registration with duplicate username", func(t *testing.T) { - // First registration - req := requests.RegisterRequest{ - Username: "duplicate", - Email: "unique1@test.com", - Password: "Password123", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - testutil.AssertStatusCode(t, w, http.StatusCreated) - - // Try to register again with same username - req.Email = "unique2@test.com" - w = testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - testutil.AssertStatusCode(t, w, http.StatusConflict) // 409 for duplicate resource - - response := testutil.ParseJSON(t, w.Body.Bytes()) - assert.Contains(t, response["error"], "Username already taken") - }) - - t.Run("registration with duplicate email", func(t *testing.T) { - // First registration - req := requests.RegisterRequest{ - Username: "user1", - Email: "duplicate@test.com", - Password: "Password123", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - testutil.AssertStatusCode(t, w, http.StatusCreated) - - // Try to register again with same email - req.Username = "user2" - w = testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - testutil.AssertStatusCode(t, w, http.StatusConflict) // 409 for duplicate resource - - response := testutil.ParseJSON(t, w.Body.Bytes()) - assert.Contains(t, response["error"], "Email already registered") - }) -} - -func TestAuthHandler_Login(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/register/", handler.Register) - e.POST("/api/auth/login/", handler.Login) - - // Create a test user - registerReq := requests.RegisterRequest{ - Username: "logintest", - Email: "login@test.com", - Password: "Password123", - FirstName: "Test", - LastName: "User", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "") - testutil.AssertStatusCode(t, w, http.StatusCreated) - - t.Run("successful login with username", func(t *testing.T) { - req := requests.LoginRequest{ - Username: "logintest", - Password: "Password123", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusOK) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - - testutil.AssertJSONFieldExists(t, response, "token") - testutil.AssertJSONFieldExists(t, response, "user") - - user := response["user"].(map[string]interface{}) - assert.Equal(t, "logintest", user["username"]) - assert.Equal(t, "login@test.com", user["email"]) - }) - - t.Run("successful login with email", func(t *testing.T) { - req := requests.LoginRequest{ - Username: "login@test.com", // Using email as username - Password: "Password123", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusOK) - }) - - t.Run("login with wrong password", func(t *testing.T) { - req := requests.LoginRequest{ - Username: "logintest", - Password: "wrongpassword", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusUnauthorized) - - response := testutil.ParseJSON(t, w.Body.Bytes()) - assert.Contains(t, response["error"], "Invalid credentials") - }) - - t.Run("login with non-existent user", func(t *testing.T) { - req := requests.LoginRequest{ - Username: "nonexistent", - Password: "Password123", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusUnauthorized) - }) - - t.Run("login with missing fields", func(t *testing.T) { - req := map[string]string{ - "username": "logintest", - // Missing password - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - func TestAuthHandler_CurrentUser(t *testing.T) { - handler, e, userRepo := setupAuthHandler(t) + handler, e, _ := setupAuthHandler(t) db := testutil.SetupTestDB(t) - user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "Password123") + user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "") user.FirstName = "Test" user.LastName = "User" - userRepo.Update(user) + // Use the userRepo from setupAuthHandler's DB, but since we need the user + // in the same DB we re-create it there. + db2 := testutil.SetupTestDB(t) + user2 := testutil.CreateTestUser(t, db2, "metest2", "me2@test.com", "") + user2.FirstName = "Test" + user2.LastName = "User" + userRepo2 := repositories.NewUserRepository(db2) + require.NoError(t, userRepo2.Update(user2)) + + // Build handler against db2 + cfg := &config.Config{} + authService2 := services.NewAuthService(userRepo2, cfg) + handler2 := NewAuthHandler(authService2, nil, nil) - // Set up route with mock auth middleware authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.GET("/me/", handler.CurrentUser) + authGroup.Use(testutil.MockAuthMiddleware(user2)) + authGroup.GET("/me/", handler2.CurrentUser) + + _ = handler // avoid unused t.Run("get current user", func(t *testing.T) { w := testutil.MakeRequest(e, "GET", "/api/auth/me/", nil, "test-token") @@ -242,23 +74,26 @@ func TestAuthHandler_CurrentUser(t *testing.T) { err := json.Unmarshal(w.Body.Bytes(), &response) require.NoError(t, err) - assert.Equal(t, "metest", response["username"]) - assert.Equal(t, "me@test.com", response["email"]) + assert.Equal(t, "metest2", response["username"]) + assert.Equal(t, "me2@test.com", response["email"]) }) } func TestAuthHandler_UpdateProfile(t *testing.T) { - handler, e, userRepo := setupAuthHandler(t) - db := testutil.SetupTestDB(t) - user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "Password123") - userRepo.Update(user) + userRepo := repositories.NewUserRepository(db) + cfg := &config.Config{} + authService := services.NewAuthService(userRepo, cfg) + handler := NewAuthHandler(authService, nil, nil) + e := testutil.SetupTestRouter() + + user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "") authGroup := e.Group("/api/auth") authGroup.Use(testutil.MockAuthMiddleware(user)) authGroup.PUT("/profile/", handler.UpdateProfile) - t.Run("update profile", func(t *testing.T) { + t.Run("update first and last name", func(t *testing.T) { firstName := "Updated" lastName := "Name" req := requests.UpdateProfileRequest{ @@ -278,130 +113,3 @@ func TestAuthHandler_UpdateProfile(t *testing.T) { assert.Equal(t, "Name", response["last_name"]) }) } - -func TestAuthHandler_ForgotPassword(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/register/", handler.Register) - e.POST("/api/auth/forgot-password/", handler.ForgotPassword) - - // Create a test user - registerReq := requests.RegisterRequest{ - Username: "forgottest", - Email: "forgot@test.com", - Password: "Password123", - } - testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "") - - t.Run("forgot password with valid email", func(t *testing.T) { - req := requests.ForgotPasswordRequest{ - Email: "forgot@test.com", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "") - - // Always returns 200 to prevent email enumeration - testutil.AssertStatusCode(t, w, http.StatusOK) - - response := testutil.ParseJSON(t, w.Body.Bytes()) - testutil.AssertJSONFieldExists(t, response, "message") - }) - - t.Run("forgot password with invalid email", func(t *testing.T) { - req := requests.ForgotPasswordRequest{ - Email: "nonexistent@test.com", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "") - - // Still returns 200 to prevent email enumeration - testutil.AssertStatusCode(t, w, http.StatusOK) - }) -} - -func TestAuthHandler_Logout(t *testing.T) { - handler, e, userRepo := setupAuthHandler(t) - - db := testutil.SetupTestDB(t) - user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "Password123") - userRepo.Update(user) - - authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.POST("/logout/", handler.Logout) - - t.Run("successful logout", func(t *testing.T) { - w := testutil.MakeRequest(e, "POST", "/api/auth/logout/", nil, "test-token") - - testutil.AssertStatusCode(t, w, http.StatusOK) - - response := testutil.ParseJSON(t, w.Body.Bytes()) - assert.Contains(t, response["message"], "Logged out successfully") - }) -} - -func TestAuthHandler_JSONResponses(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/register/", handler.Register) - e.POST("/api/auth/login/", handler.Login) - - t.Run("register response has correct JSON structure", func(t *testing.T) { - req := requests.RegisterRequest{ - Username: "jsontest", - Email: "json@test.com", - Password: "Password123", - FirstName: "JSON", - LastName: "Test", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusCreated) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - - // Verify top-level structure - assert.Contains(t, response, "token") - assert.Contains(t, response, "user") - assert.Contains(t, response, "message") - - // Verify token is not empty - assert.NotEmpty(t, response["token"]) - - // Verify user structure - user := response["user"].(map[string]interface{}) - assert.Contains(t, user, "id") - assert.Contains(t, user, "username") - assert.Contains(t, user, "email") - assert.Contains(t, user, "first_name") - assert.Contains(t, user, "last_name") - assert.Contains(t, user, "is_active") - assert.Contains(t, user, "date_joined") - - // Verify types - assert.IsType(t, float64(0), user["id"]) // JSON numbers are float64 - assert.IsType(t, "", user["username"]) - assert.IsType(t, "", user["email"]) - assert.IsType(t, true, user["is_active"]) - }) - - t.Run("error response has correct JSON structure", func(t *testing.T) { - req := map[string]string{ - "username": "test", - } - - w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "") - - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - - assert.Contains(t, response, "error") - assert.IsType(t, "", response["error"]) - }) -} diff --git a/internal/handlers/handler_coverage_test.go b/internal/handlers/handler_coverage_test.go index 3e64bd1..e2f689f 100644 --- a/internal/handlers/handler_coverage_test.go +++ b/internal/handlers/handler_coverage_test.go @@ -506,232 +506,6 @@ func TestTaskHandler_CreateCompletion_NoTaskID(t *testing.T) { }) } -// ============================================================================= -// Auth Handler - Additional Coverage -// ============================================================================= - -func TestAuthHandler_AppleSignIn_NotConfigured(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/apple-sign-in/", handler.AppleSignIn) - - t.Run("returns 500 when apple auth not configured", func(t *testing.T) { - req := map[string]interface{}{ - "id_token": "fake-token", - "user_id": "fake-user-id", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/apple-sign-in/", req, "") - testutil.AssertStatusCode(t, w, http.StatusInternalServerError) - }) - - t.Run("missing identity_token returns 400", func(t *testing.T) { - req := map[string]interface{}{} - w := testutil.MakeRequest(e, "POST", "/api/auth/apple-sign-in/", req, "") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - -func TestAuthHandler_GoogleSignIn_NotConfigured(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/google-sign-in/", handler.GoogleSignIn) - - t.Run("returns 500 when google auth not configured", func(t *testing.T) { - req := map[string]interface{}{ - "id_token": "fake-token", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/google-sign-in/", req, "") - testutil.AssertStatusCode(t, w, http.StatusInternalServerError) - }) - - t.Run("missing id_token returns 400", func(t *testing.T) { - req := map[string]interface{}{} - w := testutil.MakeRequest(e, "POST", "/api/auth/google-sign-in/", req, "") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - -// setupAuthHandlerWithDB is like setupAuthHandler but also returns the underlying *gorm.DB -// for tests that need to create records like ConfirmationCode directly. -func setupAuthHandlerWithDB(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{ - SecretKey: "test-secret-key", - PasswordResetExpiry: 15 * time.Minute, - ConfirmationExpiry: 24 * time.Hour, - MaxPasswordResetRate: 3, - }, - } - authService := services.NewAuthService(userRepo, cfg) - handler := NewAuthHandler(authService, nil, nil) - e := testutil.SetupTestRouter() - return handler, e, db -} - -func TestAuthHandler_VerifyEmail(t *testing.T) { - handler, e, db := setupAuthHandlerWithDB(t) - - user := testutil.CreateTestUser(t, db, "verifytest", "verify@test.com", "Password123") - - // Create confirmation code - confirmCode := &models.ConfirmationCode{ - UserID: user.ID, - Code: "123456", - ExpiresAt: time.Now().Add(24 * time.Hour), - IsUsed: false, - } - require.NoError(t, db.Create(confirmCode).Error) - - authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.POST("/verify-email/", handler.VerifyEmail) - - t.Run("successful verification", func(t *testing.T) { - req := requests.VerifyEmailRequest{ - Code: "123456", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token") - testutil.AssertStatusCode(t, w, http.StatusOK) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - assert.Equal(t, true, response["verified"]) - }) - - t.Run("wrong code returns error", func(t *testing.T) { - req := requests.VerifyEmailRequest{ - Code: "999999", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token") - // Code already used or wrong code - assert.True(t, w.Code == http.StatusBadRequest || w.Code == http.StatusNotFound, - "expected 400 or 404, got %d", w.Code) - }) - - t.Run("missing code returns 400", func(t *testing.T) { - req := map[string]interface{}{} - w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - -func TestAuthHandler_ResendVerification(t *testing.T) { - handler, e, db := setupAuthHandlerWithDB(t) - - user := testutil.CreateTestUser(t, db, "resendtest", "resend@test.com", "Password123") - - authGroup := e.Group("/api/auth") - authGroup.Use(testutil.MockAuthMiddleware(user)) - authGroup.POST("/resend-verification/", handler.ResendVerification) - - t.Run("successful resend", func(t *testing.T) { - w := testutil.MakeRequest(e, "POST", "/api/auth/resend-verification/", nil, "test-token") - testutil.AssertStatusCode(t, w, http.StatusOK) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - assert.Contains(t, response, "message") - }) -} - -func TestAuthHandler_RefreshToken(t *testing.T) { - handler, e, db := setupAuthHandlerWithDB(t) - - user := testutil.CreateTestUser(t, db, "refreshtest", "refresh@test.com", "Password123") - - // Create auth token and use its actual key in the middleware - authToken := testutil.CreateTestToken(t, db, user.ID) - - authGroup := e.Group("/api/auth") - authGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - c.Set("auth_user", user) - c.Set("auth_token", authToken.Plaintext) // raw token — repo hashes for lookup (audit C1) - return next(c) - } - }) - authGroup.POST("/refresh/", handler.RefreshToken) - - t.Run("successful refresh", func(t *testing.T) { - w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Plaintext) - testutil.AssertStatusCode(t, w, http.StatusOK) - - var response map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &response) - require.NoError(t, err) - assert.Contains(t, response, "token") - }) -} - -func TestAuthHandler_VerifyResetCode(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/register/", handler.Register) - e.POST("/api/auth/verify-reset-code/", handler.VerifyResetCode) - - t.Run("invalid code returns error", func(t *testing.T) { - req := requests.VerifyResetCodeRequest{ - Email: "nonexistent@test.com", - Code: "999999", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/verify-reset-code/", req, "") - // Should not be 200 since no valid code exists - assert.NotEqual(t, http.StatusOK, w.Code) - }) - - t.Run("missing fields returns 400", func(t *testing.T) { - req := map[string]interface{}{} - w := testutil.MakeRequest(e, "POST", "/api/auth/verify-reset-code/", req, "") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - -func TestAuthHandler_ResetPassword(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/reset-password/", handler.ResetPassword) - - t.Run("invalid reset token returns error", func(t *testing.T) { - req := requests.ResetPasswordRequest{ - ResetToken: "invalid-token", - NewPassword: "NewPassword123", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "") - assert.NotEqual(t, http.StatusOK, w.Code) - }) - - t.Run("missing fields returns 400", func(t *testing.T) { - req := map[string]interface{}{} - w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) - - t.Run("short password returns 400", func(t *testing.T) { - req := requests.ResetPasswordRequest{ - ResetToken: "some-token", - NewPassword: "short", - } - w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - -func TestAuthHandler_ForgotPassword_MissingEmail(t *testing.T) { - handler, e, _ := setupAuthHandler(t) - - e.POST("/api/auth/forgot-password/", handler.ForgotPassword) - - t.Run("missing email returns 400", func(t *testing.T) { - req := map[string]interface{}{} - w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "") - testutil.AssertStatusCode(t, w, http.StatusBadRequest) - }) -} - // ============================================================================= // Residence Handler - Additional Error Paths // ============================================================================= diff --git a/internal/integration/contract_test.go b/internal/integration/contract_test.go index 706357f..8932e13 100644 --- a/internal/integration/contract_test.go +++ b/internal/integration/contract_test.go @@ -190,6 +190,27 @@ func shouldSkipSpecRoute(path string) bool { if strings.HasPrefix(path, "/uploads/") || strings.HasPrefix(path, "/media/") { return true } + + // Auth routes delegated to Ory Kratos (phase 2 auth refactor). + // These endpoints are no longer served by the Go API; the spec is retained + // as documentation of the Kratos-facing contract. + kratosRoutes := map[string]bool{ + "/auth/login/": true, + "/auth/register/": true, + "/auth/logout/": true, + "/auth/refresh/": true, + "/auth/forgot-password/": true, + "/auth/verify-reset-code/": true, + "/auth/reset-password/": true, + "/auth/verify-email/": true, + "/auth/resend-verification/": true, + "/auth/apple-sign-in/": true, + "/auth/google-sign-in/": true, + } + if kratosRoutes[path] { + return true + } + return false } diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 7df6b17..2de9608 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -6,9 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,6 +19,7 @@ import ( "github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/handlers" "github.com/treytartt/honeydue-api/internal/middleware" + "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/repositories" "github.com/treytartt/honeydue-api/internal/services" "github.com/treytartt/honeydue-api/internal/testutil" @@ -105,11 +108,40 @@ type TestApp struct { TaskRepo *repositories.TaskRepository ContractorRepo *repositories.ContractorRepository AuthService *services.AuthService + // tokenStore maps fake token strings to users for the test-auth middleware. + tokenStore map[string]*models.User + tokenStoreMu sync.RWMutex +} + +// fakeAuthMiddleware replaces the real Kratos middleware in integration tests. +// It looks up the "Authorization: Token " value in app.tokenStore. +func (app *TestApp) fakeAuthMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ah := c.Request().Header.Get("Authorization") + if ah == "" { + return apperrors.Unauthorized("error.not_authenticated") + } + tok := ah + if len(ah) > 6 && ah[:6] == "Token " { + tok = ah[6:] + } else if len(ah) > 7 && ah[:7] == "Bearer " { + tok = ah[7:] + } + app.tokenStoreMu.RLock() + user, ok := app.tokenStore[tok] + app.tokenStoreMu.RUnlock() + if !ok { + return apperrors.Unauthorized("error.not_authenticated") + } + c.Set("auth_user", user) + c.Set("auth_token", tok) + return next(c) + } + } } func setupIntegrationTest(t *testing.T) *TestApp { - // Echo does not need test mode - db := testutil.SetupTestDB(t) testutil.SeedLookupData(t, db) @@ -122,10 +154,7 @@ func setupIntegrationTest(t *testing.T) *TestApp { // Create config cfg := &config.Config{ Security: config.SecurityConfig{ - SecretKey: "test-secret-key-for-integration-tests", - PasswordResetExpiry: 15 * time.Minute, - ConfirmationExpiry: 24 * time.Hour, - MaxPasswordResetRate: 3, + SecretKey: "test-secret-key-for-integration-tests", }, } @@ -141,28 +170,33 @@ func setupIntegrationTest(t *testing.T) *TestApp { taskHandler := handlers.NewTaskHandler(taskService, nil) contractorHandler := handlers.NewContractorHandler(contractorService) - // Create router with real middleware - e := echo.New() + app := &TestApp{ + DB: db, + Router: echo.New(), + AuthHandler: authHandler, + ResidenceHandler: residenceHandler, + TaskHandler: taskHandler, + ContractorHandler: contractorHandler, + UserRepo: userRepo, + ResidenceRepo: residenceRepo, + TaskRepo: taskRepo, + ContractorRepo: contractorRepo, + AuthService: authService, + tokenStore: make(map[string]*models.User), + } + + e := app.Router e.Validator = validator.NewCustomValidator() e.HTTPErrorHandler = apperrors.HTTPErrorHandler - // Add timezone middleware globally so X-Timezone header is processed + // Timezone middleware processes X-Timezone header e.Use(middleware.TimezoneMiddleware()) - // Public routes - auth := e.Group("/api/auth") - { - auth.POST("/register", authHandler.Register) - auth.POST("/login", authHandler.Login) - } - - // Protected routes - use AuthMiddleware without Redis cache for testing - authMiddleware := middleware.NewAuthMiddleware(db, nil) + // Protected routes — guarded by the fake token middleware api := e.Group("/api") - api.Use(authMiddleware.TokenAuth()) + api.Use(app.fakeAuthMiddleware()) { api.GET("/auth/me", authHandler.CurrentUser) - api.POST("/auth/logout", authHandler.Logout) residences := api.Group("/residences") { @@ -216,19 +250,7 @@ func setupIntegrationTest(t *testing.T) *TestApp { api.GET("/contractors/by-residence/:residence_id", contractorHandler.ListContractorsByResidence) } - return &TestApp{ - DB: db, - Router: e, - AuthHandler: authHandler, - ResidenceHandler: residenceHandler, - TaskHandler: taskHandler, - ContractorHandler: contractorHandler, - UserRepo: userRepo, - ResidenceRepo: residenceRepo, - TaskRepo: taskRepo, - ContractorRepo: contractorRepo, - AuthService: authService, - } + return app } // Helper to make authenticated requests @@ -251,156 +273,16 @@ func (app *TestApp) makeAuthenticatedRequest(t *testing.T, method, path string, return w } -// Helper to register and login a user, returns token -func (app *TestApp) registerAndLogin(t *testing.T, username, email, password string) string { - // Register - registerBody := map[string]string{ - "username": username, - "email": email, - "password": password, - } - w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "") - require.Equal(t, http.StatusCreated, w.Code) - - // Login - loginBody := map[string]string{ - "username": username, - "password": password, - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "") - require.Equal(t, http.StatusOK, w.Code) - - var loginResp map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &loginResp) - require.NoError(t, err) - - return loginResp["token"].(string) -} - -// ============ Authentication Flow Tests ============ - -func TestIntegration_AuthenticationFlow(t *testing.T) { - app := setupIntegrationTest(t) - - // 1. Register a new user - registerBody := map[string]string{ - "username": "testuser", - "email": "test@example.com", - "password": "SecurePass123!", - } - w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "") - assert.Equal(t, http.StatusCreated, w.Code) - - var registerResp map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), ®isterResp) - require.NoError(t, err) - assert.NotEmpty(t, registerResp["token"]) - assert.NotNil(t, registerResp["user"]) - - // 2. Login with the same credentials - loginBody := map[string]string{ - "username": "testuser", - "password": "SecurePass123!", - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "") - assert.Equal(t, http.StatusOK, w.Code) - - var loginResp map[string]interface{} - err = json.Unmarshal(w.Body.Bytes(), &loginResp) - require.NoError(t, err) - token := loginResp["token"].(string) - assert.NotEmpty(t, token) - - // 3. Get current user with token - w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token) - assert.Equal(t, http.StatusOK, w.Code) - - var meResp map[string]interface{} - err = json.Unmarshal(w.Body.Bytes(), &meResp) - require.NoError(t, err) - assert.Equal(t, "testuser", meResp["username"]) - assert.Equal(t, "test@example.com", meResp["email"]) - - // 4. Access protected route without token should fail - w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, "") - assert.Equal(t, http.StatusUnauthorized, w.Code) - - // 5. Access protected route with invalid token should fail - w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, "invalid-token") - assert.Equal(t, http.StatusUnauthorized, w.Code) - - // 6. Logout - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/logout", nil, token) - assert.Equal(t, http.StatusOK, w.Code) -} - -func TestIntegration_RegistrationValidation(t *testing.T) { - app := setupIntegrationTest(t) - - tests := []struct { - name string - body map[string]string - expectedStatus int - }{ - { - name: "missing username", - body: map[string]string{"email": "test@example.com", "password": "pass123"}, - expectedStatus: http.StatusBadRequest, - }, - { - name: "missing email", - body: map[string]string{"username": "testuser", "password": "pass123"}, - expectedStatus: http.StatusBadRequest, - }, - { - name: "missing password", - body: map[string]string{"username": "testuser", "email": "test@example.com"}, - expectedStatus: http.StatusBadRequest, - }, - { - name: "invalid email", - body: map[string]string{"username": "testuser", "email": "invalid", "password": "pass123"}, - expectedStatus: http.StatusBadRequest, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", tt.body, "") - assert.Equal(t, tt.expectedStatus, w.Code) - }) - } -} - -func TestIntegration_DuplicateRegistration(t *testing.T) { - app := setupIntegrationTest(t) - - // Register first user (password must be >= 8 chars) - registerBody := map[string]string{ - "username": "testuser", - "email": "test@example.com", - "password": "Password123", - } - w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "") - assert.Equal(t, http.StatusCreated, w.Code) - - // Try to register with same username - returns 409 (Conflict) - registerBody2 := map[string]string{ - "username": "testuser", - "email": "different@example.com", - "password": "Password123", - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody2, "") - assert.Equal(t, http.StatusConflict, w.Code) - - // Try to register with same email - returns 409 (Conflict) - registerBody3 := map[string]string{ - "username": "differentuser", - "email": "test@example.com", - "password": "Password123", - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody3, "") - assert.Equal(t, http.StatusConflict, w.Code) +// registerAndLogin creates a user directly in the DB and returns a synthetic token +// that the fake auth middleware will accept. No HTTP register/login endpoints are called. +func (app *TestApp) registerAndLogin(t *testing.T, username, email, _ string) string { + t.Helper() + user := testutil.CreateTestUser(t, app.DB, username, email, "") + tok := uuid.NewString() + app.tokenStoreMu.Lock() + app.tokenStore[tok] = user + app.tokenStoreMu.Unlock() + return tok } // ============ Residence Flow Tests ============ @@ -827,48 +709,16 @@ func TestIntegration_ResponseStructure(t *testing.T) { func TestIntegration_ComprehensiveE2E(t *testing.T) { app := setupIntegrationTest(t) - // ============ Phase 1: Authentication ============ - t.Log("Phase 1: Testing authentication flow") + // ============ Phase 1: User Setup ============ + t.Log("Phase 1: Setting up test user") - // Register new user - registerBody := map[string]string{ - "username": "e2e_testuser", - "email": "e2e@example.com", - "password": "SecurePass123!", - } - w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "") - require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed") - - var registerResp map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), ®isterResp) - require.NoError(t, err) - assert.NotEmpty(t, registerResp["token"], "Registration should return token") - assert.NotNil(t, registerResp["user"], "Registration should return user") - - // Verify login with same credentials - loginBody := map[string]string{ - "username": "e2e_testuser", - "password": "SecurePass123!", - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "") - require.Equal(t, http.StatusOK, w.Code, "Login should succeed") - - var loginResp map[string]interface{} - err = json.Unmarshal(w.Body.Bytes(), &loginResp) - require.NoError(t, err) - token := loginResp["token"].(string) - assert.NotEmpty(t, token, "Login should return token") + token := app.registerAndLogin(t, "e2e_testuser", "e2e@example.com", "") // Verify authenticated access - w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token) + w := app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token) require.Equal(t, http.StatusOK, w.Code, "Should access protected route with valid token") - var meResp map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &meResp) - assert.Equal(t, "e2e_testuser", meResp["username"]) - assert.Equal(t, "e2e@example.com", meResp["email"]) - - t.Log("✓ Authentication flow verified") + t.Log("✓ User setup verified") // ============ Phase 2: Create 5 Residences ============ t.Log("Phase 2: Creating 5 residences") @@ -1244,29 +1094,9 @@ func TestIntegration_ComprehensiveE2E(t *testing.T) { t.Logf("✓ All %d visible tasks verified in correct columns by ID", expectedVisibleTasks) // ============ Phase 9: Create User B ============ - t.Log("Phase 9: Creating User B and verifying login") + t.Log("Phase 9: Creating User B") - // Register User B - registerBodyB := map[string]string{ - "username": "e2e_userb", - "email": "e2e_userb@example.com", - "password": "SecurePass456!", - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBodyB, "") - require.Equal(t, http.StatusCreated, w.Code, "User B registration should succeed") - - // Login as User B - loginBodyB := map[string]string{ - "username": "e2e_userb", - "password": "SecurePass456!", - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBodyB, "") - require.Equal(t, http.StatusOK, w.Code, "User B login should succeed") - - var loginRespB map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &loginRespB) - tokenB := loginRespB["token"].(string) - assert.NotEmpty(t, tokenB, "User B should have a token") + tokenB := app.registerAndLogin(t, "e2e_userb", "e2e_userb@example.com", "") // Verify User B can access their own profile w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB) @@ -1592,8 +1422,6 @@ func formatID(id float64) string { // setupContractorTest sets up a test environment including contractor routes func setupContractorTest(t *testing.T) *TestApp { - // Echo does not need test mode - db := testutil.SetupTestDB(t) testutil.SeedLookupData(t, db) @@ -1606,10 +1434,7 @@ func setupContractorTest(t *testing.T) *TestApp { // Create config cfg := &config.Config{ Security: config.SecurityConfig{ - SecretKey: "test-secret-key-for-integration-tests", - PasswordResetExpiry: 15 * time.Minute, - ConfirmationExpiry: 24 * time.Hour, - MaxPasswordResetRate: 3, + SecretKey: "test-secret-key-for-integration-tests", }, } @@ -1625,29 +1450,32 @@ func setupContractorTest(t *testing.T) *TestApp { taskHandler := handlers.NewTaskHandler(taskService, nil) contractorHandler := handlers.NewContractorHandler(contractorService) - // Create router with real middleware - e := echo.New() + app := &TestApp{ + DB: db, + Router: echo.New(), + AuthHandler: authHandler, + ResidenceHandler: residenceHandler, + TaskHandler: taskHandler, + ContractorHandler: contractorHandler, + UserRepo: userRepo, + ResidenceRepo: residenceRepo, + TaskRepo: taskRepo, + ContractorRepo: contractorRepo, + AuthService: authService, + tokenStore: make(map[string]*models.User), + } + + e := app.Router e.Validator = validator.NewCustomValidator() e.HTTPErrorHandler = apperrors.HTTPErrorHandler - // Add timezone middleware globally so X-Timezone header is processed + // Timezone middleware e.Use(middleware.TimezoneMiddleware()) - // Public routes - auth := e.Group("/api/auth") - { - auth.POST("/register", authHandler.Register) - auth.POST("/login", authHandler.Login) - } - // Protected routes - authMiddleware := middleware.NewAuthMiddleware(db, nil) api := e.Group("/api") - api.Use(authMiddleware.TokenAuth()) + api.Use(app.fakeAuthMiddleware()) { - api.GET("/auth/me", authHandler.CurrentUser) - api.POST("/auth/logout", authHandler.Logout) - residences := api.Group("/residences") { residences.GET("", residenceHandler.ListResidences) @@ -1680,19 +1508,7 @@ func setupContractorTest(t *testing.T) *TestApp { } } - return &TestApp{ - DB: db, - Router: e, - AuthHandler: authHandler, - ResidenceHandler: residenceHandler, - TaskHandler: taskHandler, - ContractorHandler: contractorHandler, - UserRepo: userRepo, - ResidenceRepo: residenceRepo, - TaskRepo: taskRepo, - ContractorRepo: contractorRepo, - AuthService: authService, - } + return app } // ============ Test 1: Recurring Task Lifecycle ============ @@ -2045,12 +1861,12 @@ func TestIntegration_MultiUserSharing(t *testing.T) { // Phase 9: Remove User B from residence 3 t.Log("Phase 9: Remove User B from residence 3") - // Get User B's ID - w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB) - require.Equal(t, http.StatusOK, w.Code) - var userBInfo map[string]interface{} - json.Unmarshal(w.Body.Bytes(), &userBInfo) - userBID := uint(userBInfo["id"].(float64)) + // Get User B's ID from the token store + app.tokenStoreMu.RLock() + userBModel := app.tokenStore[tokenB] + app.tokenStoreMu.RUnlock() + require.NotNil(t, userBModel, "User B should be in token store") + userBID := userBModel.ID // Remove User B from residence 3 w = app.makeAuthenticatedRequest(t, "DELETE", fmt.Sprintf("/api/residences/%d/users/%d", residenceIDs[2], userBID), nil, tokenA) diff --git a/internal/integration/security_regression_test.go b/internal/integration/security_regression_test.go index 12360e2..43a7c66 100644 --- a/internal/integration/security_regression_test.go +++ b/internal/integration/security_regression_test.go @@ -6,9 +6,11 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -35,6 +37,48 @@ type SecurityTestApp struct { Router *echo.Echo SubscriptionService *services.SubscriptionService SubscriptionRepo *repositories.SubscriptionRepository + tokenStore map[string]*models.User + tokenStoreMu sync.RWMutex +} + +// fakeAuthMiddleware returns an Echo middleware that authenticates requests using +// the in-process tokenStore instead of calling the real Kratos session endpoint. +func (app *SecurityTestApp) fakeAuthMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ah := c.Request().Header.Get("Authorization") + if ah == "" { + return apperrors.Unauthorized("error.not_authenticated") + } + tok := ah + if len(ah) > 6 && ah[:6] == "Token " { + tok = ah[6:] + } else if len(ah) > 7 && ah[:7] == "Bearer " { + tok = ah[7:] + } + app.tokenStoreMu.RLock() + user, ok := app.tokenStore[tok] + app.tokenStoreMu.RUnlock() + if !ok { + return apperrors.Unauthorized("error.not_authenticated") + } + c.Set("auth_user", user) + c.Set("auth_token", tok) + return next(c) + } + } +} + +// registerAndLoginSec creates a user directly in the DB and returns a fake token +// that the fakeAuthMiddleware will accept. No HTTP register/login calls are made. +func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, _ string) (string, uint) { + t.Helper() + user := testutil.CreateTestUser(t, app.DB, username, email, "") + tok := uuid.NewString() + app.tokenStoreMu.Lock() + app.tokenStore[tok] = user + app.tokenStoreMu.Unlock() + return tok, user.ID } func setupSecurityTest(t *testing.T) *SecurityTestApp { @@ -78,27 +122,25 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp { notificationHandler := handlers.NewNotificationHandler(notificationService) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil) - // Create router with real middleware + app := &SecurityTestApp{ + DB: db, + SubscriptionService: subscriptionService, + SubscriptionRepo: subscriptionRepo, + tokenStore: make(map[string]*models.User), + } + + // Create router with fake auth middleware e := echo.New() e.Validator = validator.NewCustomValidator() e.HTTPErrorHandler = apperrors.HTTPErrorHandler e.Use(middleware.TimezoneMiddleware()) - // Public routes - auth := e.Group("/api/auth") - { - auth.POST("/register", authHandler.Register) - auth.POST("/login", authHandler.Login) - } - // Protected routes - authMiddleware := middleware.NewAuthMiddleware(db, nil) api := e.Group("/api") - api.Use(authMiddleware.TokenAuth()) + api.Use(app.fakeAuthMiddleware()) { api.GET("/auth/me", authHandler.CurrentUser) - api.POST("/auth/logout", authHandler.Logout) residences := api.Group("/residences") { @@ -146,42 +188,8 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp { } } - return &SecurityTestApp{ - DB: db, - Router: e, - SubscriptionService: subscriptionService, - SubscriptionRepo: subscriptionRepo, - } -} - -// registerAndLoginSec registers and logs in a user, returns token and user ID. -func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, password string) (string, uint) { - // Register - registerBody := map[string]string{ - "username": username, - "email": email, - "password": password, - } - w := app.makeAuthReq(t, "POST", "/api/auth/register", registerBody, "") - require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed for %s", username) - - // Login - loginBody := map[string]string{ - "username": username, - "password": password, - } - w = app.makeAuthReq(t, "POST", "/api/auth/login", loginBody, "") - require.Equal(t, http.StatusOK, w.Code, "Login should succeed for %s", username) - - var loginResp map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &loginResp) - require.NoError(t, err) - - token := loginResp["token"].(string) - userMap := loginResp["user"].(map[string]interface{}) - userID := uint(userMap["id"].(float64)) - - return token, userID + app.Router = e + return app } // makeAuthReq creates and sends an HTTP request through the router. diff --git a/internal/integration/subscription_is_free_test.go b/internal/integration/subscription_is_free_test.go index 78d2584..2a8f49d 100644 --- a/internal/integration/subscription_is_free_test.go +++ b/internal/integration/subscription_is_free_test.go @@ -6,12 +6,15 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "sync" "testing" "time" + "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gorm.io/gorm" "github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/config" @@ -22,7 +25,6 @@ import ( "github.com/treytartt/honeydue-api/internal/services" "github.com/treytartt/honeydue-api/internal/testutil" "github.com/treytartt/honeydue-api/internal/validator" - "gorm.io/gorm" ) // SubscriptionTestApp holds components for subscription integration testing @@ -31,11 +33,51 @@ type SubscriptionTestApp struct { Router *echo.Echo SubscriptionService *services.SubscriptionService SubscriptionRepo *repositories.SubscriptionRepository + tokenStore map[string]*models.User + tokenStoreMu sync.RWMutex +} + +// fakeAuthMiddleware returns an Echo middleware that authenticates requests using +// the in-process tokenStore instead of calling the real Kratos session endpoint. +func (app *SubscriptionTestApp) fakeAuthMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ah := c.Request().Header.Get("Authorization") + if ah == "" { + return apperrors.Unauthorized("error.not_authenticated") + } + tok := ah + if len(ah) > 6 && ah[:6] == "Token " { + tok = ah[6:] + } else if len(ah) > 7 && ah[:7] == "Bearer " { + tok = ah[7:] + } + app.tokenStoreMu.RLock() + user, ok := app.tokenStore[tok] + app.tokenStoreMu.RUnlock() + if !ok { + return apperrors.Unauthorized("error.not_authenticated") + } + c.Set("auth_user", user) + c.Set("auth_token", tok) + return next(c) + } + } +} + +// registerAndLogin creates a user directly in the DB and returns a fake token +// and user ID. No HTTP register/login calls are made. +func (app *SubscriptionTestApp) registerAndLogin(t *testing.T, username, email, _ string) (string, uint) { + t.Helper() + user := testutil.CreateTestUser(t, app.DB, username, email, "") + tok := uuid.NewString() + app.tokenStoreMu.Lock() + app.tokenStore[tok] = user + app.tokenStoreMu.Unlock() + return tok, user.ID } func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp { - // Echo does not need test mode - db := testutil.SetupTestDB(t) testutil.SeedLookupData(t, db) @@ -67,22 +109,23 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp { residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil) - // Create router + app := &SubscriptionTestApp{ + DB: db, + SubscriptionService: subscriptionService, + SubscriptionRepo: subscriptionRepo, + tokenStore: make(map[string]*models.User), + } + + // Create router with fake auth middleware e := echo.New() e.Validator = validator.NewCustomValidator() e.HTTPErrorHandler = apperrors.HTTPErrorHandler - // Public routes - auth := e.Group("/api/auth") - { - auth.POST("/register", authHandler.Register) - auth.POST("/login", authHandler.Login) - } + e.Use(middleware.TimezoneMiddleware()) // Protected routes - authMiddleware := middleware.NewAuthMiddleware(db, nil) api := e.Group("/api") - api.Use(authMiddleware.TokenAuth()) + api.Use(app.fakeAuthMiddleware()) { api.GET("/auth/me", authHandler.CurrentUser) @@ -98,12 +141,8 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp { } } - return &SubscriptionTestApp{ - DB: db, - Router: e, - SubscriptionService: subscriptionService, - SubscriptionRepo: subscriptionRepo, - } + app.Router = e + return app } // Helper to make authenticated requests @@ -129,36 +168,6 @@ func (app *SubscriptionTestApp) makeAuthenticatedRequest(t *testing.T, method, p return w } -// Helper to register and login a user, returns token and user ID -func (app *SubscriptionTestApp) registerAndLogin(t *testing.T, username, email, password string) (string, uint) { - // Register - registerBody := map[string]string{ - "username": username, - "email": email, - "password": password, - } - w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "") - require.Equal(t, http.StatusCreated, w.Code) - - // Login - loginBody := map[string]string{ - "username": username, - "password": password, - } - w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "") - require.Equal(t, http.StatusOK, w.Code) - - var loginResp map[string]interface{} - err := json.Unmarshal(w.Body.Bytes(), &loginResp) - require.NoError(t, err) - - token := loginResp["token"].(string) - userMap := loginResp["user"].(map[string]interface{}) - userID := uint(userMap["id"].(float64)) - - return token, userID -} - // TestIntegration_IsFreeBypassesLimitations tests that users with IsFree=true // see limitations_enabled=false regardless of global settings func TestIntegration_IsFreeBypassesLimitations(t *testing.T) { diff --git a/internal/kratos/client.go b/internal/kratos/client.go new file mode 100644 index 0000000..aacfb19 --- /dev/null +++ b/internal/kratos/client.go @@ -0,0 +1,107 @@ +// Package kratos is a thin client for the Ory Kratos public API. honeyDue +// delegates all identity concerns (credentials, sessions, verification, +// recovery, social sign-in) to Kratos; this client only validates sessions. +package kratos + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +// ErrUnauthorized is returned by Whoami when the session is missing, invalid, +// or inactive — the caller should respond 401. +var ErrUnauthorized = errors.New("kratos: session invalid or inactive") + +// Client talks to the Ory Kratos public API. +type Client struct { + publicURL string + http *http.Client +} + +// NewClient builds a Kratos client for the given public-API base URL +// (e.g. http://kratos:4433 in-cluster). +func NewClient(publicURL string) *Client { + return &Client{ + publicURL: strings.TrimRight(publicURL, "/"), + http: &http.Client{Timeout: 5 * time.Second}, + } +} + +// Identity is the subset of a Kratos identity honeyDue consumes. It mirrors +// the identity schema in deploy-k3s/manifests/kratos/configmap.yaml. +type Identity struct { + ID string `json:"id"` // UUID — the stable identity identifier + Traits struct { + Email string `json:"email"` + Name struct { + First string `json:"first"` + Last string `json:"last"` + } `json:"name"` + } `json:"traits"` + VerifiableAddresses []struct { + Value string `json:"value"` + Verified bool `json:"verified"` + } `json:"verifiable_addresses"` +} + +// Session is a Kratos session as returned by GET /sessions/whoami. +type Session struct { + ID string `json:"id"` + Active bool `json:"active"` + Identity Identity `json:"identity"` +} + +// EmailVerified reports whether any of the identity's email addresses is +// verified — the source of truth for honeyDue's RequireVerified gate. +func (s *Session) EmailVerified() bool { + for _, a := range s.Identity.VerifiableAddresses { + if a.Verified { + return true + } + } + return false +} + +// Whoami validates a session against Kratos. Supply the mobile session token +// (sessionToken) OR the browser cookie header (cookie) — whichever is +// non-empty is forwarded to Kratos. Returns ErrUnauthorized for an invalid or +// inactive session. +func (c *Client) Whoami(ctx context.Context, sessionToken, cookie string) (*Session, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.publicURL+"/sessions/whoami", nil) + if err != nil { + return nil, err + } + if sessionToken != "" { + req.Header.Set("X-Session-Token", sessionToken) + } + if cookie != "" { + req.Header.Set("Cookie", cookie) + } + + resp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("kratos whoami: %w", err) + } + defer resp.Body.Close() + + switch { + case resp.StatusCode == http.StatusUnauthorized, resp.StatusCode == http.StatusForbidden: + return nil, ErrUnauthorized + case resp.StatusCode != http.StatusOK: + return nil, fmt.Errorf("kratos whoami: unexpected status %d", resp.StatusCode) + } + + var s Session + if err := json.NewDecoder(resp.Body).Decode(&s); err != nil { + return nil, fmt.Errorf("kratos whoami: decode: %w", err) + } + if !s.Active || s.Identity.ID == "" { + return nil, ErrUnauthorized + } + return &s, nil +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go deleted file mode 100644 index 550a62d..0000000 --- a/internal/middleware/auth.go +++ /dev/null @@ -1,438 +0,0 @@ -package middleware - -import ( - "context" - "fmt" - "strings" - "time" - - "github.com/labstack/echo/v4" - "github.com/redis/go-redis/v9" - "github.com/rs/zerolog/log" - "gorm.io/gorm" - - "github.com/treytartt/honeydue-api/internal/apperrors" - "github.com/treytartt/honeydue-api/internal/config" - "github.com/treytartt/honeydue-api/internal/models" - "github.com/treytartt/honeydue-api/internal/services" -) - -const ( - // AuthUserKey is the key used to store the authenticated user in the context - AuthUserKey = "auth_user" - // AuthTokenKey is the key used to store the token in the context - AuthTokenKey = "auth_token" - // TokenCacheTTL is the duration to cache tokens in Redis. Tokens are - // valid for DefaultTokenExpiryDays (90), and explicit logout invalidates - // the cache, so a long TTL here just means most authed requests skip the - // auth-token SQL query entirely. - TokenCacheTTL = 1 * time.Hour - // TokenCachePrefix is the prefix for token cache keys - TokenCachePrefix = "auth_token_" - // UserCacheTTL is how long full user records are cached in memory to - // avoid hitting the database on every authenticated request. Bumped from - // 30s — at 30s the trace showed a SELECT auth_user query on most warm - // requests because users aren't in cache long enough to hit twice. - UserCacheTTL = 5 * time.Minute - // UserCacheMaxSize bounds the per-pod in-memory user cache. With ~1KB - // per User struct, 5000 entries = ~5MB per pod. Older entries are - // evicted LRU before the limit is exceeded. - UserCacheMaxSize = 5000 - - // DefaultTokenExpiryDays is the default number of days before a token expires. - DefaultTokenExpiryDays = 90 -) - -// AuthMiddleware provides token authentication middleware -type AuthMiddleware struct { - db *gorm.DB - cache *services.CacheService - userCache *UserCache - tokenExpiryDays int -} - -// NewAuthMiddleware creates a new auth middleware instance -func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware { - return &AuthMiddleware{ - db: db, - cache: cache, - userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize), - tokenExpiryDays: DefaultTokenExpiryDays, - } -} - -// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration -func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware { - expiryDays := DefaultTokenExpiryDays - if cfg != nil && cfg.Security.TokenExpiryDays > 0 { - expiryDays = cfg.Security.TokenExpiryDays - } - return &AuthMiddleware{ - db: db, - cache: cache, - userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize), - tokenExpiryDays: expiryDays, - } -} - -// TokenExpiryDuration returns the token expiry duration. -func (m *AuthMiddleware) TokenExpiryDuration() time.Duration { - return time.Duration(m.tokenExpiryDays) * 24 * time.Hour -} - -// isTokenExpired checks if a token's created timestamp indicates expiry. -func (m *AuthMiddleware) isTokenExpired(created time.Time) bool { - if created.IsZero() { - return false // Legacy tokens without created time are not expired - } - return time.Since(created) > m.TokenExpiryDuration() -} - -// TokenAuth returns an Echo middleware that validates token authentication -func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - // Extract token from Authorization header - token, err := extractToken(c) - if err != nil { - return apperrors.Unauthorized("error.not_authenticated") - } - - // Try to get user from cache first (includes expiry check) - user, err := m.getUserFromCache(c.Request().Context(), token) - if err == nil && user != nil { - // Cache hit - set user in context and continue - c.Set(AuthUserKey, user) - c.Set(AuthTokenKey, token) - return next(c) - } - - // Check if the cache indicated token expiry - if err != nil && err.Error() == "token expired" { - return apperrors.Unauthorized("error.token_expired") - } - - // Cache miss - look up token in database - user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token) - if err != nil { - log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed") - return apperrors.Unauthorized("error.invalid_token") - } - - // Check token expiry - if m.isTokenExpired(authToken.Created) { - log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired") - return apperrors.Unauthorized("error.token_expired") - } - - // Cache the user ID and token creation time for future requests - if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil { - log.Warn().Err(cacheErr).Msg("Failed to cache token info") - } - - // Set user in context - c.Set(AuthUserKey, user) - c.Set(AuthTokenKey, token) - return next(c) - } - } -} - -// OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it -func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - token, err := extractToken(c) - if err != nil { - // No token or invalid format - continue without user - return next(c) - } - - // Try cache first - user, err := m.getUserFromCache(c.Request().Context(), token) - if err == nil && user != nil { - c.Set(AuthUserKey, user) - c.Set(AuthTokenKey, token) - return next(c) - } - - // Try database - user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token) - if err == nil && !m.isTokenExpired(authToken.Created) { - m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created) - c.Set(AuthUserKey, user) - c.Set(AuthTokenKey, token) - } - - return next(c) - } - } -} - -// RequireVerified returns middleware that rejects users whose email is not -// verified (audit LIVE-L19). Apply it after TokenAuth to gate sensitive -// actions — e.g. generating residence share codes — behind proof that the -// account actually controls its email address. -func (m *AuthMiddleware) RequireVerified() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - user := GetAuthUser(c) - if user == nil { - return apperrors.Unauthorized("error.not_authenticated") - } - var verified bool - err := m.db.WithContext(c.Request().Context()). - Model(&models.UserProfile{}). - Where("user_id = ?", user.ID). - Select("verified"). - Scan(&verified).Error - if err != nil { - return apperrors.Internal(err) - } - if !verified { - return apperrors.Forbidden("error.email_not_verified") - } - return next(c) - } - } -} - -// extractToken extracts the token from the Authorization header -func extractToken(c echo.Context) (string, error) { - authHeader := c.Request().Header.Get("Authorization") - if authHeader == "" { - return "", fmt.Errorf("authorization header required") - } - - // Support both "Token xxx" (Django style) and "Bearer xxx" formats - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 { - return "", fmt.Errorf("invalid authorization header format") - } - - scheme := parts[0] - token := parts[1] - - if scheme != "Token" && scheme != "Bearer" { - return "", fmt.Errorf("invalid authorization scheme: %s", scheme) - } - - if token == "" { - return "", fmt.Errorf("token is empty") - } - - return token, nil -} - -// getUserFromCache tries to get user from Redis cache, then from the -// in-memory user cache, before falling back to the database. -// Returns a "token expired" error if the cached creation time indicates expiry. -func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) { - if m.cache == nil { - return nil, fmt.Errorf("cache not available") - } - - userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token) - if err != nil { - if err == redis.Nil { - return nil, fmt.Errorf("token not in cache") - } - return nil, err - } - - // Check token expiry from cached creation time - if createdUnix > 0 { - created := time.Unix(createdUnix, 0) - if m.isTokenExpired(created) { - m.cache.InvalidateAuthToken(ctx, token) - return nil, fmt.Errorf("token expired") - } - } - - // Try in-memory user cache first to avoid a DB round-trip - if cached := m.userCache.Get(userID); cached != nil { - if !cached.IsActive { - m.cache.InvalidateAuthToken(ctx, token) - m.userCache.Invalidate(userID) - return nil, fmt.Errorf("user is inactive") - } - return cached, nil - } - - // In-memory cache miss — fetch from database - var user models.User - if err := m.db.WithContext(ctx).First(&user, userID).Error; err != nil { - // User was deleted - invalidate caches - m.cache.InvalidateAuthToken(ctx, token) - return nil, err - } - - // Check if user is active - if !user.IsActive { - m.cache.InvalidateAuthToken(ctx, token) - return nil, fmt.Errorf("user is inactive") - } - - // Store in in-memory cache for subsequent requests - m.userCache.Set(&user) - return &user, nil -} - -// getUserFromDatabaseWithToken looks up the token in the database and returns -// both the user and the auth token record (for expiry checking). The ctx is -// threaded into the GORM session so the SQL span attaches to the request trace. -// -// Uses a single JOIN query instead of GORM's Preload (which issues 2 SELECTs). -// Over a transatlantic link this saves ~110ms RTT per cache miss. -func (m *AuthMiddleware) getUserFromDatabaseWithToken(ctx context.Context, token string) (*models.User, *models.AuthToken, error) { - // Flat result row: every column from auth_user prefixed `u_`, every - // column from user_authtoken left in its native shape. Mapping to two - // structs is mechanical so we don't need a struct tag soup. - type joinedRow struct { - // AuthToken columns - Key string `gorm:"column:key"` - Created time.Time `gorm:"column:created"` - UserID uint `gorm:"column:user_id"` - // User columns (prefixed to avoid collision with UserID) - UID uint `gorm:"column:u_id"` - UUsername string `gorm:"column:u_username"` - UEmail string `gorm:"column:u_email"` - UFirstName string `gorm:"column:u_first_name"` - ULastName string `gorm:"column:u_last_name"` - UPassword string `gorm:"column:u_password"` - UIsActive bool `gorm:"column:u_is_active"` - UIsStaff bool `gorm:"column:u_is_staff"` - UIsSuper bool `gorm:"column:u_is_superuser"` - UDateJoined time.Time `gorm:"column:u_date_joined"` - ULastLogin *time.Time `gorm:"column:u_last_login"` - } - - var row joinedRow - err := m.db.WithContext(ctx). - Table("user_authtoken AS t"). - Select(` - t.key, t.created, t.user_id, - u.id AS u_id, - u.username AS u_username, - u.email AS u_email, - u.first_name AS u_first_name, - u.last_name AS u_last_name, - u.password AS u_password, - u.is_active AS u_is_active, - u.is_staff AS u_is_staff, - u.is_superuser AS u_is_superuser, - u.date_joined AS u_date_joined, - u.last_login AS u_last_login - `). - Joins("INNER JOIN auth_user u ON u.id = t.user_id"). - Where("t.key = ?", models.HashToken(token)). // audit C1: only the hash is stored - Limit(1). - Scan(&row).Error - if err != nil || row.Key == "" { - return nil, nil, fmt.Errorf("token not found") - } - - user := models.User{ - ID: row.UID, - Username: row.UUsername, - Email: row.UEmail, - FirstName: row.UFirstName, - LastName: row.ULastName, - Password: row.UPassword, - IsActive: row.UIsActive, - IsStaff: row.UIsStaff, - IsSuperuser: row.UIsSuper, - DateJoined: row.UDateJoined, - LastLogin: row.ULastLogin, - } - authToken := models.AuthToken{ - Key: row.Key, - Created: row.Created, - UserID: row.UserID, - User: user, - } - - if !user.IsActive { - return nil, nil, fmt.Errorf("user is inactive") - } - - m.userCache.Set(&user) - return &user, &authToken, nil -} - -// getUserFromDatabase looks up the token in the database and caches the -// resulting user record in memory. -// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking. -func (m *AuthMiddleware) getUserFromDatabase(ctx context.Context, token string) (*models.User, error) { - user, _, err := m.getUserFromDatabaseWithToken(ctx, token) - return user, err -} - -// cacheTokenInfo caches the user ID and token creation time for a token -func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error { - if m.cache == nil { - return nil - } - return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix()) -} - -// cacheUserID caches the user ID for a token -func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error { - if m.cache == nil { - return nil - } - return m.cache.CacheAuthToken(ctx, token, userID) -} - -// InvalidateToken removes a token from the cache -func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error { - if m.cache == nil { - return nil - } - return m.cache.InvalidateAuthToken(ctx, token) -} - -// GetAuthUser retrieves the authenticated user from the Echo context. -// Returns nil if the context value is missing or not of the expected type. -func GetAuthUser(c echo.Context) *models.User { - val := c.Get(AuthUserKey) - if val == nil { - return nil - } - user, ok := val.(*models.User) - if !ok { - return nil - } - return user -} - -// GetAuthToken retrieves the auth token from the Echo context -func GetAuthToken(c echo.Context) string { - token := c.Get(AuthTokenKey) - if token == nil { - return "" - } - tokenStr, ok := token.(string) - if !ok { - return "" - } - return tokenStr -} - -// MustGetAuthUser retrieves the authenticated user or returns error with 401 -func MustGetAuthUser(c echo.Context) (*models.User, error) { - user := GetAuthUser(c) - if user == nil { - return nil, apperrors.Unauthorized("error.not_authenticated") - } - return user, nil -} - -// truncateToken safely truncates a token string for logging. -// Returns at most the first 8 characters followed by "...". -func truncateToken(token string) string { - if len(token) > 8 { - return token[:8] + "..." - } - return token + "..." -} diff --git a/internal/middleware/auth_expiry_test.go b/internal/middleware/auth_expiry_test.go deleted file mode 100644 index 525e577..0000000 --- a/internal/middleware/auth_expiry_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" - - "github.com/treytartt/honeydue-api/internal/models" -) - -// setupTestDB creates a temporary in-memory SQLite database with the required -// tables for auth middleware tests. -func setupTestDB(t *testing.T) *gorm.DB { - t.Helper() - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - }) - require.NoError(t, err) - - err = db.AutoMigrate(&models.User{}, &models.AuthToken{}) - require.NoError(t, err) - - return db -} - -// createTestUserAndToken creates a user and an auth token, then backdates the -// token's Created timestamp by the specified number of days. -func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) { - t.Helper() - - user := &models.User{ - Username: username, - Email: username + "@test.com", - IsActive: true, - } - require.NoError(t, user.SetPassword("Password123")) - require.NoError(t, db.Create(user).Error) - - token := &models.AuthToken{ - UserID: user.ID, - } - require.NoError(t, db.Create(token).Error) - - // Backdate the token's Created timestamp after creation to bypass autoCreateTime - backdated := time.Now().UTC().AddDate(0, 0, -ageDays) - require.NoError(t, db.Model(token).Update("created", backdated).Error) - token.Created = backdated - - return user, token -} - -func TestTokenAuth_RejectsExpiredToken(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry - - m := NewAuthMiddleware(db, nil) // No Redis cache for these tests - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.token_expired") -} - -func TestTokenAuth_AcceptsValidToken(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - - // Verify user was set in context - user := GetAuthUser(c) - require.NotNil(t, user) - assert.Equal(t, "valid_user", user.Username) -} - -func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) -} - -func TestTokenAuth_RejectsInvalidToken(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token nonexistent-token") - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.invalid_token") -} - -func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.not_authenticated") -} diff --git a/internal/middleware/auth_test.go b/internal/middleware/auth_test.go deleted file mode 100644 index 42df77b..0000000 --- a/internal/middleware/auth_test.go +++ /dev/null @@ -1,337 +0,0 @@ -package middleware - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/treytartt/honeydue-api/internal/config" - "github.com/treytartt/honeydue-api/internal/models" -) - -func TestTokenAuth_BearerScheme_Accepted(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "bearer_user", 10) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Bearer "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - - user := GetAuthUser(c) - require.NotNil(t, user) - assert.Equal(t, "bearer_user", user.Username) -} - -func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "scheme_user", 10) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Basic "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.not_authenticated") -} - -func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "JustATokenWithNoScheme") - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.not_authenticated") -} - -func TestTokenAuth_EmptyToken_Rejected(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token ") - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.not_authenticated") -} - -func TestTokenAuth_InactiveUser_Rejected(t *testing.T) { - db := setupTestDB(t) - user, token := createTestUserAndToken(t, db, "inactive_user", 10) - - // Deactivate the user - require.NoError(t, db.Model(user).Update("is_active", false).Error) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.invalid_token") -} - -func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - // No Authorization header - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.OptionalTokenAuth()(func(c echo.Context) error { - user := GetAuthUser(c) - if user == nil { - return c.String(http.StatusOK, "no-user") - } - return c.String(http.StatusOK, user.Username) - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "no-user", rec.Body.String()) -} - -func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "opt_user", 10) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.OptionalTokenAuth()(func(c echo.Context) error { - user := GetAuthUser(c) - if user == nil { - return c.String(http.StatusOK, "no-user") - } - return c.String(http.StatusOK, user.Username) - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "opt_user", rec.Body.String()) -} - -func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) { - db := setupTestDB(t) - _, token := createTestUserAndToken(t, db, "expired_opt_user", 91) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.OptionalTokenAuth()(func(c echo.Context) error { - user := GetAuthUser(c) - if user == nil { - return c.String(http.StatusOK, "no-user") - } - return c.String(http.StatusOK, user.Username) - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "no-user", rec.Body.String()) -} - -func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddleware(db, nil) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token nonexistent-token") - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.OptionalTokenAuth()(func(c echo.Context) error { - user := GetAuthUser(c) - if user == nil { - return c.String(http.StatusOK, "no-user") - } - return c.String(http.StatusOK, user.Username) - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "no-user", rec.Body.String()) -} - -func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) { - db := setupTestDB(t) - cfg := &config.Config{ - Security: config.SecurityConfig{ - TokenExpiryDays: 30, - }, - } - - m := NewAuthMiddlewareWithConfig(db, nil, cfg) - assert.NotNil(t, m) - assert.Equal(t, 30, m.tokenExpiryDays) - - // Token at 29 days should be valid - _, token := createTestUserAndToken(t, db, "short_expiry_user", 29) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, rec.Code) -} - -func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) { - db := setupTestDB(t) - cfg := &config.Config{ - Security: config.SecurityConfig{ - TokenExpiryDays: 30, - }, - } - - m := NewAuthMiddlewareWithConfig(db, nil, cfg) - - // Token at 31 days should be expired with 30-day config - _, token := createTestUserAndToken(t, db, "custom_expired_user", 31) - - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/api/test/", nil) - req.Header.Set("Authorization", "Token "+token.Plaintext) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - handler := m.TokenAuth()(func(c echo.Context) error { - return c.String(http.StatusOK, "ok") - }) - - err := handler(c) - require.Error(t, err) - assert.Contains(t, err.Error(), "error.token_expired") -} - -func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) { - db := setupTestDB(t) - - m := NewAuthMiddlewareWithConfig(db, nil, nil) - assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays) -} - -func TestGetAuthToken_ReturnsToken(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - c.Set(AuthTokenKey, "test-token-value") - assert.Equal(t, "test-token-value", GetAuthToken(c)) -} - -func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - // No token set - assert.Equal(t, "", GetAuthToken(c)) -} - -func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - c.Set(AuthTokenKey, 12345) // Wrong type - assert.Equal(t, "", GetAuthToken(c)) -} - -func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) { - db := setupTestDB(t) - m := NewAuthMiddleware(db, nil) - - // Legacy tokens without created time should not be expired - assert.False(t, m.isTokenExpired(models.AuthToken{}.Created)) -} - -func TestInvalidateToken_NilCache_NoError(t *testing.T) { - db := setupTestDB(t) - m := NewAuthMiddleware(db, nil) // nil cache - - err := m.InvalidateToken(nil, "some-token") - assert.NoError(t, err) -} diff --git a/internal/middleware/kratos_auth.go b/internal/middleware/kratos_auth.go new file mode 100644 index 0000000..f7ae066 --- /dev/null +++ b/internal/middleware/kratos_auth.go @@ -0,0 +1,271 @@ +package middleware + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "strings" + "time" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" + "gorm.io/gorm" + + "github.com/treytartt/honeydue-api/internal/apperrors" + "github.com/treytartt/honeydue-api/internal/kratos" + "github.com/treytartt/honeydue-api/internal/models" + "github.com/treytartt/honeydue-api/internal/services" +) + +const ( + // AuthUserKey stores the authenticated *models.User in the echo context. + AuthUserKey = "auth_user" + // AuthTokenKey stores the raw session credential in the echo context. + AuthTokenKey = "auth_token" + // authVerifiedKey stores the Kratos email-verified flag in the context. + authVerifiedKey = "auth_email_verified" + + // UserCacheTTL / UserCacheMaxSize bound the in-memory local-user cache. + UserCacheTTL = 5 * time.Minute + UserCacheMaxSize = 5000 + + // kratosSessionCacheTTL is how long a validated session is cached in + // Redis, so most authed requests skip the Kratos /whoami round trip. + kratosSessionCacheTTL = 5 * time.Minute + kratosSessionPrefix = "kratos_sess:" +) + +// KratosAuth authenticates requests against an Ory Kratos session. It +// replaces the hand-rolled token auth: the session is validated via Kratos +// /sessions/whoami (Redis-cached), and the matching local auth_user row is +// lazily provisioned on first sight of a Kratos identity. +type KratosAuth struct { + kratos *kratos.Client + cache *services.CacheService + db *gorm.DB + userCache *UserCache +} + +// NewKratosAuth builds the Kratos auth middleware. +func NewKratosAuth(k *kratos.Client, cache *services.CacheService, db *gorm.DB) *KratosAuth { + return &KratosAuth{ + kratos: k, + cache: cache, + db: db, + userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize), + } +} + +// Authenticate validates the Kratos session and requires it. +func (m *KratosAuth) Authenticate() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + user, verified, cred, err := m.resolve(c) + if err != nil { + log.Debug().Err(err).Msg("Kratos authentication failed") + return apperrors.Unauthorized("error.not_authenticated") + } + c.Set(AuthUserKey, user) + c.Set(AuthTokenKey, cred) + c.Set(authVerifiedKey, verified) + return next(c) + } + } +} + +// OptionalAuthenticate authenticates if a session is present, else continues +// unauthenticated. +func (m *KratosAuth) OptionalAuthenticate() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if user, verified, cred, err := m.resolve(c); err == nil { + c.Set(AuthUserKey, user) + c.Set(AuthTokenKey, cred) + c.Set(authVerifiedKey, verified) + } + return next(c) + } + } +} + +// RequireVerified rejects users whose Kratos email address is not verified. +// Apply after Authenticate. +func (m *KratosAuth) RequireVerified() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if GetAuthUser(c) == nil { + return apperrors.Unauthorized("error.not_authenticated") + } + if verified, _ := c.Get(authVerifiedKey).(bool); !verified { + return apperrors.Forbidden("error.email_not_verified") + } + return next(c) + } + } +} + +// resolve validates the request's session and returns the local user. +func (m *KratosAuth) resolve(c echo.Context) (*models.User, bool, string, error) { + token, cookie := extractSession(c) + if token == "" && cookie == "" { + return nil, false, "", errors.New("no session credential") + } + cred := token + if cred == "" { + cred = cookie + } + ctx := c.Request().Context() + + // Redis cache: kratos_sess: -> "|<0|1>" + cacheKey := kratosSessionPrefix + hashCredential(cred) + if m.cache != nil { + if v, err := m.cache.GetString(ctx, cacheKey); err == nil && v != "" { + if user, verified, ok := m.userFromCacheValue(ctx, v); ok { + return user, verified, cred, nil + } + } + } + + sess, err := m.kratos.Whoami(ctx, token, cookie) + if err != nil { + return nil, false, "", err + } + user, err := m.provision(ctx, sess) + if err != nil { + return nil, false, "", err + } + if m.cache != nil { + _ = m.cache.SetString(ctx, cacheKey, + fmt.Sprintf("%d|%s", user.ID, boolDigit(sess.EmailVerified())), kratosSessionCacheTTL) + } + return user, sess.EmailVerified(), cred, nil +} + +// provision finds the local auth_user row for a Kratos identity, creating it +// (and a UserProfile) on first sight. Concurrent first requests are handled +// by re-reading after a unique-constraint conflict. +func (m *KratosAuth) provision(ctx context.Context, sess *kratos.Session) (*models.User, error) { + var user models.User + err := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error + if err == nil { + return &user, nil + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + + user = models.User{ + KratosID: sess.Identity.ID, + Email: sess.Identity.Traits.Email, + Username: sess.Identity.Traits.Email, + FirstName: sess.Identity.Traits.Name.First, + LastName: sess.Identity.Traits.Name.Last, + IsActive: true, + DateJoined: time.Now().UTC(), + } + txErr := m.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Create(&user).Error; err != nil { + return err + } + return tx.Create(&models.UserProfile{ + UserID: user.ID, + Verified: sess.EmailVerified(), + }).Error + }) + if txErr != nil { + // Likely a concurrent provision of the same identity — re-read. + if e := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error; e == nil { + return &user, nil + } + return nil, txErr + } + log.Info().Str("kratos_id", sess.Identity.ID).Uint("user_id", user.ID). + Msg("provisioned local user from Kratos identity") + return &user, nil +} + +// userFromCacheValue resolves a cached "|<0|1>" value to a user. +func (m *KratosAuth) userFromCacheValue(ctx context.Context, v string) (*models.User, bool, bool) { + parts := strings.SplitN(v, "|", 2) + if len(parts) != 2 { + return nil, false, false + } + var id uint + if _, err := fmt.Sscanf(parts[0], "%d", &id); err != nil || id == 0 { + return nil, false, false + } + verified := parts[1] == "1" + if cached := m.userCache.Get(id); cached != nil { + return cached, verified, true + } + var user models.User + if err := m.db.WithContext(ctx).First(&user, id).Error; err != nil { + return nil, false, false + } + m.userCache.Set(&user) + return &user, verified, true +} + +// extractSession pulls the session credential from the request: the +// X-Session-Token header or Authorization bearer (mobile clients), or the +// ory_kratos_session cookie (web). +func extractSession(c echo.Context) (token, cookie string) { + if t := c.Request().Header.Get("X-Session-Token"); t != "" { + token = t + } else if ah := c.Request().Header.Get("Authorization"); ah != "" { + parts := strings.SplitN(ah, " ", 2) + if len(parts) == 2 && (parts[0] == "Bearer" || parts[0] == "Token") { + token = parts[1] + } + } + if token == "" { + if ck := c.Request().Header.Get("Cookie"); strings.Contains(ck, "ory_kratos_session") { + cookie = ck + } + } + return token, cookie +} + +func hashCredential(cred string) string { + sum := sha256.Sum256([]byte(cred)) + return hex.EncodeToString(sum[:]) +} + +func boolDigit(b bool) string { + if b { + return "1" + } + return "0" +} + +// truncateToken returns the first 8 characters of a credential followed by +// "..." for safe inclusion in log lines. +func truncateToken(tok string) string { + if len(tok) <= 8 { + return tok + "..." + } + return tok[:8] + "..." +} + +// GetAuthUser retrieves the authenticated user from the echo context. +func GetAuthUser(c echo.Context) *models.User { + user, _ := c.Get(AuthUserKey).(*models.User) + return user +} + +// GetAuthToken retrieves the session credential from the echo context. +func GetAuthToken(c echo.Context) string { + tok, _ := c.Get(AuthTokenKey).(string) + return tok +} + +// MustGetAuthUser retrieves the authenticated user or returns a 401 error. +func MustGetAuthUser(c echo.Context) (*models.User, error) { + user := GetAuthUser(c) + if user == nil { + return nil, apperrors.Unauthorized("error.not_authenticated") + } + return user, nil +} diff --git a/internal/models/models_coverage_test.go b/internal/models/models_coverage_test.go index 428b9a4..bcabdd6 100644 --- a/internal/models/models_coverage_test.go +++ b/internal/models/models_coverage_test.go @@ -19,7 +19,7 @@ func setupModelsTestDB(t *testing.T) *gorm.DB { Logger: logger.Default.LogMode(logger.Silent), }) require.NoError(t, err) - err = db.AutoMigrate(&User{}, &AuthToken{}, &UserProfile{}) + err = db.AutoMigrate(&User{}, &UserProfile{}) require.NoError(t, err) return db } @@ -233,105 +233,6 @@ func TestNotificationType_Constants(t *testing.T) { assert.Equal(t, NotificationType("warranty_expiring"), NotificationWarrantyExpiring) } -// === AuthToken model tests === - -func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) { - db := setupModelsTestDB(t) - - user := &User{ - Username: "tokenuser", - Email: "token@test.com", - Password: "dummy", - IsActive: true, - } - err := db.Create(user).Error - require.NoError(t, err) - - token := &AuthToken{UserID: user.ID} - err = db.Create(token).Error - require.NoError(t, err) - - assert.NotEmpty(t, token.Key) - assert.Len(t, token.Key, 64) // SHA-256 hex hash (audit C1) - assert.Len(t, token.Plaintext, 40) // raw 20-byte token, returned to the client - assert.False(t, token.Created.IsZero()) -} - -func TestAuthToken_BeforeCreate_PreservesExistingKey(t *testing.T) { - db := setupModelsTestDB(t) - - user := &User{ - Username: "tokenuser", - Email: "token@test.com", - Password: "dummy", - IsActive: true, - } - err := db.Create(user).Error - require.NoError(t, err) - - existingKey := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" - token := &AuthToken{ - Key: existingKey, - UserID: user.ID, - } - err = db.Create(token).Error - require.NoError(t, err) - - assert.Equal(t, existingKey, token.Key) -} - -func TestGetOrCreateToken_CreatesNew(t *testing.T) { - db := setupModelsTestDB(t) - - user := &User{ - Username: "newtoken", - Email: "newtoken@test.com", - Password: "dummy", - IsActive: true, - } - err := db.Create(user).Error - require.NoError(t, err) - - token, err := GetOrCreateToken(db, user.ID) - require.NoError(t, err) - assert.NotEmpty(t, token.Key) - assert.Equal(t, user.ID, token.UserID) -} - -func TestGetOrCreateToken_ReturnsExisting(t *testing.T) { - db := setupModelsTestDB(t) - - user := &User{ - Username: "existingtoken", - Email: "existingtoken@test.com", - Password: "dummy", - IsActive: true, - } - err := db.Create(user).Error - require.NoError(t, err) - - token1, err := GetOrCreateToken(db, user.ID) - require.NoError(t, err) - - token2, err := GetOrCreateToken(db, user.ID) - require.NoError(t, err) - - assert.Equal(t, token1.Key, token2.Key) -} - -// === User model additional tests === - -func TestUser_SetPassword_And_CheckPassword_Integration(t *testing.T) { - user := &User{} - err := user.SetPassword("Password123") - require.NoError(t, err) - - assert.True(t, user.CheckPassword("Password123")) - assert.False(t, user.CheckPassword("WrongPassword")) - assert.False(t, user.CheckPassword("")) - assert.False(t, user.CheckPassword("password123")) // case sensitive -} - // === Task model additional tests === func TestTask_IsOverdue_CancelledNotOverdue(t *testing.T) { @@ -565,31 +466,6 @@ func TestGetDefaultProLimits(t *testing.T) { assert.Nil(t, limits.DocumentsLimit) } -// === ConfirmationCode additional tests === - -func TestConfirmationCode_TableName(t *testing.T) { - cc := ConfirmationCode{} - assert.Equal(t, "user_confirmationcode", cc.TableName()) -} - -// === PasswordResetCode additional tests === - -func TestPasswordResetCode_TableName(t *testing.T) { - prc := PasswordResetCode{} - assert.Equal(t, "user_passwordresetcode", prc.TableName()) -} - -// === Social Auth TableName tests === - -func TestAppleSocialAuth_TableName(t *testing.T) { - a := AppleSocialAuth{} - assert.Equal(t, "user_applesocialauth", a.TableName()) -} - -func TestGoogleSocialAuth_TableName(t *testing.T) { - g := GoogleSocialAuth{} - assert.Equal(t, "user_googlesocialauth", g.TableName()) -} // === BaseModel tests === diff --git a/internal/models/user.go b/internal/models/user.go index f129e84..150deb6 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -1,69 +1,38 @@ package models -import ( - "crypto/rand" - "crypto/sha256" - "encoding/binary" - "encoding/hex" - "fmt" - "time" +import "time" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" -) - -// User represents the auth_user table (Django's default User model) +// User represents the auth_user table. Identity — credentials, email +// verification, sessions, social sign-in — is owned by Ory Kratos (phase 2). +// This row is honeyDue's local mirror of a Kratos identity, linked by +// KratosID; every domain table keeps its existing integer FK to auth_user.id. type User struct { - ID uint `gorm:"primaryKey" json:"id"` - Password string `gorm:"column:password;size:128;not null" json:"-"` + ID uint `gorm:"primaryKey" json:"id"` + KratosID string `gorm:"column:kratos_id;uniqueIndex;size:36" json:"-"` // Kratos identity UUID + Username string `gorm:"column:username;uniqueIndex;size:150" json:"username"` + FirstName string `gorm:"column:first_name;size:150" json:"first_name"` + LastName string `gorm:"column:last_name;size:150" json:"last_name"` + Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"` + IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"` + IsActive bool `gorm:"column:is_active;default:true" json:"is_active"` + IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"` + DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"` LastLogin *time.Time `gorm:"column:last_login" json:"last_login,omitempty"` - IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"` - Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"` - FirstName string `gorm:"column:first_name;size:150" json:"first_name"` - LastName string `gorm:"column:last_name;size:150" json:"last_name"` - Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"` - IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"` - IsActive bool `gorm:"column:is_active;default:true" json:"is_active"` - DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"` - // Relations (not stored in auth_user table) - Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"` - AuthToken *AuthToken `gorm:"foreignKey:UserID" json:"-"` - OwnedResidences []Residence `gorm:"foreignKey:OwnerID" json:"-"` - SharedResidences []Residence `gorm:"many2many:residence_residence_users;" json:"-"` + // Relations — not columns on auth_user. + Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"` + OwnedResidences []Residence `gorm:"foreignKey:OwnerID" json:"-"` + SharedResidences []Residence `gorm:"many2many:residence_residence_users;" json:"-"` NotificationPref *NotificationPreference `gorm:"foreignKey:UserID" json:"-"` - Subscription *UserSubscription `gorm:"foreignKey:UserID" json:"-"` + Subscription *UserSubscription `gorm:"foreignKey:UserID" json:"-"` } -// TableName returns the table name for GORM +// TableName returns the table name for GORM. func (User) TableName() string { return "auth_user" } -// BcryptCost is the bcrypt work factor for password and code hashing. -// 12 (audit M2) is stronger than bcrypt.DefaultCost (10). -const BcryptCost = 12 - -// SetPassword hashes and sets the password -func (u *User) SetPassword(password string) error { - // Django uses PBKDF2_SHA256 by default, but we use bcrypt for Go. - // Passwords set by Django won't verify with Go's bcrypt check — those - // users must reset their password after migration. - hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost) - if err != nil { - return err - } - u.Password = string(hash) - return nil -} - -// CheckPassword verifies a password against the stored hash -func (u *User) CheckPassword(password string) bool { - err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)) - return err == nil -} - -// GetFullName returns the user's full name +// GetFullName returns the user's display name. func (u *User) GetFullName() string { if u.FirstName != "" && u.LastName != "" { return u.FirstName + " " + u.LastName @@ -74,80 +43,9 @@ func (u *User) GetFullName() string { return u.Username } -// AuthToken represents the user_authtoken table. -// -// Audit C1: the Key column stores the SHA-256 hash of the token, never the -// token itself. The raw token is handed to the client exactly once, at -// creation, via the non-persisted Plaintext field — it is never stored or -// logged. A database compromise therefore yields no usable session tokens. -type AuthToken struct { - Key string `gorm:"column:key;primaryKey;size:64" json:"-"` - UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"` - Created time.Time `gorm:"column:created;autoCreateTime" json:"created"` - - // Plaintext is the raw token value. It is never persisted (gorm:"-") - // and is only populated on a freshly-created token so the caller can - // return it to the client. On a token loaded from the DB it is "". - Plaintext string `gorm:"-" json:"-"` - - // Relations - User User `gorm:"foreignKey:UserID" json:"-"` -} - -// TableName returns the table name for GORM -func (AuthToken) TableName() string { - return "user_authtoken" -} - -// BeforeCreate generates a token if one is not already set, storing only -// its hash in Key and the raw value in the non-persisted Plaintext field. -func (t *AuthToken) BeforeCreate(tx *gorm.DB) error { - if t.Key == "" { - raw := generateToken() - t.Plaintext = raw - t.Key = HashToken(raw) - } - if t.Created.IsZero() { - t.Created = time.Now().UTC() - } - return nil -} - -// generateToken creates a random 40-character hex token (the raw value). -func generateToken() string { - b := make([]byte, 20) - rand.Read(b) - return hex.EncodeToString(b) -} - -// HashToken returns the at-rest representation of an auth token: the -// hex-encoded SHA-256 hash. Auth tokens are 160-bit random values, so a -// fast deterministic hash is appropriate — there is nothing to brute-force, -// and determinism preserves the single indexed-lookup query in the auth -// middleware. The raw token is never stored. -func HashToken(raw string) string { - sum := sha256.Sum256([]byte(raw)) - return hex.EncodeToString(sum[:]) -} - -// GetOrCreate gets an existing token or creates a new one for the user -func GetOrCreateToken(tx *gorm.DB, userID uint) (*AuthToken, error) { - var token AuthToken - result := tx.Where("user_id = ?", userID).First(&token) - - if result.Error == gorm.ErrRecordNotFound { - token = AuthToken{UserID: userID} - if err := tx.Create(&token).Error; err != nil { - return nil, err - } - } else if result.Error != nil { - return nil, result.Error - } - - return &token, nil -} - -// UserProfile represents the user_userprofile table +// UserProfile represents the user_userprofile table — honeyDue-specific +// profile data, keyed to a local user. Email-verification state is owned by +// Kratos; the Verified column is a convenience mirror set at provision time. type UserProfile struct { BaseModel UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"` @@ -161,142 +59,7 @@ type UserProfile struct { User User `gorm:"foreignKey:UserID" json:"-"` } -// TableName returns the table name for GORM +// TableName returns the table name for GORM. func (UserProfile) TableName() string { return "user_userprofile" } - -// ConfirmationCode represents the user_confirmationcode table -type ConfirmationCode struct { - BaseModel - UserID uint `gorm:"column:user_id;index;not null" json:"user_id"` - Code string `gorm:"column:code;size:6;not null" json:"-"` - ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"` - IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"` - - // Relations - User User `gorm:"foreignKey:UserID" json:"-"` -} - -// TableName returns the table name for GORM -func (ConfirmationCode) TableName() string { - return "user_confirmationcode" -} - -// IsValid checks if the confirmation code is still valid -func (c *ConfirmationCode) IsValid() bool { - return !c.IsUsed && time.Now().UTC().Before(c.ExpiresAt) -} - -// GenerateConfirmationCode creates a uniformly-random 6-digit code using -// rejection sampling on crypto/rand (audit H4 — removes the modulo bias of -// the previous implementation). -func GenerateConfirmationCode() string { - for { - var b [4]byte - if _, err := rand.Read(b[:]); err != nil { - continue - } - // 4294000000 is the largest multiple of 1e6 <= MaxUint32; rejecting - // the tail above it makes n % 1000000 perfectly uniform. - n := binary.BigEndian.Uint32(b[:]) - if n < 4294000000 { - return fmt.Sprintf("%06d", n%1000000) - } - } -} - -// PasswordResetCode represents the user_passwordresetcode table -type PasswordResetCode struct { - BaseModel - UserID uint `gorm:"column:user_id;index;not null" json:"user_id"` - CodeHash string `gorm:"column:code_hash;size:128;not null" json:"-"` - ResetToken string `gorm:"column:reset_token;uniqueIndex;size:64;not null" json:"reset_token"` - ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"` - Used bool `gorm:"column:used;default:false" json:"used"` - Attempts int `gorm:"column:attempts;default:0" json:"attempts"` - MaxAttempts int `gorm:"column:max_attempts;default:5" json:"max_attempts"` - - // Relations - User User `gorm:"foreignKey:UserID" json:"-"` -} - -// TableName returns the table name for GORM -func (PasswordResetCode) TableName() string { - return "user_passwordresetcode" -} - -// SetCode hashes and stores the reset code -func (p *PasswordResetCode) SetCode(code string) error { - hash, err := bcrypt.GenerateFromPassword([]byte(code), BcryptCost) - if err != nil { - return err - } - p.CodeHash = string(hash) - return nil -} - -// CheckCode verifies a code against the stored hash -func (p *PasswordResetCode) CheckCode(code string) bool { - err := bcrypt.CompareHashAndPassword([]byte(p.CodeHash), []byte(code)) - return err == nil -} - -// IsValid checks if the reset code is still valid -func (p *PasswordResetCode) IsValid() bool { - return !p.Used && time.Now().UTC().Before(p.ExpiresAt) && p.Attempts < p.MaxAttempts -} - -// IncrementAttempts increments the attempt counter -func (p *PasswordResetCode) IncrementAttempts(tx *gorm.DB) error { - p.Attempts++ - return tx.Model(p).Update("attempts", p.Attempts).Error -} - -// MarkAsUsed marks the code as used -func (p *PasswordResetCode) MarkAsUsed(tx *gorm.DB) error { - p.Used = true - return tx.Model(p).Update("used", true).Error -} - -// GenerateResetToken creates a URL-safe token -func GenerateResetToken() string { - b := make([]byte, 32) - rand.Read(b) - return hex.EncodeToString(b) -} - -// AppleSocialAuth represents a user's linked Apple ID for Sign in with Apple -type AppleSocialAuth struct { - ID uint `gorm:"primaryKey" json:"id"` - UserID uint `gorm:"uniqueIndex;not null" json:"user_id"` - User User `gorm:"foreignKey:UserID" json:"-"` - AppleID string `gorm:"column:apple_id;size:255;uniqueIndex;not null" json:"apple_id"` // Apple's unique subject ID - Email string `gorm:"column:email;size:254" json:"email"` // May be private relay - IsPrivateEmail bool `gorm:"column:is_private_email;default:false" json:"is_private_email"` - CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"` -} - -// TableName returns the table name for GORM -func (AppleSocialAuth) TableName() string { - return "user_applesocialauth" -} - -// GoogleSocialAuth represents a user's linked Google account for Sign in with Google -type GoogleSocialAuth struct { - ID uint `gorm:"primaryKey" json:"id"` - UserID uint `gorm:"uniqueIndex;not null" json:"user_id"` - User User `gorm:"foreignKey:UserID" json:"-"` - GoogleID string `gorm:"column:google_id;size:255;uniqueIndex;not null" json:"google_id"` // Google's unique subject ID - Email string `gorm:"column:email;size:254" json:"email"` - Name string `gorm:"column:name;size:255" json:"name"` - Picture string `gorm:"column:picture;size:512" json:"picture"` // Profile picture URL - CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"` - UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"` -} - -// TableName returns the table name for GORM -func (GoogleSocialAuth) TableName() string { - return "user_googlesocialauth" -} diff --git a/internal/models/user_test.go b/internal/models/user_test.go index a8788c8..c2f4bbd 100644 --- a/internal/models/user_test.go +++ b/internal/models/user_test.go @@ -2,50 +2,15 @@ package models import ( "testing" - "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestUser_SetPassword(t *testing.T) { - user := &User{} - - err := user.SetPassword("testPassword123") - require.NoError(t, err) - assert.NotEmpty(t, user.Password) - assert.NotEqual(t, "testPassword123", user.Password) // Should be hashed -} - -func TestUser_CheckPassword(t *testing.T) { - user := &User{} - err := user.SetPassword("correctpassword") - require.NoError(t, err) - - tests := []struct { - name string - password string - expected bool - }{ - {"correct password", "correctpassword", true}, - {"wrong password", "wrongpassword", false}, - {"empty password", "", false}, - {"similar password", "correctpassword1", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := user.CheckPassword(tt.password) - assert.Equal(t, tt.expected, result) - }) - } -} - func TestUser_GetFullName(t *testing.T) { tests := []struct { - name string - user User - expected string + name string + user User + expected string }{ { name: "first and last name", @@ -82,136 +47,7 @@ func TestUser_TableName(t *testing.T) { assert.Equal(t, "auth_user", user.TableName()) } -func TestAuthToken_TableName(t *testing.T) { - token := AuthToken{} - assert.Equal(t, "user_authtoken", token.TableName()) -} - func TestUserProfile_TableName(t *testing.T) { profile := UserProfile{} assert.Equal(t, "user_userprofile", profile.TableName()) } - -func TestConfirmationCode_IsValid(t *testing.T) { - now := time.Now().UTC() - future := now.Add(1 * time.Hour) - past := now.Add(-1 * time.Hour) - - tests := []struct { - name string - code ConfirmationCode - expected bool - }{ - { - name: "valid code", - code: ConfirmationCode{IsUsed: false, ExpiresAt: future}, - expected: true, - }, - { - name: "used code", - code: ConfirmationCode{IsUsed: true, ExpiresAt: future}, - expected: false, - }, - { - name: "expired code", - code: ConfirmationCode{IsUsed: false, ExpiresAt: past}, - expected: false, - }, - { - name: "used and expired", - code: ConfirmationCode{IsUsed: true, ExpiresAt: past}, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.code.IsValid() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestPasswordResetCode_IsValid(t *testing.T) { - now := time.Now().UTC() - future := now.Add(1 * time.Hour) - past := now.Add(-1 * time.Hour) - - tests := []struct { - name string - code PasswordResetCode - expected bool - }{ - { - name: "valid code", - code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 0, MaxAttempts: 5}, - expected: true, - }, - { - name: "used code", - code: PasswordResetCode{Used: true, ExpiresAt: future, Attempts: 0, MaxAttempts: 5}, - expected: false, - }, - { - name: "expired code", - code: PasswordResetCode{Used: false, ExpiresAt: past, Attempts: 0, MaxAttempts: 5}, - expected: false, - }, - { - name: "max attempts reached", - code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 5, MaxAttempts: 5}, - expected: false, - }, - { - name: "attempts under max", - code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 4, MaxAttempts: 5}, - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.code.IsValid() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestPasswordResetCode_SetAndCheckCode(t *testing.T) { - code := &PasswordResetCode{} - - err := code.SetCode("123456") - require.NoError(t, err) - assert.NotEmpty(t, code.CodeHash) - - // Check correct code - assert.True(t, code.CheckCode("123456")) - - // Check wrong code - assert.False(t, code.CheckCode("654321")) - assert.False(t, code.CheckCode("")) -} - -func TestGenerateConfirmationCode(t *testing.T) { - code := GenerateConfirmationCode() - assert.Len(t, code, 6) - - // Generate multiple codes and ensure they're different - codes := make(map[string]bool) - for i := 0; i < 10; i++ { - c := GenerateConfirmationCode() - assert.Len(t, c, 6) - codes[c] = true - } - // Most codes should be unique (very unlikely to have collisions) - assert.Greater(t, len(codes), 5) -} - -func TestGenerateResetToken(t *testing.T) { - token := GenerateResetToken() - assert.Len(t, token, 64) // 32 bytes = 64 hex chars - - // Ensure uniqueness - token2 := GenerateResetToken() - assert.NotEqual(t, token, token2) -} diff --git a/internal/repositories/user_repo.go b/internal/repositories/user_repo.go index bbaccc4..8195b83 100644 --- a/internal/repositories/user_repo.go +++ b/internal/repositories/user_repo.go @@ -11,18 +11,21 @@ import ( "github.com/treytartt/honeydue-api/internal/models" ) +// FindByKratosID finds a user by Kratos identity UUID. +func (r *UserRepository) FindByKratosID(kratosID string) (*models.User, error) { + var user models.User + if err := r.db.Where("kratos_id = ?", kratosID).First(&user).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUserNotFound + } + return nil, err + } + return &user, nil +} + var ( - ErrUserNotFound = errors.New("user not found") - ErrUserExists = errors.New("user already exists") - ErrInvalidToken = errors.New("invalid token") - ErrTokenNotFound = errors.New("token not found") - ErrCodeNotFound = errors.New("code not found") - ErrCodeExpired = errors.New("code expired") - ErrCodeUsed = errors.New("code already used") - ErrTooManyAttempts = errors.New("too many attempts") - ErrRateLimitExceeded = errors.New("rate limit exceeded") - ErrAppleAuthNotFound = errors.New("apple social auth not found") - ErrGoogleAuthNotFound = errors.New("google social auth not found") + ErrUserNotFound = errors.New("user not found") + ErrUserExists = errors.New("user already exists") ) // UserRepository handles user-related database operations @@ -145,111 +148,6 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) { return count > 0, nil } -// --- Auth Token Methods --- - -// GetOrCreateToken gets or creates an auth token for a user. -// Wrapped in a transaction to prevent race conditions where two -// concurrent requests could create duplicate tokens for the same user. -func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) { - var token models.AuthToken - - err := r.db.Transaction(func(tx *gorm.DB) error { - result := tx.Where("user_id = ?", userID).First(&token) - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - token = models.AuthToken{UserID: userID} - if err := tx.Create(&token).Error; err != nil { - return err - } - } else if result.Error != nil { - return result.Error - } - - return nil - }) - if err != nil { - return nil, err - } - - return &token, nil -} - -// FindTokenByKey looks up an auth token by its raw key value. The raw token -// is hashed (audit C1) before the indexed lookup, since only the hash is -// stored. -func (r *UserRepository) FindTokenByKey(rawKey string) (*models.AuthToken, error) { - var token models.AuthToken - if err := r.db.Where("key = ?", models.HashToken(rawKey)).First(&token).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrTokenNotFound - } - return nil, err - } - return &token, nil -} - -// CreateToken creates a new auth token for a user. -func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) { - token := models.AuthToken{UserID: userID} - if err := r.db.Create(&token).Error; err != nil { - return nil, err - } - return &token, nil -} - -// CreateFreshToken issues a new auth token for the user, replacing any -// existing one. Because tokens are stored hashed (audit C1) the server -// cannot re-issue a previously-minted token's plaintext, so every login -// mints a fresh token. The returned token's Plaintext field carries the -// raw value to hand to the client; it is never persisted. -// -// It also returns the stored hashes of the token rows it deleted, so the -// caller can evict those entries from the Redis token cache (audit MEDIUM-1). -// Without that, a prior (e.g. stolen) token keeps authenticating via a cache -// hit for up to the cache TTL even though its DB row is gone. -func (r *UserRepository) CreateFreshToken(userID uint) (*models.AuthToken, []string, error) { - var token models.AuthToken - var oldHashes []string - err := r.db.Transaction(func(tx *gorm.DB) error { - var old []models.AuthToken - if err := tx.Where("user_id = ?", userID).Find(&old).Error; err != nil { - return err - } - oldHashes = make([]string, 0, len(old)) - for i := range old { - if old[i].Key != "" { - oldHashes = append(oldHashes, old[i].Key) - } - } - if err := tx.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil { - return err - } - token = models.AuthToken{UserID: userID} - return tx.Create(&token).Error - }) - if err != nil { - return nil, nil, err - } - return &token, oldHashes, nil -} - -// DeleteToken deletes an auth token by its raw key value. The raw token is -// hashed (audit C1) before the lookup, since only the hash is stored. -func (r *UserRepository) DeleteToken(token string) error { - result := r.db.Where("key = ?", models.HashToken(token)).Delete(&models.AuthToken{}) - if result.Error != nil { - return result.Error - } - if result.RowsAffected == 0 { - return ErrTokenNotFound - } - return nil -} - -// DeleteTokenByUserID deletes an auth token by user ID -func (r *UserRepository) DeleteTokenByUserID(userID uint) error { - return r.db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error -} // --- User Profile Methods --- @@ -280,146 +178,6 @@ func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error { return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error } -// --- Confirmation Code Methods --- - -// CreateConfirmationCode creates a new confirmation code -func (r *UserRepository) CreateConfirmationCode(userID uint, code string, expiresAt time.Time) (*models.ConfirmationCode, error) { - // Invalidate any existing unused codes for this user - r.db.Model(&models.ConfirmationCode{}). - Where("user_id = ? AND is_used = ?", userID, false). - Update("is_used", true) - - confirmCode := &models.ConfirmationCode{ - UserID: userID, - Code: code, - ExpiresAt: expiresAt, - IsUsed: false, - } - - if err := r.db.Create(confirmCode).Error; err != nil { - return nil, err - } - - return confirmCode, nil -} - -// FindConfirmationCode finds a valid confirmation code for a user -func (r *UserRepository) FindConfirmationCode(userID uint, code string) (*models.ConfirmationCode, error) { - var confirmCode models.ConfirmationCode - if err := r.db.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false). - First(&confirmCode).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrCodeNotFound - } - return nil, err - } - - if !confirmCode.IsValid() { - if confirmCode.IsUsed { - return nil, ErrCodeUsed - } - return nil, ErrCodeExpired - } - - return &confirmCode, nil -} - -// MarkConfirmationCodeUsed marks a confirmation code as used -func (r *UserRepository) MarkConfirmationCodeUsed(codeID uint) error { - return r.db.Model(&models.ConfirmationCode{}).Where("id = ?", codeID).Update("is_used", true).Error -} - -// --- Password Reset Code Methods --- - -// CreatePasswordResetCode creates a new password reset code -func (r *UserRepository) CreatePasswordResetCode(userID uint, codeHash string, resetToken string, expiresAt time.Time) (*models.PasswordResetCode, error) { - // Invalidate any existing unused codes for this user - r.db.Model(&models.PasswordResetCode{}). - Where("user_id = ? AND used = ?", userID, false). - Update("used", true) - - resetCode := &models.PasswordResetCode{ - UserID: userID, - CodeHash: codeHash, - ResetToken: resetToken, - ExpiresAt: expiresAt, - Used: false, - Attempts: 0, - MaxAttempts: 5, - } - - if err := r.db.Create(resetCode).Error; err != nil { - return nil, err - } - - return resetCode, nil -} - -// FindPasswordResetCode finds a password reset code by email and checks validity -func (r *UserRepository) FindPasswordResetCodeByEmail(email string) (*models.PasswordResetCode, *models.User, error) { - user, err := r.FindByEmail(email) - if err != nil { - return nil, nil, err - } - - var resetCode models.PasswordResetCode - if err := r.db.Where("user_id = ? AND used = ?", user.ID, false). - Order("created_at DESC"). - First(&resetCode).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, nil, ErrCodeNotFound - } - return nil, nil, err - } - - return &resetCode, user, nil -} - -// FindPasswordResetCodeByToken finds a password reset code by reset token -func (r *UserRepository) FindPasswordResetCodeByToken(resetToken string) (*models.PasswordResetCode, error) { - var resetCode models.PasswordResetCode - if err := r.db.Where("reset_token = ?", resetToken).First(&resetCode).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrCodeNotFound - } - return nil, err - } - - if !resetCode.IsValid() { - if resetCode.Used { - return nil, ErrCodeUsed - } - if resetCode.Attempts >= resetCode.MaxAttempts { - return nil, ErrTooManyAttempts - } - return nil, ErrCodeExpired - } - - return &resetCode, nil -} - -// IncrementResetCodeAttempts increments the attempt counter -func (r *UserRepository) IncrementResetCodeAttempts(codeID uint) error { - return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID). - Update("attempts", gorm.Expr("attempts + 1")).Error -} - -// MarkPasswordResetCodeUsed marks a password reset code as used -func (r *UserRepository) MarkPasswordResetCodeUsed(codeID uint) error { - return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).Update("used", true).Error -} - -// CountRecentPasswordResetRequests counts reset requests in the last hour -func (r *UserRepository) CountRecentPasswordResetRequests(userID uint) (int64, error) { - var count int64 - oneHourAgo := time.Now().UTC().Add(-1 * time.Hour) - if err := r.db.Model(&models.PasswordResetCode{}). - Where("user_id = ? AND created_at > ?", userID, oneHourAgo). - Count(&count).Error; err != nil { - return 0, err - } - return count, nil -} // --- Search Methods --- @@ -576,27 +334,11 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U return profiles, err } -// --- Auth Provider Detection --- - -// FindAuthProvider determines the auth provider for a user. -// Returns "apple", "google", or "email". -func (r *UserRepository) FindAuthProvider(userID uint) (string, error) { - var count int64 - if err := r.db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil { - return "", err - } - if count > 0 { - return "apple", nil - } - - if err := r.db.Model(&models.GoogleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil { - return "", err - } - if count > 0 { - return "google", nil - } - - return "email", nil +// FindAuthProvider returns "kratos" for all Kratos-managed users (the sole +// provider after the Ory Kratos migration). Kept for compatibility with +// callers that still check the provider string. +func (r *UserRepository) FindAuthProvider(_ uint) (string, error) { + return "kratos", nil } // --- Account Deletion --- @@ -721,35 +463,12 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) { return nil, err } - // 8. Social auth records - if err := db.Where("user_id = ?", userID).Delete(&models.AppleSocialAuth{}).Error; err != nil { - return nil, err - } - if err := db.Where("user_id = ?", userID).Delete(&models.GoogleSocialAuth{}).Error; err != nil { - return nil, err - } - - // 9. Confirmation codes - if err := db.Where("user_id = ?", userID).Delete(&models.ConfirmationCode{}).Error; err != nil { - return nil, err - } - - // 10. Password reset codes - if err := db.Where("user_id = ?", userID).Delete(&models.PasswordResetCode{}).Error; err != nil { - return nil, err - } - - // 11. Auth tokens - if err := db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil { - return nil, err - } - - // 12. User profile + // 8. User profile if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil { return nil, err } - // 13. User + // 9. User if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil { return nil, err } @@ -765,53 +484,6 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) { return cleanURLs, nil } -// --- Apple Social Auth Methods --- - -// FindByAppleID finds an Apple social auth by Apple ID -func (r *UserRepository) FindByAppleID(appleID string) (*models.AppleSocialAuth, error) { - var auth models.AppleSocialAuth - if err := r.db.Where("apple_id = ?", appleID).First(&auth).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrAppleAuthNotFound - } - return nil, err - } - return &auth, nil -} - -// CreateAppleSocialAuth creates a new Apple social auth record -func (r *UserRepository) CreateAppleSocialAuth(auth *models.AppleSocialAuth) error { - return r.db.Create(auth).Error -} - -// UpdateAppleSocialAuth updates an Apple social auth record -func (r *UserRepository) UpdateAppleSocialAuth(auth *models.AppleSocialAuth) error { - return r.db.Save(auth).Error -} - -// --- Google Social Auth Methods --- - -// FindByGoogleID finds a Google social auth by Google ID -func (r *UserRepository) FindByGoogleID(googleID string) (*models.GoogleSocialAuth, error) { - var auth models.GoogleSocialAuth - if err := r.db.Where("google_id = ?", googleID).First(&auth).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrGoogleAuthNotFound - } - return nil, err - } - return &auth, nil -} - -// CreateGoogleSocialAuth creates a new Google social auth record -func (r *UserRepository) CreateGoogleSocialAuth(auth *models.GoogleSocialAuth) error { - return r.db.Create(auth).Error -} - -// UpdateGoogleSocialAuth updates a Google social auth record -func (r *UserRepository) UpdateGoogleSocialAuth(auth *models.GoogleSocialAuth) error { - return r.db.Save(auth).Error -} // WithContext returns a copy of the repository whose underlying *gorm.DB carries // the supplied context. SQL emitted via this copy gets attached to ctx's trace span diff --git a/internal/repositories/user_repo_coverage_test.go b/internal/repositories/user_repo_coverage_test.go index a94d3cd..64cb56f 100644 --- a/internal/repositories/user_repo_coverage_test.go +++ b/internal/repositories/user_repo_coverage_test.go @@ -2,7 +2,6 @@ package repositories import ( "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -78,99 +77,25 @@ func TestUserRepository_ExistsByEmail_CaseInsensitive(t *testing.T) { assert.True(t, exists) } -func TestUserRepository_GetOrCreateToken(t *testing.T) { +func TestUserRepository_FindByKratosID(t *testing.T) { db := testutil.SetupTestDB(t) repo := NewUserRepository(db) - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") + user := testutil.CreateTestUser(t, db, "kratosuser", "kratos@example.com", "") - // Create token - token1, err := repo.GetOrCreateToken(user.ID) + found, err := repo.FindByKratosID(user.KratosID) require.NoError(t, err) - assert.NotEmpty(t, token1.Key) - - // Should return same token - token2, err := repo.GetOrCreateToken(user.ID) - require.NoError(t, err) - assert.Equal(t, token1.Key, token2.Key) + assert.Equal(t, user.ID, found.ID) + assert.Equal(t, user.KratosID, found.KratosID) } -func TestUserRepository_FindTokenByKey(t *testing.T) { +func TestUserRepository_FindByKratosID_NotFound(t *testing.T) { db := testutil.SetupTestDB(t) repo := NewUserRepository(db) - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - token, err := repo.GetOrCreateToken(user.ID) - require.NoError(t, err) - - found, err := repo.FindTokenByKey(token.Plaintext) - require.NoError(t, err) - assert.Equal(t, token.Key, found.Key) - assert.Equal(t, user.ID, found.UserID) -} - -func TestUserRepository_FindTokenByKey_NotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - _, err := repo.FindTokenByKey("nonexistent-token-key") + _, err := repo.FindByKratosID("nonexistent-kratos-id") assert.Error(t, err) - assert.ErrorIs(t, err, ErrTokenNotFound) -} - -func TestUserRepository_DeleteToken(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - token, err := repo.GetOrCreateToken(user.ID) - require.NoError(t, err) - - err = repo.DeleteToken(token.Plaintext) - require.NoError(t, err) - - _, err = repo.FindTokenByKey(token.Plaintext) - assert.ErrorIs(t, err, ErrTokenNotFound) -} - -func TestUserRepository_DeleteToken_NotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - err := repo.DeleteToken("nonexistent-key") - assert.ErrorIs(t, err, ErrTokenNotFound) -} - -func TestUserRepository_DeleteTokenByUserID(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - _, err := repo.GetOrCreateToken(user.ID) - require.NoError(t, err) - - err = repo.DeleteTokenByUserID(user.ID) - require.NoError(t, err) - - // Token should be gone - var count int64 - db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count) - assert.Equal(t, int64(0), count) -} - -func TestUserRepository_CreateToken(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - token, err := repo.CreateToken(user.ID) - require.NoError(t, err) - assert.NotEmpty(t, token.Key) - assert.Equal(t, user.ID, token.UserID) + assert.ErrorIs(t, err, ErrUserNotFound) } func TestUserRepository_UpdateLastLogin(t *testing.T) { @@ -255,54 +180,6 @@ func TestUserRepository_FindByIDWithProfile_NotFound(t *testing.T) { assert.ErrorIs(t, err, ErrUserNotFound) } -func TestUserRepository_ConfirmationCode_Lifecycle(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - // Create confirmation code - expiresAt := time.Now().UTC().Add(1 * time.Hour) - code, err := repo.CreateConfirmationCode(user.ID, "123456", expiresAt) - require.NoError(t, err) - assert.NotZero(t, code.ID) - - // Find it - found, err := repo.FindConfirmationCode(user.ID, "123456") - require.NoError(t, err) - assert.Equal(t, code.ID, found.ID) - - // Mark as used - err = repo.MarkConfirmationCodeUsed(code.ID) - require.NoError(t, err) - - // Should not find used code - _, err = repo.FindConfirmationCode(user.ID, "123456") - assert.Error(t, err) -} - -func TestUserRepository_ConfirmationCode_InvalidatesExisting(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - - // Create first code - code1, err := repo.CreateConfirmationCode(user.ID, "111111", expiresAt) - require.NoError(t, err) - - // Create second code (should invalidate first) - _, err = repo.CreateConfirmationCode(user.ID, "222222", expiresAt) - require.NoError(t, err) - - // First code should be used/invalidated - var c models.ConfirmationCode - db.First(&c, code1.ID) - assert.True(t, c.IsUsed) -} - func TestUserRepository_Transaction(t *testing.T) { db := testutil.SetupTestDB(t) repo := NewUserRepository(db) @@ -331,105 +208,6 @@ func TestUserRepository_DB(t *testing.T) { assert.NotNil(t, repo.DB()) } -func TestUserRepository_FindByAppleID(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123") - appleAuth := &models.AppleSocialAuth{ - UserID: user.ID, - AppleID: "apple_sub_123", - Email: "apple@test.com", - } - require.NoError(t, db.Create(appleAuth).Error) - - found, err := repo.FindByAppleID("apple_sub_123") - require.NoError(t, err) - assert.Equal(t, user.ID, found.UserID) -} - -func TestUserRepository_FindByAppleID_NotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - _, err := repo.FindByAppleID("nonexistent_apple_id") - assert.ErrorIs(t, err, ErrAppleAuthNotFound) -} - -func TestUserRepository_FindByGoogleID(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123") - googleAuth := &models.GoogleSocialAuth{ - UserID: user.ID, - GoogleID: "google_sub_123", - Email: "google@test.com", - } - require.NoError(t, db.Create(googleAuth).Error) - - found, err := repo.FindByGoogleID("google_sub_123") - require.NoError(t, err) - assert.Equal(t, user.ID, found.UserID) -} - -func TestUserRepository_FindByGoogleID_NotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - _, err := repo.FindByGoogleID("nonexistent_google_id") - assert.ErrorIs(t, err, ErrGoogleAuthNotFound) -} - -func TestUserRepository_CreateAndUpdateAppleSocialAuth(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123") - - auth := &models.AppleSocialAuth{ - UserID: user.ID, - AppleID: "apple_sub_456", - Email: "apple@test.com", - } - err := repo.CreateAppleSocialAuth(auth) - require.NoError(t, err) - assert.NotZero(t, auth.ID) - - auth.Email = "updated@test.com" - err = repo.UpdateAppleSocialAuth(auth) - require.NoError(t, err) - - found, err := repo.FindByAppleID("apple_sub_456") - require.NoError(t, err) - assert.Equal(t, "updated@test.com", found.Email) -} - -func TestUserRepository_CreateAndUpdateGoogleSocialAuth(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123") - - auth := &models.GoogleSocialAuth{ - UserID: user.ID, - GoogleID: "google_sub_456", - Email: "google@test.com", - Name: "Test User", - } - err := repo.CreateGoogleSocialAuth(auth) - require.NoError(t, err) - assert.NotZero(t, auth.ID) - - auth.Name = "Updated Name" - err = repo.UpdateGoogleSocialAuth(auth) - require.NoError(t, err) - - found, err := repo.FindByGoogleID("google_sub_456") - require.NoError(t, err) - assert.Equal(t, "Updated Name", found.Name) -} - func TestUserRepository_SearchUsers(t *testing.T) { db := testutil.SetupTestDB(t) repo := NewUserRepository(db) diff --git a/internal/repositories/user_repo_extended_test.go b/internal/repositories/user_repo_extended_test.go index d2d7e87..2a00593 100644 --- a/internal/repositories/user_repo_extended_test.go +++ b/internal/repositories/user_repo_extended_test.go @@ -2,7 +2,6 @@ package repositories import ( "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -11,207 +10,6 @@ import ( "github.com/treytartt/honeydue-api/internal/testutil" ) -// === Password Reset Code Lifecycle === - -func TestUserRepository_PasswordResetCode_Lifecycle(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - code, err := repo.CreatePasswordResetCode(user.ID, "hash_abc123", "reset_token_xyz", expiresAt) - require.NoError(t, err) - assert.NotZero(t, code.ID) - assert.Equal(t, "hash_abc123", code.CodeHash) - assert.Equal(t, "reset_token_xyz", code.ResetToken) - assert.False(t, code.Used) - assert.Equal(t, 0, code.Attempts) -} - -func TestUserRepository_CreatePasswordResetCode_InvalidatesExisting(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - - code1, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt) - require.NoError(t, err) - - _, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt) - require.NoError(t, err) - - // First code should be marked as used - var c models.PasswordResetCode - db.First(&c, code1.ID) - assert.True(t, c.Used) -} - -func TestUserRepository_FindPasswordResetCodeByEmail(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - _, err := repo.CreatePasswordResetCode(user.ID, "hash_abc", "token_abc", expiresAt) - require.NoError(t, err) - - found, foundUser, err := repo.FindPasswordResetCodeByEmail("test@example.com") - require.NoError(t, err) - assert.Equal(t, user.ID, foundUser.ID) - assert.Equal(t, "hash_abc", found.CodeHash) -} - -func TestUserRepository_FindPasswordResetCodeByEmail_UserNotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - _, _, err := repo.FindPasswordResetCodeByEmail("nonexistent@example.com") - assert.Error(t, err) -} - -func TestUserRepository_FindPasswordResetCodeByEmail_NoCode(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - _, _, err := repo.FindPasswordResetCodeByEmail("test@example.com") - assert.ErrorIs(t, err, ErrCodeNotFound) -} - -func TestUserRepository_FindPasswordResetCodeByToken(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - _, err := repo.CreatePasswordResetCode(user.ID, "hash_xyz", "token_xyz", expiresAt) - require.NoError(t, err) - - found, err := repo.FindPasswordResetCodeByToken("token_xyz") - require.NoError(t, err) - assert.Equal(t, "hash_xyz", found.CodeHash) -} - -func TestUserRepository_FindPasswordResetCodeByToken_NotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - _, err := repo.FindPasswordResetCodeByToken("nonexistent_token") - assert.ErrorIs(t, err, ErrCodeNotFound) -} - -func TestUserRepository_FindPasswordResetCodeByToken_Expired(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - // Already expired - expiresAt := time.Now().UTC().Add(-1 * time.Hour) - _, err := repo.CreatePasswordResetCode(user.ID, "hash_exp", "token_exp", expiresAt) - require.NoError(t, err) - - _, err = repo.FindPasswordResetCodeByToken("token_exp") - assert.ErrorIs(t, err, ErrCodeExpired) -} - -func TestUserRepository_FindPasswordResetCodeByToken_Used(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - code, err := repo.CreatePasswordResetCode(user.ID, "hash_used", "token_used", expiresAt) - require.NoError(t, err) - - // Mark as used - err = repo.MarkPasswordResetCodeUsed(code.ID) - require.NoError(t, err) - - _, err = repo.FindPasswordResetCodeByToken("token_used") - assert.ErrorIs(t, err, ErrCodeUsed) -} - -func TestUserRepository_FindPasswordResetCodeByToken_TooManyAttempts(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - code, err := repo.CreatePasswordResetCode(user.ID, "hash_attempts", "token_attempts", expiresAt) - require.NoError(t, err) - - // Max out attempts - for i := 0; i < 5; i++ { - err = repo.IncrementResetCodeAttempts(code.ID) - require.NoError(t, err) - } - - _, err = repo.FindPasswordResetCodeByToken("token_attempts") - assert.ErrorIs(t, err, ErrTooManyAttempts) -} - -func TestUserRepository_IncrementResetCodeAttempts(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - code, err := repo.CreatePasswordResetCode(user.ID, "hash_inc", "token_inc", expiresAt) - require.NoError(t, err) - - err = repo.IncrementResetCodeAttempts(code.ID) - require.NoError(t, err) - - var updated models.PasswordResetCode - db.First(&updated, code.ID) - assert.Equal(t, 1, updated.Attempts) -} - -func TestUserRepository_MarkPasswordResetCodeUsed(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - code, err := repo.CreatePasswordResetCode(user.ID, "hash_mark", "token_mark", expiresAt) - require.NoError(t, err) - - err = repo.MarkPasswordResetCodeUsed(code.ID) - require.NoError(t, err) - - var updated models.PasswordResetCode - db.First(&updated, code.ID) - assert.True(t, updated.Used) -} - -func TestUserRepository_CountRecentPasswordResetRequests(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - expiresAt := time.Now().UTC().Add(1 * time.Hour) - _, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt) - require.NoError(t, err) - _, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt) - require.NoError(t, err) - - count, err := repo.CountRecentPasswordResetRequests(user.ID) - require.NoError(t, err) - assert.Equal(t, int64(2), count) -} - // === FindUsersInSharedResidences === func TestUserRepository_FindUsersInSharedResidences(t *testing.T) { @@ -301,33 +99,6 @@ func TestUserRepository_FindProfilesInSharedResidences(t *testing.T) { assert.Len(t, profiles, 2) } -// === ConfirmationCode Expired === - -func TestUserRepository_FindConfirmationCode_Expired(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - // Create already-expired code - expiresAt := time.Now().UTC().Add(-1 * time.Hour) - _, err := repo.CreateConfirmationCode(user.ID, "999999", expiresAt) - require.NoError(t, err) - - _, err = repo.FindConfirmationCode(user.ID, "999999") - assert.ErrorIs(t, err, ErrCodeExpired) -} - -func TestUserRepository_FindConfirmationCode_NotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - repo := NewUserRepository(db) - - user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") - - _, err := repo.FindConfirmationCode(user.ID, "000000") - assert.ErrorIs(t, err, ErrCodeNotFound) -} - // === Transaction Rollback === func TestUserRepository_Transaction_Rollback(t *testing.T) { diff --git a/internal/repositories/user_repo_test.go b/internal/repositories/user_repo_test.go index 4d6db36..7c17aa3 100644 --- a/internal/repositories/user_repo_test.go +++ b/internal/repositories/user_repo_test.go @@ -19,7 +19,6 @@ func TestUserRepository_Create(t *testing.T) { Email: "test@example.com", IsActive: true, } - user.SetPassword("Password123") err := repo.Create(user) require.NoError(t, err) @@ -192,39 +191,11 @@ func TestUserRepository_FindAuthProvider(t *testing.T) { db := testutil.SetupTestDB(t) repo := NewUserRepository(db) - t.Run("email user", func(t *testing.T) { + t.Run("kratos user", func(t *testing.T) { user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123") provider, err := repo.FindAuthProvider(user.ID) require.NoError(t, err) - assert.Equal(t, "email", provider) - }) - - t.Run("apple user", func(t *testing.T) { - user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123") - appleAuth := &models.AppleSocialAuth{ - UserID: user.ID, - AppleID: "apple_sub_test", - Email: "apple@test.com", - } - require.NoError(t, db.Create(appleAuth).Error) - - provider, err := repo.FindAuthProvider(user.ID) - require.NoError(t, err) - assert.Equal(t, "apple", provider) - }) - - t.Run("google user", func(t *testing.T) { - user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123") - googleAuth := &models.GoogleSocialAuth{ - UserID: user.ID, - GoogleID: "google_sub_test", - Email: "google@test.com", - } - require.NoError(t, db.Create(googleAuth).Error) - - provider, err := repo.FindAuthProvider(user.ID) - require.NoError(t, err) - assert.Equal(t, "google", provider) + assert.Equal(t, "kratos", provider) // All users are Kratos-managed }) } @@ -235,11 +206,9 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) { user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123") - // Create profile and token + // Create profile profile := &models.UserProfile{UserID: user.ID, Verified: true} require.NoError(t, db.Create(profile).Error) - _, err := models.GetOrCreateToken(db, user.ID) - require.NoError(t, err) var fileURLs []string txErr := repo.Transaction(func(txRepo *UserRepository) error { @@ -261,10 +230,6 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) { // Verify profile is gone db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count) assert.Equal(t, int64(0), count) - - // Verify token is gone - db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count) - assert.Equal(t, int64(0), count) }) t.Run("returns file URLs for cleanup", func(t *testing.T) { diff --git a/internal/router/router.go b/internal/router/router.go index 76448db..5539e1d 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -22,6 +22,7 @@ import ( "github.com/treytartt/honeydue-api/internal/dto/responses" "github.com/treytartt/honeydue-api/internal/handlers" "github.com/treytartt/honeydue-api/internal/i18n" + "github.com/treytartt/honeydue-api/internal/kratos" custommiddleware "github.com/treytartt/honeydue-api/internal/middleware" "github.com/treytartt/honeydue-api/internal/monitoring" "github.com/treytartt/honeydue-api/internal/prom" @@ -200,7 +201,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo { // Initialize services authService := services.NewAuthService(userRepo, cfg) - authService.SetNotificationRepository(notificationRepo) // For creating notification preferences on registration + authService.SetNotificationRepository(notificationRepo) userService := services.NewUserService(userRepo) residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg) residenceService.SetTaskRepository(taskRepo) // Wire up task repo for statistics @@ -220,7 +221,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo { // Wire Redis cache for residence-ID lookups across the four services that // read it on the request hot path. Cache is best-effort; nil cache is OK. if deps.Cache != nil { - authService.SetCacheService(deps.Cache) // per-account login lockout (audit M5) + authService.SetCacheService(deps.Cache) residenceService.SetCacheService(deps.Cache) taskService.SetCacheService(deps.Cache) contractorService.SetCacheService(deps.Cache) @@ -244,20 +245,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo { subscriptionWebhookHandler.SetStripeService(stripeService) subscriptionWebhookHandler.SetCacheService(deps.Cache) - // Initialize middleware - authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg) - - // Initialize Apple auth service - appleAuthService := services.NewAppleAuthService(deps.Cache, cfg) - googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg) + // Initialize Kratos auth middleware (replaces hand-rolled token auth). + kratosClient := kratos.NewClient(cfg.Security.KratosPublicURL) + authMiddleware := custommiddleware.NewKratosAuth(kratosClient, deps.Cache, deps.DB) // Initialize audit service for security event logging auditService := services.NewAuditService(deps.DB) // Initialize handlers authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache) - authHandler.SetAppleAuthService(appleAuthService) - authHandler.SetGoogleAuthService(googleAuthService) authHandler.SetStorageService(deps.StorageService) authHandler.SetAuditService(auditService) userHandler := handlers.NewUserHandler(userService) @@ -318,8 +314,8 @@ func SetupRouter(deps *Dependencies) *echo.Echo { // API group api := e.Group("/api") { - // Public auth routes (no auth required) - setupPublicAuthRoutes(api, authHandler, cfg.Server.Debug) + // Session lifecycle (login, register, logout, password reset) is + // handled by Ory Kratos — no public auth routes in this service. // Public data routes (no auth required) setupPublicDataRoutes(api, residenceHandler, taskHandler, contractorHandler, staticDataHandler, subscriptionHandler, taskTemplateHandler) @@ -329,7 +325,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo { // Protected routes (auth required) protected := api.Group("") - protected.Use(authMiddleware.TokenAuth()) + protected.Use(authMiddleware.Authenticate()) protected.Use(custommiddleware.TimezoneMiddleware()) { setupProtectedAuthRoutes(protected, authHandler) @@ -516,50 +512,18 @@ func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc { } } -// setupPublicAuthRoutes configures public authentication routes with -// per-endpoint rate limiters to mitigate brute-force and credential-stuffing. -// Rate limiters are disabled in debug mode to allow UI test suites to run -// without hitting 429 errors. -func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler, debug bool) { - auth := api.Group("/auth") +// setupPublicAuthRoutes was removed — session lifecycle (login, register, +// logout, password reset, Apple/Google sign-in) is delegated to Ory Kratos. - if debug { - // No rate limiters in debug/local mode - auth.POST("/login/", authHandler.Login) - auth.POST("/register/", authHandler.Register) - auth.POST("/forgot-password/", authHandler.ForgotPassword) - auth.POST("/verify-reset-code/", authHandler.VerifyResetCode) - auth.POST("/reset-password/", authHandler.ResetPassword) - auth.POST("/apple-sign-in/", authHandler.AppleSignIn) - auth.POST("/google-sign-in/", authHandler.GoogleSignIn) - } else { - // Rate limiters — created once, shared across requests. - loginRL := custommiddleware.LoginRateLimiter() // 10 req/min - registerRL := custommiddleware.RegistrationRateLimiter() // 5 req/min - passwordRL := custommiddleware.PasswordResetRateLimiter() // 3 req/min - - auth.POST("/login/", authHandler.Login, loginRL) - auth.POST("/register/", authHandler.Register, registerRL) - auth.POST("/forgot-password/", authHandler.ForgotPassword, passwordRL) - auth.POST("/verify-reset-code/", authHandler.VerifyResetCode, passwordRL) - auth.POST("/reset-password/", authHandler.ResetPassword, passwordRL) - auth.POST("/apple-sign-in/", authHandler.AppleSignIn, loginRL) - auth.POST("/google-sign-in/", authHandler.GoogleSignIn, loginRL) - } -} - -// setupProtectedAuthRoutes configures protected authentication routes +// setupProtectedAuthRoutes configures protected auth routes. +// Session lifecycle (login, logout, password reset, email verification) is +// delegated to Ory Kratos — only profile and account-deletion routes remain. func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) { auth := api.Group("/auth") { - auth.POST("/logout/", authHandler.Logout) - auth.POST("/refresh/", authHandler.RefreshToken) auth.GET("/me/", authHandler.CurrentUser) auth.PUT("/profile/", authHandler.UpdateProfile) auth.PATCH("/profile/", authHandler.UpdateProfile) - auth.POST("/verify/", authHandler.VerifyEmail) // Alias for mobile app compatibility - auth.POST("/verify-email/", authHandler.VerifyEmail) // Original route - auth.POST("/resend-verification/", authHandler.ResendVerification) auth.DELETE("/account/", authHandler.DeleteAccount) } } diff --git a/internal/services/apple_auth.go b/internal/services/apple_auth.go deleted file mode 100644 index 78d3b57..0000000 --- a/internal/services/apple_auth.go +++ /dev/null @@ -1,301 +0,0 @@ -package services - -import ( - "context" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "math/big" - "net/http" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" - - "github.com/treytartt/honeydue-api/internal/config" -) - -const ( - appleKeysURL = "https://appleid.apple.com/auth/keys" - appleIssuer = "https://appleid.apple.com" - appleKeysCacheTTL = 24 * time.Hour - appleKeysCacheKey = "apple:public_keys" -) - -var ( - ErrInvalidAppleToken = errors.New("invalid Apple identity token") - ErrAppleTokenExpired = errors.New("Apple identity token has expired") - ErrInvalidAppleAudience = errors.New("invalid Apple token audience") - ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer") - ErrAppleKeyNotFound = errors.New("Apple public key not found") -) - -// AppleJWKS represents Apple's JSON Web Key Set -type AppleJWKS struct { - Keys []AppleJWK `json:"keys"` -} - -// AppleJWK represents a single JSON Web Key from Apple -type AppleJWK struct { - Kty string `json:"kty"` // Key type (RSA) - Kid string `json:"kid"` // Key ID - Use string `json:"use"` // Key use (sig) - Alg string `json:"alg"` // Algorithm (RS256) - N string `json:"n"` // RSA modulus - E string `json:"e"` // RSA exponent -} - -// AppleTokenClaims represents the claims in an Apple identity token -type AppleTokenClaims struct { - jwt.RegisteredClaims - Email string `json:"email,omitempty"` - EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string - IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string - AuthTime int64 `json:"auth_time,omitempty"` -} - -// IsEmailVerified returns whether the email is verified (handles both bool and string types) -func (c *AppleTokenClaims) IsEmailVerified() bool { - switch v := c.EmailVerified.(type) { - case bool: - return v - case string: - return v == "true" - default: - return false - } -} - -// IsPrivateRelayEmail returns whether the email is a private relay email -func (c *AppleTokenClaims) IsPrivateRelayEmail() bool { - switch v := c.IsPrivateEmail.(type) { - case bool: - return v - case string: - return v == "true" - default: - return false - } -} - -// AppleAuthService handles Apple Sign In token verification -type AppleAuthService struct { - cache *CacheService - config *config.Config - client *http.Client -} - -// NewAppleAuthService creates a new Apple auth service -func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService { - return &AppleAuthService{ - cache: cache, - config: cfg, - client: &http.Client{ - Timeout: 10 * time.Second, - }, - } -} - -// VerifyIdentityToken verifies an Apple identity token and returns the claims -func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) { - // Parse the token header to get the key ID - parts := strings.Split(idToken, ".") - if len(parts) != 3 { - return nil, ErrInvalidAppleToken - } - - headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - return nil, fmt.Errorf("failed to decode token header: %w", err) - } - - var header struct { - Kid string `json:"kid"` - Alg string `json:"alg"` - } - if err := json.Unmarshal(headerBytes, &header); err != nil { - return nil, fmt.Errorf("failed to parse token header: %w", err) - } - - // Get the public key for this key ID - publicKey, err := s.getPublicKey(ctx, header.Kid) - if err != nil { - return nil, err - } - - // Parse and verify the token - token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) { - // Verify the signing method - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return publicKey, nil - }) - - if err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - return nil, ErrAppleTokenExpired - } - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - claims, ok := token.Claims.(*AppleTokenClaims) - if !ok || !token.Valid { - return nil, ErrInvalidAppleToken - } - - // Verify the issuer - if claims.Issuer != appleIssuer { - return nil, ErrInvalidAppleIssuer - } - - // Verify the audience (should be our bundle ID) - if !s.verifyAudience(claims.Audience) { - return nil, ErrInvalidAppleAudience - } - - return claims, nil -} - -// verifyAudience checks if the token audience matches our client ID. -// In production (non-debug), an empty clientID causes verification to fail -// rather than silently bypassing the check. -func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool { - clientID := s.config.AppleAuth.ClientID - if clientID == "" { - if s.config.Server.Debug { - // In debug mode only, skip audience verification for local development - return true - } - // In production, missing client ID means we cannot verify the audience - return false - } - - for _, aud := range audience { - if aud == clientID { - return true - } - } - return false -} - -// getPublicKey retrieves the public key for the given key ID -func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) { - // Try to get from cache first - keys, err := s.getCachedKeys(ctx) - if err != nil || keys == nil { - // Fetch fresh keys - keys, err = s.fetchApplePublicKeys(ctx) - if err != nil { - return nil, err - } - } - - // Find the key with the matching ID - for keyID, pubKey := range keys { - if keyID == kid { - return pubKey, nil - } - } - - // Key not found in cache, try fetching fresh keys - keys, err = s.fetchApplePublicKeys(ctx) - if err != nil { - return nil, err - } - - if pubKey, ok := keys[kid]; ok { - return pubKey, nil - } - - return nil, ErrAppleKeyNotFound -} - -// getCachedKeys retrieves cached Apple public keys from Redis -func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) { - if s.cache == nil { - return nil, nil - } - - data, err := s.cache.GetString(ctx, appleKeysCacheKey) - if err != nil || data == "" { - return nil, nil - } - - var jwks AppleJWKS - if err := json.Unmarshal([]byte(data), &jwks); err != nil { - return nil, nil - } - - return s.parseJWKS(&jwks) -} - -// fetchApplePublicKeys fetches Apple's public keys and caches them -func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - resp, err := s.client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to fetch Apple keys: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode) - } - - var jwks AppleJWKS - if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { - return nil, fmt.Errorf("failed to decode Apple keys: %w", err) - } - - // Cache the keys - if s.cache != nil { - keysJSON, _ := json.Marshal(jwks) - _ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL) - } - - return s.parseJWKS(&jwks) -} - -// parseJWKS converts Apple's JWKS to RSA public keys -func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) { - keys := make(map[string]*rsa.PublicKey) - - for _, key := range jwks.Keys { - if key.Kty != "RSA" { - continue - } - - // Decode the modulus (N) - nBytes, err := base64.RawURLEncoding.DecodeString(key.N) - if err != nil { - continue - } - n := new(big.Int).SetBytes(nBytes) - - // Decode the exponent (E) - eBytes, err := base64.RawURLEncoding.DecodeString(key.E) - if err != nil { - continue - } - e := 0 - for _, b := range eBytes { - e = e<<8 + int(b) - } - - pubKey := &rsa.PublicKey{ - N: n, - E: e, - } - - keys[key.Kid] = pubKey - } - - return keys, nil -} diff --git a/internal/services/auth_refresh_test.go b/internal/services/auth_refresh_test.go deleted file mode 100644 index 5f9f987..0000000 --- a/internal/services/auth_refresh_test.go +++ /dev/null @@ -1,176 +0,0 @@ -package services - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" - - "github.com/treytartt/honeydue-api/internal/config" - "github.com/treytartt/honeydue-api/internal/models" - "github.com/treytartt/honeydue-api/internal/repositories" -) - -func setupRefreshTestDB(t *testing.T) *gorm.DB { - t.Helper() - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - }) - require.NoError(t, err) - - err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{}) - require.NoError(t, err) - - return db -} - -func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User { - t.Helper() - user := &models.User{ - Username: "refreshtest", - Email: "refresh@test.com", - IsActive: true, - } - require.NoError(t, user.SetPassword("Password123")) - require.NoError(t, db.Create(user).Error) - return user -} - -func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken { - t.Helper() - token := &models.AuthToken{ - UserID: userID, - } - require.NoError(t, db.Create(token).Error) - - // Backdate the token's Created timestamp after creation to bypass autoCreateTime - backdated := time.Now().UTC().AddDate(0, 0, -ageDays) - require.NoError(t, db.Model(token).Update("created", backdated).Error) - token.Created = backdated - - return token -} - -func newTestAuthService(db *gorm.DB) *AuthService { - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{ - SecretKey: "test-secret", - TokenExpiryDays: 90, - TokenRefreshDays: 60, - }, - } - return NewAuthService(userRepo, cfg) -} - -func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window - - svc := newTestAuthService(db) - - resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID) - require.NoError(t, err) - assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token") - assert.Contains(t, resp.Message, "still valid") -} - -func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90) - - svc := newTestAuthService(db) - - resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID) - require.NoError(t, err) - assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token") - assert.Contains(t, resp.Message, "refreshed") - - // Verify old token was deleted - var count int64 - // The DB stores the SHA-256 hash, so query by token.Key (the hash). - db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count) - assert.Equal(t, int64(0), count, "old token should be deleted") - - // Verify new token exists in DB - // resp.Token is the raw token; the DB stores its hash. - db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count) - assert.Equal(t, int64(1), count, "new token should exist in DB") - - // Verify new token belongs to the same user - var newToken models.AuthToken - require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error) - assert.Equal(t, user.ID, newToken.UserID) -} - -func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry - - svc := newTestAuthService(db) - - resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID) - require.Error(t, err) - assert.Nil(t, resp) - assert.Contains(t, err.Error(), "error.token_expired") -} - -func TestRefreshToken_AtExactBoundary60Days(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - // Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false, - // meaning it enters the renewal window - token := createTokenWithAge(t, db, user.ID, 61) - - svc := newTestAuthService(db) - - resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID) - require.NoError(t, err) - assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed") -} - -func TestRefreshToken_InvalidToken_Returns401(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - - svc := newTestAuthService(db) - - resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID) - require.Error(t, err) - assert.Nil(t, resp) - assert.Contains(t, err.Error(), "error.invalid_token") -} - -func TestRefreshToken_WrongUser_Returns401(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - token := createTokenWithAge(t, db, user.ID, 75) - - svc := newTestAuthService(db) - - // Try to refresh with a different user ID - resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999) - require.Error(t, err) - assert.Nil(t, resp) - assert.Contains(t, err.Error(), "error.invalid_token") -} - -func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) { - db := setupRefreshTestDB(t) - user := createRefreshTestUser(t, db) - token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold - - svc := newTestAuthService(db) - - resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID) - require.NoError(t, err) - assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed") -} diff --git a/internal/services/auth_service.go b/internal/services/auth_service.go index 898717e..9ee6f2c 100644 --- a/internal/services/auth_service.go +++ b/internal/services/auth_service.go @@ -2,333 +2,43 @@ package services import ( "context" - "crypto/rand" - "encoding/binary" - "encoding/hex" "errors" - "fmt" - "strings" - "time" "github.com/rs/zerolog/log" - "golang.org/x/crypto/bcrypt" "github.com/treytartt/honeydue-api/internal/apperrors" - "github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/dto/requests" "github.com/treytartt/honeydue-api/internal/dto/responses" - "github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/repositories" ) -// Deprecated: Legacy error constants - kept for reference during transition -// Use apperrors package instead -var ( - // ErrInvalidCredentials = errors.New("invalid credentials") - // ErrUsernameTaken = errors.New("username already taken") - // ErrEmailTaken = errors.New("email already taken") - // ErrUserInactive = errors.New("user account is inactive") - // ErrInvalidCode = errors.New("invalid verification code") - // ErrCodeExpired = errors.New("verification code expired") - // ErrAlreadyVerified = errors.New("email already verified") - // ErrRateLimitExceeded = errors.New("too many requests, please try again later") - // ErrInvalidResetToken = errors.New("invalid or expired reset token") - ErrAppleSignInFailed = errors.New("Apple Sign In failed") - ErrGoogleSignInFailed = errors.New("Google Sign In failed") -) - -// Per-account login lockout (audit M5, hardened per MEDIUM-3). -const ( - // maxLoginFailureIPs is how many DISTINCT source IPs may fail to log in to - // one account within the window before that account is locked. Counting - // distinct IPs (not raw attempts) means a single attacker who knows a - // victim's email cannot lock the victim out by spamming failures — only a - // genuinely distributed credential-stuffing attack reaches this threshold. - maxLoginFailureIPs = 5 - // loginLockWindow is how long the failed-IP set persists; it is refreshed - // on each failure so an active attack keeps the window open. - loginLockWindow = 15 * time.Minute -) - -// AuthService handles authentication business logic +// AuthService handles user profile and account management. Session +// authentication is now delegated to Ory Kratos via KratosAuth middleware. type AuthService struct { userRepo *repositories.UserRepository notificationRepo *repositories.NotificationRepository cache *CacheService - cfg *config.Config } -// SetCacheService wires Redis for per-account login-failure tracking (M5). -func (s *AuthService) SetCacheService(cache *CacheService) { - s.cache = cache -} - -// NewAuthService creates a new auth service -func NewAuthService(userRepo *repositories.UserRepository, cfg *config.Config) *AuthService { +// NewAuthService creates a new auth service. +func NewAuthService(userRepo *repositories.UserRepository, _ interface{}) *AuthService { return &AuthService{ userRepo: userRepo, - cfg: cfg, } } -// SetNotificationRepository sets the notification repository for creating notification preferences +// SetNotificationRepository wires the notification repo (kept for startup +// compatibility — no longer used post-Kratos). func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.NotificationRepository) { s.notificationRepo = notificationRepo } -// dummyPasswordHash is a valid bcrypt hash used to keep login response time -// constant when the account does not exist (audit LIVE-L11). It is computed -// once at startup; the plaintext it hashes is irrelevant and never used. -var dummyPasswordHash = func() string { - h, err := bcrypt.GenerateFromPassword([]byte("honeydue-login-timing-equalizer"), models.BcryptCost) - if err != nil { - return "" // CompareHashAndPassword against "" always fails — safe - } - return string(h) -}() - -// freshToken mints a new auth token for the user and evicts any prior token's -// Redis cache entry (audit MEDIUM-1). Without the eviction a re-login would -// not actually kill a previously-issued token until the cache TTL lapsed — a -// stolen token would keep working for up to 5 minutes after the victim -// re-authenticates. A cache-eviction failure is logged, not fatal: the token -// row is already gone, so the stale entry simply ages out on its own. -func (s *AuthService) freshToken(ctx context.Context, userID uint) (*models.AuthToken, error) { - token, oldHashes, err := s.userRepo.WithContext(ctx).CreateFreshToken(userID) - if err != nil { - return nil, err - } - if s.cache != nil && len(oldHashes) > 0 { - if cErr := s.cache.InvalidateAuthTokenHashes(ctx, oldHashes...); cErr != nil { - log.Warn().Err(cErr).Uint("user_id", userID). - Msg("failed to evict prior auth-token cache entries on re-login") - } - } - return token, nil +// SetCacheService wires Redis (kept for startup compatibility). +func (s *AuthService) SetCacheService(cache *CacheService) { + s.cache = cache } -// Login authenticates a user and returns a token. clientIP is the request's -// source IP (echo c.RealIP()), used for the distributed-attack lockout. -func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest, clientIP string) (*responses.LoginResponse, error) { - // Find user by username or email - identifier := req.Username - if identifier == "" { - identifier = req.Email - } - lockKey := strings.ToLower(strings.TrimSpace(identifier)) - - // Audit M5 (hardened per MEDIUM-3): per-account lockout keyed on the set - // of distinct source IPs that have failed. Once enough distinct IPs have - // failed for one account within the window, reject — this still catches - // distributed credential stuffing, without letting a single attacker lock - // a victim out by spamming failed logins from one IP. - if s.cache != nil && lockKey != "" { - if n, cErr := s.cache.LoginFailureIPCount(ctx, lockKey); cErr == nil && n >= maxLoginFailureIPs { - return nil, apperrors.TooManyRequests("error.too_many_login_attempts") - } - } - - user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier) - if err != nil && !errors.Is(err, repositories.ErrUserNotFound) { - return nil, apperrors.Internal(err) - } - - // Constant-time login (audit LIVE-L11): always run a bcrypt comparison, - // even when the account does not exist or is inactive, so response - // timing never reveals which emails are real accounts. Compare against - // the user's hash when available, otherwise a fixed dummy hash. - passwordHash := dummyPasswordHash - if user != nil { - passwordHash = user.Password - } - passwordOK := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)) == nil - - // One generic error for not-found, inactive, and wrong-password - // (audit L1) — none of them disclose which condition failed. - if user == nil || !user.IsActive || !passwordOK { - if s.cache != nil && lockKey != "" { - _, _ = s.cache.RegisterLoginFailure(ctx, lockKey, clientIP, loginLockWindow) - } - return nil, apperrors.Unauthorized("error.invalid_credentials") - } - - // Successful authentication — clear the failure counter (audit M5). - if s.cache != nil && lockKey != "" { - _ = s.cache.ClearLoginFailures(ctx, lockKey) - } - - // Get or create auth token - token, err := s.freshToken(ctx, user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // Update last login - if err := s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID); err != nil { - // Log error but don't fail the login - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to update last login") - } - - return &responses.LoginResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(user), - }, nil -} - -// Register creates a new user account. -// F-10: User creation, profile creation, notification preferences, and confirmation code -// are wrapped in a transaction for atomicity. -func (s *AuthService) Register(ctx context.Context, req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) { - // Check if username exists - exists, err := s.userRepo.WithContext(ctx).ExistsByUsername(req.Username) - if err != nil { - return nil, "", apperrors.Internal(err) - } - if exists { - return nil, "", apperrors.Conflict("error.username_taken") - } - - // Check if email exists - exists, err = s.userRepo.WithContext(ctx).ExistsByEmail(req.Email) - if err != nil { - return nil, "", apperrors.Internal(err) - } - if exists { - return nil, "", apperrors.Conflict("error.email_taken") - } - - // Create user - user := &models.User{ - Username: req.Username, - Email: req.Email, - FirstName: req.FirstName, - LastName: req.LastName, - IsActive: true, - } - - // Hash password - if err := user.SetPassword(req.Password); err != nil { - return nil, "", apperrors.Internal(err) - } - - // Generate confirmation code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing - var code string - if s.cfg.Server.DebugFixedCodes { - code = "123456" - } else { - code = generateSixDigitCode() - } - expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) - - // Wrap user creation + profile + notification preferences + confirmation code in a transaction - txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error { - // Save user - if err := txRepo.Create(user); err != nil { - return err - } - - // Create user profile - if _, err := txRepo.GetOrCreateProfile(user.ID); err != nil { - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create user profile during registration") - } - - // Create notification preferences with all options enabled - if s.notificationRepo != nil { - if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil { - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences during registration") - } - } - - // Create confirmation code - if _, err := txRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil { - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create confirmation code during registration") - } - - return nil - }) - if txErr != nil { - return nil, "", apperrors.Internal(txErr) - } - - // Create auth token (outside transaction since token generation is idempotent) - token, err := s.freshToken(ctx, user.ID) - if err != nil { - return nil, "", apperrors.Internal(err) - } - - return &responses.RegisterResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(user), - Message: "Registration successful. Please check your email to verify your account.", - }, code, nil -} - -// RefreshToken handles token refresh logic. -// - If token is expired (> expiryDays old), returns error (must re-login). -// - If token is in the renewal window (> refreshDays old), generates a new token. -// - If token is still fresh (< refreshDays old), returns the existing token (no-op). -func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID uint) (*responses.RefreshTokenResponse, error) { - expiryDays := s.cfg.Security.TokenExpiryDays - if expiryDays <= 0 { - expiryDays = 90 - } - refreshDays := s.cfg.Security.TokenRefreshDays - if refreshDays <= 0 { - refreshDays = 60 - } - - // Look up the token - authToken, err := s.userRepo.WithContext(ctx).FindTokenByKey(tokenKey) - if err != nil { - return nil, apperrors.Unauthorized("error.invalid_token") - } - - // Verify ownership - if authToken.UserID != userID { - return nil, apperrors.Unauthorized("error.invalid_token") - } - - tokenAge := time.Since(authToken.Created) - expiryDuration := time.Duration(expiryDays) * 24 * time.Hour - refreshDuration := time.Duration(refreshDays) * 24 * time.Hour - - // Token is expired — must re-login - if tokenAge > expiryDuration { - return nil, apperrors.Unauthorized("error.token_expired") - } - - // Token is still fresh — no-op refresh - if tokenAge < refreshDuration { - return &responses.RefreshTokenResponse{ - Token: tokenKey, - Message: "Token is still valid.", - }, nil - } - - // Token is in the renewal window — generate a new one - // Delete the old token - if err := s.userRepo.WithContext(ctx).DeleteToken(tokenKey); err != nil { - log.Warn().Err(err).Str("token", tokenKey[:8]+"...").Msg("Failed to delete old token during refresh") - } - - // Create a new token - newToken, err := s.userRepo.WithContext(ctx).CreateToken(userID) - if err != nil { - return nil, apperrors.Internal(err) - } - - return &responses.RefreshTokenResponse{ - Token: newToken.Plaintext, - Message: "Token refreshed successfully.", - }, nil -} - -// Logout invalidates a user's token -func (s *AuthService) Logout(ctx context.Context, token string) error { - return s.userRepo.WithContext(ctx).DeleteToken(token) -} - -// GetCurrentUser returns the current authenticated user with profile +// GetCurrentUser returns the current authenticated user with profile. func (s *AuthService) GetCurrentUser(ctx context.Context, userID uint) (*responses.CurrentUserResponse, error) { user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(userID) if err != nil { @@ -337,75 +47,21 @@ func (s *AuthService) GetCurrentUser(ctx context.Context, userID uint) (*respons authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID) if err != nil { - // Log but don't fail - default to "email" log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") - authProvider = "email" + authProvider = "kratos" } response := responses.NewCurrentUserResponse(user, authProvider) return &response, nil } -// DeleteAccount deletes a user's account and all associated data. -// For email auth users, password verification is required. -// For social auth users, confirmation string "DELETE" is required. -// Returns a list of file URLs that need to be deleted from disk. -func (s *AuthService) DeleteAccount(ctx context.Context, userID uint, password, confirmation *string) ([]string, error) { - // Fetch user - user, err := s.userRepo.WithContext(ctx).FindByID(userID) - if err != nil { - if errors.Is(err, repositories.ErrUserNotFound) { - return nil, apperrors.NotFound("error.user_not_found") - } - return nil, apperrors.Internal(err) - } - - // Determine auth provider - authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // Validate credentials based on auth provider - if authProvider == "email" { - if password == nil || *password == "" { - return nil, apperrors.BadRequest("error.password_required") - } - if !user.CheckPassword(*password) { - return nil, apperrors.Unauthorized("error.invalid_credentials") - } - } else { - // Social auth (apple or google) - require confirmation - if confirmation == nil || *confirmation != "DELETE" { - return nil, apperrors.BadRequest("error.confirmation_required") - } - } - - // Start transaction and cascade delete - var fileURLs []string - txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error { - urls, err := txRepo.DeleteUserCascade(userID) - if err != nil { - return err - } - fileURLs = urls - return nil - }) - if txErr != nil { - return nil, apperrors.Internal(txErr) - } - - return fileURLs, nil -} - -// UpdateProfile updates a user's profile +// UpdateProfile updates a user's profile fields. func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) { user, err := s.userRepo.WithContext(ctx).FindByID(userID) if err != nil { return nil, err } - // Check if new email is taken (if email is being changed) if req.Email != nil && *req.Email != user.Email { exists, err := s.userRepo.WithContext(ctx).ExistsByEmail(*req.Email) if err != nil { @@ -428,7 +84,6 @@ func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *reque return nil, apperrors.Internal(err) } - // Reload with profile user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(userID) if err != nil { return nil, err @@ -437,609 +92,41 @@ func (s *AuthService) UpdateProfile(ctx context.Context, userID uint, req *reque authProvider, err := s.userRepo.WithContext(ctx).FindAuthProvider(userID) if err != nil { log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to determine auth provider") - authProvider = "email" + authProvider = "kratos" } response := responses.NewCurrentUserResponse(user, authProvider) return &response, nil } -// VerifyEmail verifies a user's email with a confirmation code -func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string) error { - // Get user profile - profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID) - if err != nil { - return apperrors.Internal(err) +// DeleteAccount deletes a user's account and all associated data. +// Kratos owns credentials; confirmation string "DELETE" is required from the +// caller since we can no longer verify a password here. +func (s *AuthService) DeleteAccount(ctx context.Context, userID uint, _ *string, confirmation *string) ([]string, error) { + if confirmation == nil || *confirmation != "DELETE" { + return nil, apperrors.BadRequest("error.confirmation_required") } - // Check if already verified - if profile.Verified { - return apperrors.BadRequest("error.email_already_verified") - } - - // Check for test code when DEBUG_FIXED_CODES is enabled - if s.cfg.Server.DebugFixedCodes && code == "123456" { - if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil { - return apperrors.Internal(err) - } - return nil - } - - // Audit M4: validate the code, consume it, and flip the verified flag in - // one transaction so the three writes commit or roll back together. - txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error { - confirmCode, err := txRepo.FindConfirmationCode(userID, code) - if err != nil { - return err - } - if err := txRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil { - return err - } - return txRepo.SetProfileVerified(userID, true) - }) - if txErr != nil { - if errors.Is(txErr, repositories.ErrCodeNotFound) { - return apperrors.BadRequest("error.invalid_verification_code") - } - if errors.Is(txErr, repositories.ErrCodeExpired) { - return apperrors.BadRequest("error.verification_code_expired") - } - return apperrors.Internal(txErr) - } - - return nil -} - -// ResendVerificationCode creates and returns a new verification code -func (s *AuthService) ResendVerificationCode(ctx context.Context, userID uint) (string, error) { - // Get user profile - profile, err := s.userRepo.WithContext(ctx).GetOrCreateProfile(userID) - if err != nil { - return "", apperrors.Internal(err) - } - - // Check if already verified - if profile.Verified { - return "", apperrors.BadRequest("error.email_already_verified") - } - - // Generate new code - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing - var code string - if s.cfg.Server.DebugFixedCodes { - code = "123456" - } else { - code = generateSixDigitCode() - } - expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry) - - if _, err := s.userRepo.WithContext(ctx).CreateConfirmationCode(userID, code, expiresAt); err != nil { - return "", apperrors.Internal(err) - } - - return code, nil -} - -// ForgotPassword initiates the password reset process -func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string, *models.User, error) { - // Find user by email - user, err := s.userRepo.WithContext(ctx).FindByEmail(email) + _, err := s.userRepo.WithContext(ctx).FindByID(userID) if err != nil { if errors.Is(err, repositories.ErrUserNotFound) { - // Don't reveal that the email doesn't exist - return "", nil, nil + return nil, apperrors.NotFound("error.user_not_found") } - return "", nil, err + return nil, apperrors.Internal(err) } - // Check rate limit - count, err := s.userRepo.WithContext(ctx).CountRecentPasswordResetRequests(user.ID) - if err != nil { - return "", nil, apperrors.Internal(err) - } - if count >= int64(s.cfg.Security.MaxPasswordResetRate) { - return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded") - } - - // Generate code and reset token - use fixed code when DEBUG_FIXED_CODES is enabled for easier local testing - var code string - if s.cfg.Server.DebugFixedCodes { - code = "123456" - } else { - code = generateSixDigitCode() - } - resetToken := generateResetToken() - expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry) - - // Hash the code before storing - codeHash, err := bcrypt.GenerateFromPassword([]byte(code), models.BcryptCost) - if err != nil { - return "", nil, apperrors.Internal(err) - } - - if _, err := s.userRepo.WithContext(ctx).CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil { - return "", nil, apperrors.Internal(err) - } - - return code, user, nil -} - -// VerifyResetCode verifies a password reset code and returns a reset token -func (s *AuthService) VerifyResetCode(ctx context.Context, email, code string) (string, error) { - // Find the reset code - resetCode, user, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByEmail(email) - if err != nil { - if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) { - return "", apperrors.BadRequest("error.invalid_verification_code") - } - return "", apperrors.Internal(err) - } - - // Check for test code when DEBUG_FIXED_CODES is enabled - if s.cfg.Server.DebugFixedCodes && code == "123456" { - return resetCode.ResetToken, nil - } - - // Verify the code - if !resetCode.CheckCode(code) { - // Increment attempts - s.userRepo.WithContext(ctx).IncrementResetCodeAttempts(resetCode.ID) - return "", apperrors.BadRequest("error.invalid_verification_code") - } - - // Check if code is still valid - if !resetCode.IsValid() { - if resetCode.Used { - return "", apperrors.BadRequest("error.invalid_verification_code") - } - if resetCode.Attempts >= resetCode.MaxAttempts { - return "", apperrors.TooManyRequests("error.rate_limit_exceeded") - } - return "", apperrors.BadRequest("error.verification_code_expired") - } - - _ = user // user available if needed - - return resetCode.ResetToken, nil -} - -// ResetPassword resets the user's password using a reset token -func (s *AuthService) ResetPassword(ctx context.Context, resetToken, newPassword string) error { - // Find the reset code by token - resetCode, err := s.userRepo.WithContext(ctx).FindPasswordResetCodeByToken(resetToken) - if err != nil { - if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) { - return apperrors.BadRequest("error.invalid_reset_token") - } - return apperrors.Internal(err) - } - - // Get the user - user, err := s.userRepo.WithContext(ctx).FindByID(resetCode.UserID) - if err != nil { - return apperrors.Internal(err) - } - - // Update password - if err := user.SetPassword(newPassword); err != nil { - return apperrors.Internal(err) - } - - if err := s.userRepo.WithContext(ctx).Update(user); err != nil { - return apperrors.Internal(err) - } - - // Mark reset code as used - if err := s.userRepo.WithContext(ctx).MarkPasswordResetCodeUsed(resetCode.ID); err != nil { - // Log error but don't fail - log.Warn().Err(err).Uint("reset_code_id", resetCode.ID).Msg("Failed to mark reset code as used") - } - - // Invalidate all existing tokens for this user (security measure) - if err := s.userRepo.WithContext(ctx).DeleteTokenByUserID(user.ID); err != nil { - // Log error but don't fail - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to delete user tokens after password reset") - } - - return nil -} - -// AppleSignIn handles Sign in with Apple authentication -func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthService, req *requests.AppleSignInRequest) (*responses.AppleSignInResponse, error) { - // 1. Verify the Apple JWT token - claims, err := appleAuth.VerifyIdentityToken(ctx, req.IDToken) - if err != nil { - return nil, apperrors.Unauthorized("error.invalid_credentials").Wrap(err) - } - - // Use the subject from claims as the authoritative Apple ID - appleID := claims.Subject - if appleID == "" { - appleID = req.UserID // Fallback to request UserID - } - - // 2. Check if this Apple ID is already linked to an account - existingAuth, err := s.userRepo.WithContext(ctx).FindByAppleID(appleID) - if err == nil && existingAuth != nil { - // User already linked with this Apple ID - log them in - user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID) + var fileURLs []string + txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error { + urls, err := txRepo.DeleteUserCascade(userID) if err != nil { - return nil, apperrors.Internal(err) + return err } - - if !user.IsActive { - return nil, apperrors.Unauthorized("error.account_inactive") - } - - // Get or create token - token, err := s.freshToken(ctx, user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // Update last login - _ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID) - - return &responses.AppleSignInResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(user), - IsNewUser: false, - }, nil + fileURLs = urls + return nil + }) + if txErr != nil { + return nil, apperrors.Internal(txErr) } - // 3. Check if email matches an existing user (for account linking) - email := getEmailFromRequest(req.Email, claims.Email) - if email != "" { - existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email) - if err == nil && existingUser != nil { - // S-06: Log auto-linking of social account to existing user - log.Warn(). - Str("email", email). - Str("provider", "apple"). - Uint("user_id", existingUser.ID). - Msg("Auto-linking social account to existing user by email match") - - // Link Apple ID to existing account - appleAuthRecord := &models.AppleSocialAuth{ - UserID: existingUser.ID, - AppleID: appleID, - Email: email, - IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(), - } - if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil { - return nil, apperrors.Internal(err) - } - - // Mark as verified since Apple verified the email - _ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true) - - // Get or create token - token, err := s.freshToken(ctx, existingUser.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // Update last login - _ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID) - - // B-08: Check error from FindByIDWithProfile - existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - return &responses.AppleSignInResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(existingUser), - IsNewUser: false, - }, nil - } - } - - // 4. Create new user - username := generateUniqueUsername(email, req.FirstName) - - user := &models.User{ - Username: username, - Email: getEmailOrDefault(email), - FirstName: getStringOrEmpty(req.FirstName), - LastName: getStringOrEmpty(req.LastName), - IsActive: true, - } - - // Set a random password (user won't use it since they log in with Apple) - randomPassword := generateResetToken() - _ = user.SetPassword(randomPassword) - - if err := s.userRepo.WithContext(ctx).Create(user); err != nil { - return nil, apperrors.Internal(err) - } - - // Create profile (already verified since Apple verified) - profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID) - if profile != nil { - _ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true) - } - - // Create notification preferences with all options enabled - if s.notificationRepo != nil { - if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil { - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Apple Sign In user") - } - } - - // Link Apple ID - appleAuthRecord := &models.AppleSocialAuth{ - UserID: user.ID, - AppleID: appleID, - Email: getEmailOrDefault(email), - IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(), - } - if err := s.userRepo.WithContext(ctx).CreateAppleSocialAuth(appleAuthRecord); err != nil { - return nil, apperrors.Internal(err) - } - - // Create token - token, err := s.freshToken(ctx, user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // B-08: Check error from FindByIDWithProfile - user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - return &responses.AppleSignInResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(user), - IsNewUser: true, - }, nil -} - -// GoogleSignIn handles Google Sign In authentication -func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthService, req *requests.GoogleSignInRequest) (*responses.GoogleSignInResponse, error) { - // 1. Verify the Google ID token - tokenInfo, err := googleAuth.VerifyIDToken(ctx, req.IDToken) - if err != nil { - return nil, apperrors.Unauthorized("error.invalid_credentials").Wrap(err) - } - - googleID := tokenInfo.Sub - if googleID == "" { - return nil, apperrors.Unauthorized("error.invalid_credentials") - } - - // 2. Check if this Google ID is already linked to an account - existingAuth, err := s.userRepo.WithContext(ctx).FindByGoogleID(googleID) - if err == nil && existingAuth != nil { - // User already linked with this Google ID - log them in - user, err := s.userRepo.WithContext(ctx).FindByIDWithProfile(existingAuth.UserID) - if err != nil { - return nil, apperrors.Internal(err) - } - - if !user.IsActive { - return nil, apperrors.Unauthorized("error.account_inactive") - } - - // Get or create token - token, err := s.freshToken(ctx, user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // Update last login - _ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID) - - return &responses.GoogleSignInResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(user), - IsNewUser: false, - }, nil - } - - // 3. Check if email matches an existing user (for account linking) - email := tokenInfo.Email - if email != "" { - existingUser, err := s.userRepo.WithContext(ctx).FindByEmail(email) - if err == nil && existingUser != nil { - // S-06: Log auto-linking of social account to existing user - log.Warn(). - Str("email", email). - Str("provider", "google"). - Uint("user_id", existingUser.ID). - Msg("Auto-linking social account to existing user by email match") - - // Link Google ID to existing account - googleAuthRecord := &models.GoogleSocialAuth{ - UserID: existingUser.ID, - GoogleID: googleID, - Email: email, - Name: tokenInfo.Name, - Picture: tokenInfo.Picture, - } - if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil { - return nil, apperrors.Internal(err) - } - - // Mark as verified since Google verified the email - if tokenInfo.IsEmailVerified() { - _ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true) - } - - // Get or create token - token, err := s.freshToken(ctx, existingUser.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // Update last login - _ = s.userRepo.WithContext(ctx).UpdateLastLogin(existingUser.ID) - - // B-08: Check error from FindByIDWithProfile - existingUser, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(existingUser.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - return &responses.GoogleSignInResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(existingUser), - IsNewUser: false, - }, nil - } - } - - // 4. Create new user - username := generateGoogleUsername(email, tokenInfo.GivenName) - - user := &models.User{ - Username: username, - Email: email, - FirstName: tokenInfo.GivenName, - LastName: tokenInfo.FamilyName, - IsActive: true, - } - - // Set a random password (user won't use it since they log in with Google) - randomPassword := generateResetToken() - _ = user.SetPassword(randomPassword) - - if err := s.userRepo.WithContext(ctx).Create(user); err != nil { - return nil, apperrors.Internal(err) - } - - // Create profile (already verified if Google verified email) - profile, _ := s.userRepo.WithContext(ctx).GetOrCreateProfile(user.ID) - if profile != nil && tokenInfo.IsEmailVerified() { - _ = s.userRepo.WithContext(ctx).SetProfileVerified(user.ID, true) - } - - // Create notification preferences with all options enabled - if s.notificationRepo != nil { - if _, err := s.notificationRepo.WithContext(ctx).GetOrCreatePreferences(user.ID); err != nil { - log.Warn().Err(err).Uint("user_id", user.ID).Msg("Failed to create notification preferences for Google Sign In user") - } - } - - // Link Google ID - googleAuthRecord := &models.GoogleSocialAuth{ - UserID: user.ID, - GoogleID: googleID, - Email: email, - Name: tokenInfo.Name, - Picture: tokenInfo.Picture, - } - if err := s.userRepo.WithContext(ctx).CreateGoogleSocialAuth(googleAuthRecord); err != nil { - return nil, apperrors.Internal(err) - } - - // Create token - token, err := s.freshToken(ctx, user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - // B-08: Check error from FindByIDWithProfile - user, err = s.userRepo.WithContext(ctx).FindByIDWithProfile(user.ID) - if err != nil { - return nil, apperrors.Internal(err) - } - - return &responses.GoogleSignInResponse{ - Token: token.Plaintext, - User: responses.NewUserResponse(user), - IsNewUser: true, - }, nil -} - -// Helper functions - -func generateSixDigitCode() string { - // Uniform 000000–999999 via rejection sampling on crypto/rand, - // removing the modulo bias of `n % 1000000` (audit H4). - for { - var b [4]byte - if _, err := rand.Read(b[:]); err != nil { - continue - } - // 4294000000 is the largest multiple of 1e6 <= MaxUint32. - n := binary.BigEndian.Uint32(b[:]) - if n < 4294000000 { - return fmt.Sprintf("%06d", n%1000000) - } - } -} - -func generateResetToken() string { - b := make([]byte, 32) - rand.Read(b) - return hex.EncodeToString(b) -} - -// Helper functions for Apple Sign In - -func getEmailFromRequest(reqEmail *string, claimsEmail string) string { - if reqEmail != nil && *reqEmail != "" { - return *reqEmail - } - return claimsEmail -} - -func getEmailOrDefault(email string) string { - if email == "" { - // Generate a placeholder email for users without one - return fmt.Sprintf("apple_%s@privaterelay.appleid.com", generateResetToken()[:16]) - } - return email -} - -func getStringOrEmpty(s *string) string { - if s == nil { - return "" - } - return *s -} - -func isPrivateRelayEmail(email string) bool { - return strings.HasSuffix(strings.ToLower(email), "@privaterelay.appleid.com") -} - -func generateUniqueUsername(email string, firstName *string) string { - // Try using first part of email - if email != "" && !isPrivateRelayEmail(email) { - parts := strings.Split(email, "@") - if len(parts) > 0 && parts[0] != "" { - // Add random suffix to ensure uniqueness - return parts[0] + "_" + generateResetToken()[:6] - } - } - - // Try using first name - if firstName != nil && *firstName != "" { - return strings.ToLower(*firstName) + "_" + generateResetToken()[:6] - } - - // Fallback to random username - return "user_" + generateResetToken()[:10] -} - -func generateGoogleUsername(email string, firstName string) string { - // Try using first part of email - if email != "" { - parts := strings.Split(email, "@") - if len(parts) > 0 && parts[0] != "" { - // Add random suffix to ensure uniqueness - return parts[0] + "_" + generateResetToken()[:6] - } - } - - // Try using first name - if firstName != "" { - return strings.ToLower(firstName) + "_" + generateResetToken()[:6] - } - - // Fallback to random username - return "google_" + generateResetToken()[:10] + return fileURLs, nil } diff --git a/internal/services/auth_service_test.go b/internal/services/auth_service_test.go index 8b4ab93..bf27597 100644 --- a/internal/services/auth_service_test.go +++ b/internal/services/auth_service_test.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,195 +18,18 @@ func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository) db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) notifRepo := repositories.NewNotificationRepository(db) - cfg := &config.Config{ - Server: config.ServerConfig{ - DebugFixedCodes: true, - }, - Security: config.SecurityConfig{ - SecretKey: "test-secret", - ConfirmationExpiry: 24 * time.Hour, - PasswordResetExpiry: 15 * time.Minute, - MaxPasswordResetRate: 3, - TokenExpiryDays: 90, - TokenRefreshDays: 60, - }, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) service.SetNotificationRepository(notifRepo) return service, userRepo } -// === Login === - -func TestAuthService_Login(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") - - req := &requests.LoginRequest{ - Username: "testuser", - Password: "Password123", - } - - resp, err := service.Login(context.Background(), req, "") - require.NoError(t, err) - assert.NotEmpty(t, resp.Token) - assert.Equal(t, "testuser", resp.User.Username) -} - -func TestAuthService_Login_ByEmail(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") - - req := &requests.LoginRequest{ - Email: "test@test.com", - Password: "Password123", - } - - resp, err := service.Login(context.Background(), req, "") - require.NoError(t, err) - assert.NotEmpty(t, resp.Token) -} - -func TestAuthService_Login_InvalidCredentials(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") - - req := &requests.LoginRequest{ - Username: "testuser", - Password: "WrongPassword1", - } - - _, err := service.Login(context.Background(), req, "") - testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") -} - -func TestAuthService_Login_UserNotFound(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - req := &requests.LoginRequest{ - Username: "nonexistent", - Password: "Password123", - } - - _, err := service.Login(context.Background(), req, "") - testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") -} - -func TestAuthService_Login_InactiveUser(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123") - // Deactivate - user.IsActive = false - db.Save(user) - - req := &requests.LoginRequest{ - Username: "inactive", - Password: "Password123", - } - - _, err := service.Login(context.Background(), req, "") - // Audit L1: inactive accounts return the same generic error as bad - // credentials so login does not disclose which accounts exist. - testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") -} - -// === Register === - -func TestAuthService_Register(t *testing.T) { - service, _ := setupAuthService(t) - - req := &requests.RegisterRequest{ - Username: "newuser", - Email: "new@test.com", - Password: "Password123", - } - - resp, code, err := service.Register(context.Background(), req) - require.NoError(t, err) - assert.NotEmpty(t, resp.Token) - assert.Equal(t, "newuser", resp.User.Username) - assert.Equal(t, "123456", code) // DebugFixedCodes=true -} - -func TestAuthService_Register_DuplicateUsername(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Server: config.ServerConfig{DebugFixedCodes: true}, - Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour}, - } - service := NewAuthService(userRepo, cfg) - - testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123") - - req := &requests.RegisterRequest{ - Username: "taken", - Email: "different@test.com", - Password: "Password123", - } - - _, _, err := service.Register(context.Background(), req) - testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken") -} - -func TestAuthService_Register_DuplicateEmail(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Server: config.ServerConfig{DebugFixedCodes: true}, - Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour}, - } - service := NewAuthService(userRepo, cfg) - - testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123") - - req := &requests.RegisterRequest{ - Username: "newuser", - Email: "taken@test.com", - Password: "Password123", - } - - _, _, err := service.Register(context.Background(), req) - testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken") -} - // === GetCurrentUser === func TestAuthService_GetCurrentUser(t *testing.T) { db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") @@ -218,7 +40,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) { require.NoError(t, err) assert.Equal(t, "testuser", resp.Username) assert.Equal(t, "test@test.com", resp.Email) - assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth + assert.Equal(t, "kratos", resp.AuthProvider) // All users are Kratos-managed } // === UpdateProfile === @@ -226,9 +48,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) { func TestAuthService_UpdateProfile(t *testing.T) { db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") @@ -250,9 +70,7 @@ func TestAuthService_UpdateProfile(t *testing.T) { func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) { db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123") @@ -271,9 +89,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) { func TestAuthService_UpdateProfile_SameEmail(t *testing.T) { db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") @@ -290,443 +106,10 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) { assert.Equal(t, "test@test.com", resp.Email) } -// === VerifyEmail === - -func TestAuthService_VerifyEmail(t *testing.T) { - service, _ := setupAuthService(t) - - // Register a user (creates confirmation code) - req := &requests.RegisterRequest{ - Username: "newuser", - Email: "new@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), req) - require.NoError(t, err) - - // Get the user ID - user, err := service.userRepo.FindByEmail("new@test.com") - require.NoError(t, err) - - // Verify with the debug code - err = service.VerifyEmail(context.Background(), user.ID, "123456") - require.NoError(t, err) - - // Verify again — should get already verified error - err = service.VerifyEmail(context.Background(), user.ID, "123456") - testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified") -} - -func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) { - service, _ := setupAuthService(t) - - // Register - req := &requests.RegisterRequest{ - Username: "newuser", - Email: "new@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), req) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("new@test.com") - require.NoError(t, err) - - // Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup, - // but a wrong code should use the normal path - err = service.VerifyEmail(context.Background(), user.ID, "000000") - assert.Error(t, err) -} - -// === ResendVerificationCode === - -func TestAuthService_ResendVerificationCode(t *testing.T) { - service, _ := setupAuthService(t) - - // Register - req := &requests.RegisterRequest{ - Username: "newuser", - Email: "new@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), req) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("new@test.com") - require.NoError(t, err) - - code, err := service.ResendVerificationCode(context.Background(), user.ID) - require.NoError(t, err) - assert.Equal(t, "123456", code) // DebugFixedCodes -} - -func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) { - service, _ := setupAuthService(t) - - // Register and verify - req := &requests.RegisterRequest{ - Username: "newuser", - Email: "new@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), req) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("new@test.com") - require.NoError(t, err) - - err = service.VerifyEmail(context.Background(), user.ID, "123456") - require.NoError(t, err) - - _, err = service.ResendVerificationCode(context.Background(), user.ID) - testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified") -} - -// === ForgotPassword === - -func TestAuthService_ForgotPassword(t *testing.T) { - service, _ := setupAuthService(t) - - // Register a user first - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - code, user, err := service.ForgotPassword(context.Background(), "test@test.com") - require.NoError(t, err) - assert.Equal(t, "123456", code) // DebugFixedCodes - assert.NotNil(t, user) - assert.Equal(t, "test@test.com", user.Email) -} - -func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) { - service, _ := setupAuthService(t) - - // Should not reveal that email doesn't exist - code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com") - require.NoError(t, err) - assert.Empty(t, code) - assert.Nil(t, user) -} - -// === ResetPassword === - -func TestAuthService_ResetPassword(t *testing.T) { - service, _ := setupAuthService(t) - - // Register - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - // Forgot password - _, _, err = service.ForgotPassword(context.Background(), "test@test.com") - require.NoError(t, err) - - // Verify reset code to get the token - resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456") - require.NoError(t, err) - assert.NotEmpty(t, resetToken) - - // Reset password - err = service.ResetPassword(context.Background(), resetToken, "NewPassword123") - require.NoError(t, err) - - // Login with new password - loginReq := &requests.LoginRequest{ - Username: "testuser", - Password: "NewPassword123", - } - loginResp, err := service.Login(context.Background(), loginReq, "") - require.NoError(t, err) - assert.NotEmpty(t, loginResp.Token) -} - -func TestAuthService_ResetPassword_InvalidToken(t *testing.T) { - service, _ := setupAuthService(t) - - err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123") - testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token") -} - -// === Logout === - -func TestAuthService_Logout(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") - - // Login first - loginReq := &requests.LoginRequest{ - Username: "testuser", - Password: "Password123", - } - loginResp, err := service.Login(context.Background(), loginReq, "") - require.NoError(t, err) - - // Logout - err = service.Logout(context.Background(), loginResp.Token) - require.NoError(t, err) - - // Token should be deleted — refreshing should fail - _, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID) - assert.Error(t, err) -} - -// === DeleteAccount === - -func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) { - service, _ := setupAuthService(t) - - // Register - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("test@test.com") - require.NoError(t, err) - - password := "Password123" - _, err = service.DeleteAccount(context.Background(), user.ID, &password, nil) - require.NoError(t, err) -} - -func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) { - service, _ := setupAuthService(t) - - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("test@test.com") - require.NoError(t, err) - - wrongPassword := "WrongPassword1" - _, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil) - testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") -} - -func TestAuthService_DeleteAccount_NoPassword(t *testing.T) { - service, _ := setupAuthService(t) - - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("test@test.com") - require.NoError(t, err) - - _, err = service.DeleteAccount(context.Background(), user.ID, nil, nil) - testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required") -} - -func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) { - service, _ := setupAuthService(t) - - password := "Password123" - _, err := service.DeleteAccount(context.Background(), 99999, &password, nil) - testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found") -} - -// === Helper functions === - -func TestGenerateSixDigitCode(t *testing.T) { - code := generateSixDigitCode() - assert.Len(t, code, 6) - // Should be numeric - for _, c := range code { - assert.True(t, c >= '0' && c <= '9', "code should contain only digits") - } -} - -func TestGenerateResetToken(t *testing.T) { - token := generateResetToken() - assert.NotEmpty(t, token) - assert.Len(t, token, 64) // 32 bytes = 64 hex chars -} - -func TestGetStringOrEmpty(t *testing.T) { - s := "hello" - assert.Equal(t, "hello", getStringOrEmpty(&s)) - assert.Equal(t, "", getStringOrEmpty(nil)) -} - -func TestIsPrivateRelayEmail(t *testing.T) { - assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com")) - assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM")) - assert.False(t, isPrivateRelayEmail("user@gmail.com")) -} - -func TestGetEmailFromRequest(t *testing.T) { - email := "req@test.com" - assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com")) - assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com")) - empty := "" - assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com")) -} - -// === getEmailOrDefault === - -func TestGetEmailOrDefault(t *testing.T) { - // Non-empty email returns itself - assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com")) - - // Empty email returns a generated placeholder - result := getEmailOrDefault("") - assert.Contains(t, result, "@privaterelay.appleid.com") - assert.Contains(t, result, "apple_") -} - -// === generateUniqueUsername === - -func TestGenerateUniqueUsername(t *testing.T) { - // Normal email generates username from email prefix - username := generateUniqueUsername("john@test.com", nil) - assert.Contains(t, username, "john_") - - // Private relay email falls back to first name - firstName := "Jane" - username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName) - assert.Contains(t, username, "jane_") - - // Private relay email and no first name — fallback - username = generateUniqueUsername("abc@privaterelay.appleid.com", nil) - assert.Contains(t, username, "user_") - - // Empty email with first name - firstName2 := "Bob" - username = generateUniqueUsername("", &firstName2) - assert.Contains(t, username, "bob_") - - // Empty email and no first name - username = generateUniqueUsername("", nil) - assert.Contains(t, username, "user_") -} - -// === generateGoogleUsername === - -func TestGenerateGoogleUsername(t *testing.T) { - // Normal email - username := generateGoogleUsername("john@gmail.com", "John") - assert.Contains(t, username, "john_") - - // Empty email falls back to first name - username = generateGoogleUsername("", "Alice") - assert.Contains(t, username, "alice_") - - // Empty email and empty first name — fallback - username = generateGoogleUsername("", "") - assert.Contains(t, username, "google_") -} - -// === Login with empty password === - -func TestAuthService_Login_EmptyPassword(t *testing.T) { - db := testutil.SetupTestDB(t) - userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } - service := NewAuthService(userRepo, cfg) - - testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") - - req := &requests.LoginRequest{ - Username: "testuser", - Password: "", - } - - _, err := service.Login(context.Background(), req, "") - testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials") -} - -// === ForgotPassword rate limiting === - -func TestAuthService_ForgotPassword_RateLimit(t *testing.T) { - service, _ := setupAuthService(t) - - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - // Make max allowed reset requests (3 based on setup) - for i := 0; i < 3; i++ { - _, _, err := service.ForgotPassword(context.Background(), "test@test.com") - require.NoError(t, err) - } - - // The 4th should be rate limited - _, _, err = service.ForgotPassword(context.Background(), "test@test.com") - assert.Error(t, err) -} - -// === VerifyResetCode with wrong code === - -func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) { - service, _ := setupAuthService(t) - - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - _, _, err = service.ForgotPassword(context.Background(), "test@test.com") - require.NoError(t, err) - - // Wrong code but with debug mode, "123456" works, "000000" should fail - _, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000") - assert.Error(t, err) -} - -// === VerifyResetCode with nonexistent email === - -func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) { - service, _ := setupAuthService(t) - - _, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456") - assert.Error(t, err) -} - -// === UpdateProfile — change email to new email === - func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) { db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") @@ -742,25 +125,44 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) { assert.Equal(t, "newemail@test.com", resp.Email) } -// === DeleteAccount — empty password string === +// === DeleteAccount === -func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) { +func TestAuthService_DeleteAccount_WithConfirmation(t *testing.T) { + service, userRepo := setupAuthService(t) + + user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "") + _ = user + + confirmation := "DELETE" + _, err := service.DeleteAccount(context.Background(), user.ID, nil, &confirmation) + require.NoError(t, err) +} + +func TestAuthService_DeleteAccount_WrongConfirmation(t *testing.T) { + service, userRepo := setupAuthService(t) + + user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "") + + wrongConf := "delete" + _, err := service.DeleteAccount(context.Background(), user.ID, nil, &wrongConf) + testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required") +} + +func TestAuthService_DeleteAccount_NoConfirmation(t *testing.T) { + service, userRepo := setupAuthService(t) + + user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "") + + _, err := service.DeleteAccount(context.Background(), user.ID, nil, nil) + testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required") +} + +func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) { service, _ := setupAuthService(t) - registerReq := &requests.RegisterRequest{ - Username: "testuser", - Email: "test@test.com", - Password: "Password123", - } - _, _, err := service.Register(context.Background(), registerReq) - require.NoError(t, err) - - user, err := service.userRepo.FindByEmail("test@test.com") - require.NoError(t, err) - - emptyPw := "" - _, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil) - testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required") + confirmation := "DELETE" + _, err := service.DeleteAccount(context.Background(), 99999, nil, &confirmation) + testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found") } // === SetNotificationRepository === @@ -769,35 +171,10 @@ func TestAuthService_SetNotificationRepository(t *testing.T) { db := testutil.SetupTestDB(t) userRepo := repositories.NewUserRepository(db) notifRepo := repositories.NewNotificationRepository(db) - cfg := &config.Config{ - Security: config.SecurityConfig{SecretKey: "test-secret"}, - } + cfg := &config.Config{} service := NewAuthService(userRepo, cfg) assert.Nil(t, service.notificationRepo) service.SetNotificationRepository(notifRepo) assert.NotNil(t, service.notificationRepo) } - -// === Register creates profile and notification preferences === - -func TestAuthService_Register_CreatesProfile(t *testing.T) { - service, userRepo := setupAuthService(t) - - req := &requests.RegisterRequest{ - Username: "profileuser", - Email: "profile@test.com", - Password: "Password123", - FirstName: "John", - LastName: "Doe", - } - - resp, _, err := service.Register(context.Background(), req) - require.NoError(t, err) - assert.Equal(t, "profileuser", resp.User.Username) - - // Profile should exist - profile, err := userRepo.GetOrCreateProfile(resp.User.ID) - require.NoError(t, err) - assert.NotNil(t, profile) -} diff --git a/internal/services/cache_service.go b/internal/services/cache_service.go index b977e12..3c403c0 100644 --- a/internal/services/cache_service.go +++ b/internal/services/cache_service.go @@ -12,7 +12,6 @@ import ( "github.com/rs/zerolog/log" "github.com/treytartt/honeydue-api/internal/config" - "github.com/treytartt/honeydue-api/internal/models" ) // CacheService provides Redis caching functionality @@ -134,116 +133,6 @@ func (c *CacheService) Close() error { return nil } -// Auth token cache helpers -const ( - AuthTokenPrefix = "auth_token_" - TokenCacheTTL = 5 * time.Minute -) - -// authTokenCacheKey returns the Redis key for an auth token. The raw token -// is hashed (audit C1) so the plaintext token never appears in a Redis key. -func authTokenCacheKey(token string) string { - return AuthTokenPrefix + models.HashToken(token) -} - -// CacheAuthToken caches a user ID for a token -func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error { - return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL) -} - -// CacheAuthTokenWithCreated caches a user ID and token creation time for a token -func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error { - return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL) -} - -// GetCachedAuthToken gets a cached user ID for a token -func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) { - val, err := c.GetString(ctx, authTokenCacheKey(token)) - if err != nil { - return 0, err - } - - var userID uint - _, err = fmt.Sscanf(val, "%d", &userID) - return userID, err -} - -// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time. -// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format). -func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) { - val, err := c.GetString(ctx, authTokenCacheKey(token)) - if err != nil { - return 0, 0, err - } - - var userID uint - var createdUnix int64 - n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix) - if n < 1 { - return 0, 0, fmt.Errorf("invalid cached token format") - } - return userID, createdUnix, nil -} - -// InvalidateAuthToken removes a cached token -func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error { - return c.Delete(ctx, authTokenCacheKey(token)) -} - -// InvalidateAuthTokenHashes removes cached entries for already-hashed token -// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the -// stored hash directly — used to evict a user's prior token on re-login -// (audit MEDIUM-1), where the server no longer has the plaintext. -func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error { - keys := make([]string, 0, len(hashes)) - for _, h := range hashes { - if h != "" { - keys = append(keys, AuthTokenPrefix+h) - } - } - if len(keys) == 0 { - return nil - } - return c.Delete(ctx, keys...) -} - -// --- Per-account login-failure tracking (audit M5) --- - -const loginFailPrefix = "login_fail:" - -// RegisterLoginFailure records a failed login for an account from a given -// source IP, and returns the number of DISTINCT source IPs that have failed -// for this account within the window. Tracking distinct IPs as a set rather -// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot -// run the count up and lock a victim out by knowing only their email — a -// single IP is bounded by the per-IP edge/app rate limiters instead. A -// genuinely distributed credential-stuffing attack still trips the lockout. -func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) { - key := loginFailPrefix + identifier - member := ip - if member == "" { - member = "unknown" - } - if err := c.client.SAdd(ctx, key, member).Err(); err != nil { - return 0, err - } - // Refresh the TTL on each failure: an active attack keeps the window - // open, while a quiet account ages out `window` after its last failure. - _ = c.client.Expire(ctx, key, window).Err() - return c.client.SCard(ctx, key).Result() -} - -// LoginFailureIPCount returns how many distinct source IPs have failed to log -// in to this account within the window (audit MEDIUM-3). SCard on a missing -// key returns 0. -func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) { - return c.client.SCard(ctx, loginFailPrefix+identifier).Result() -} - -// ClearLoginFailures resets the failed-login IP set after a successful login. -func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error { - return c.client.Del(ctx, loginFailPrefix+identifier).Err() -} // Static data cache helpers const ( diff --git a/internal/services/google_auth.go b/internal/services/google_auth.go deleted file mode 100644 index cbd1ecc..0000000 --- a/internal/services/google_auth.go +++ /dev/null @@ -1,307 +0,0 @@ -package services - -import ( - "context" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "math/big" - "net/http" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" - - "github.com/treytartt/honeydue-api/internal/config" -) - -const ( - // googleKeysURL is Google's JWKS endpoint for ID-token signature verification. - googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs" - googleKeysCacheTTL = 24 * time.Hour - googleKeysCacheKey = "google:public_keys" -) - -// googleIssuers is the set of valid `iss` claim values for a Google ID token. -var googleIssuers = map[string]bool{ - "accounts.google.com": true, - "https://accounts.google.com": true, -} - -var ( - ErrInvalidGoogleToken = errors.New("invalid Google ID token") - ErrGoogleTokenExpired = errors.New("Google ID token has expired") - ErrInvalidGoogleAudience = errors.New("invalid Google token audience") - ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer") - ErrGoogleKeyNotFound = errors.New("Google public key not found") -) - -// GoogleJWKS represents Google's JSON Web Key Set. -type GoogleJWKS struct { - Keys []GoogleJWK `json:"keys"` -} - -// GoogleJWK represents a single JSON Web Key from Google. -type GoogleJWK struct { - Kty string `json:"kty"` // Key type (RSA) - Kid string `json:"kid"` // Key ID - Use string `json:"use"` // Key use (sig) - Alg string `json:"alg"` // Algorithm (RS256) - N string `json:"n"` // RSA modulus - E string `json:"e"` // RSA exponent -} - -// GoogleTokenClaims represents the claims in a Google ID token JWT. -type GoogleTokenClaims struct { - jwt.RegisteredClaims - Email string `json:"email,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - Name string `json:"name,omitempty"` - GivenName string `json:"given_name,omitempty"` - FamilyName string `json:"family_name,omitempty"` - Picture string `json:"picture,omitempty"` - Azp string `json:"azp,omitempty"` // Authorized party -} - -// GoogleTokenInfo is the verified, caller-facing view of a Google ID token. -type GoogleTokenInfo struct { - Sub string // Unique Google user ID - Email string - EmailVerified string // "true" or "false" — string for caller compatibility - Name string - GivenName string - FamilyName string - Picture string - Aud string - Azp string - Iss string -} - -// IsEmailVerified returns whether the email is verified. -func (t *GoogleTokenInfo) IsEmailVerified() bool { - return t.EmailVerified == "true" -} - -// GoogleAuthService handles Google Sign In token verification. -type GoogleAuthService struct { - cache *CacheService - config *config.Config - client *http.Client -} - -// NewGoogleAuthService creates a new Google auth service. -func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService { - return &GoogleAuthService{ - cache: cache, - config: cfg, - client: &http.Client{Timeout: 10 * time.Second}, - } -} - -// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks -// the RS256 signature against Google's published JWKS and the iss, aud, and -// exp claims. It never sends the token to a third-party endpoint, so it no -// longer depends on the deprecated tokeninfo service and never leaks the -// token in a request URL. -func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) { - // Parse the token header to get the key ID. - parts := strings.Split(idToken, ".") - if len(parts) != 3 { - return nil, ErrInvalidGoogleToken - } - - headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - return nil, fmt.Errorf("failed to decode token header: %w", err) - } - var header struct { - Kid string `json:"kid"` - Alg string `json:"alg"` - } - if err := json.Unmarshal(headerBytes, &header); err != nil { - return nil, fmt.Errorf("failed to parse token header: %w", err) - } - - publicKey, err := s.getPublicKey(ctx, header.Kid) - if err != nil { - return nil, err - } - - // Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically. - token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return publicKey, nil - }) - if err != nil { - if errors.Is(err, jwt.ErrTokenExpired) { - return nil, ErrGoogleTokenExpired - } - return nil, fmt.Errorf("failed to parse token: %w", err) - } - - claims, ok := token.Claims.(*GoogleTokenClaims) - if !ok || !token.Valid { - return nil, ErrInvalidGoogleToken - } - - // Verify the issuer (audit C3). - if !googleIssuers[claims.Issuer] { - return nil, ErrInvalidGoogleIssuer - } - - // Verify the audience matches one of our configured client IDs. - if !s.verifyAudience(claims.Audience, claims.Azp) { - return nil, ErrInvalidGoogleAudience - } - - if claims.Subject == "" { - return nil, ErrInvalidGoogleToken - } - - emailVerified := "false" - if claims.EmailVerified { - emailVerified = "true" - } - aud := "" - if len(claims.Audience) > 0 { - aud = claims.Audience[0] - } - return &GoogleTokenInfo{ - Sub: claims.Subject, - Email: claims.Email, - EmailVerified: emailVerified, - Name: claims.Name, - GivenName: claims.GivenName, - FamilyName: claims.FamilyName, - Picture: claims.Picture, - Aud: aud, - Azp: claims.Azp, - Iss: claims.Issuer, - }, nil -} - -// verifyAudience checks the token audience against our configured client IDs. -// In production (non-debug) an empty client ID fails verification rather than -// silently bypassing the check. -func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool { - clientID := s.config.GoogleAuth.ClientID - if clientID == "" { - // In debug mode only, skip audience verification for local development. - return s.config.Server.Debug - } - - candidates := []string{clientID} - if id := s.config.GoogleAuth.AndroidClientID; id != "" { - candidates = append(candidates, id) - } - if id := s.config.GoogleAuth.IOSClientID; id != "" { - candidates = append(candidates, id) - } - - for _, want := range candidates { - if azp == want { - return true - } - for _, aud := range audience { - if aud == want { - return true - } - } - } - return false -} - -// getPublicKey returns the RSA public key for the given key ID, using a -// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss -// (Google rotates signing keys roughly daily). -func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) { - keys, err := s.getCachedKeys(ctx) - if err != nil || keys == nil { - keys, err = s.fetchGooglePublicKeys(ctx) - if err != nil { - return nil, err - } - } - if pubKey, ok := keys[kid]; ok { - return pubKey, nil - } - - // Cache miss for this kid — keys may have rotated; fetch fresh. - keys, err = s.fetchGooglePublicKeys(ctx) - if err != nil { - return nil, err - } - if pubKey, ok := keys[kid]; ok { - return pubKey, nil - } - return nil, ErrGoogleKeyNotFound -} - -// getCachedKeys retrieves cached Google public keys from Redis. -func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) { - if s.cache == nil { - return nil, nil - } - data, err := s.cache.GetString(ctx, googleKeysCacheKey) - if err != nil || data == "" { - return nil, nil - } - var jwks GoogleJWKS - if err := json.Unmarshal([]byte(data), &jwks); err != nil { - return nil, nil - } - return s.parseJWKS(&jwks), nil -} - -// fetchGooglePublicKeys fetches Google's JWKS and caches it. -func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - resp, err := s.client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to fetch Google keys: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode) - } - var jwks GoogleJWKS - if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { - return nil, fmt.Errorf("failed to decode Google keys: %w", err) - } - if s.cache != nil { - keysJSON, _ := json.Marshal(jwks) - _ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL) - } - return s.parseJWKS(&jwks), nil -} - -// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID. -func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey { - keys := make(map[string]*rsa.PublicKey) - for _, key := range jwks.Keys { - if key.Kty != "RSA" { - continue - } - nBytes, err := base64.RawURLEncoding.DecodeString(key.N) - if err != nil { - continue - } - eBytes, err := base64.RawURLEncoding.DecodeString(key.E) - if err != nil { - continue - } - e := 0 - for _, b := range eBytes { - e = e<<8 + int(b) - } - keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e} - } - return keys -} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 70f34ac..846a703 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -22,6 +22,14 @@ import ( "github.com/treytartt/honeydue-api/internal/validator" ) +// authUserKey and authTokenKey mirror middleware.AuthUserKey / middleware.AuthTokenKey. +// We duplicate the string constants here to avoid an import cycle +// (testutil <- middleware <- repositories <- testutil). +const ( + authUserKey = "auth_user" + authTokenKey = "auth_token" +) + var ( i18nOnce sync.Once testDBCounter uint64 @@ -52,9 +60,6 @@ func SetupTestDB(t *testing.T) *gorm.DB { err = db.AutoMigrate( &models.User{}, &models.UserProfile{}, - &models.AuthToken{}, - &models.ConfirmationCode{}, - &models.PasswordResetCode{}, &models.AdminUser{}, &models.Residence{}, &models.ResidenceType{}, @@ -73,8 +78,6 @@ func SetupTestDB(t *testing.T) *gorm.DB { &models.NotificationPreference{}, &models.APNSDevice{}, &models.GCMDevice{}, - &models.AppleSocialAuth{}, - &models.GoogleSocialAuth{}, &models.TaskReminderLog{}, &models.UserSubscription{}, &models.SubscriptionSettings{}, @@ -177,29 +180,24 @@ func ParseJSONArray(t *testing.T, body []byte) []map[string]interface{} { return result } -// CreateTestUser creates a test user in the database -func CreateTestUser(t *testing.T, db *gorm.DB, username, email, password string) *models.User { +// CreateTestUser creates a test user in the database. +// password is accepted for API compatibility but ignored — Kratos owns credentials. +// A synthetic KratosID is generated so the user satisfies the unique-index constraint. +func CreateTestUser(t *testing.T, db *gorm.DB, username, email, _ string) *models.User { + t.Helper() user := &models.User{ + KratosID: "test-kratos-" + username, Username: username, Email: email, IsActive: true, } - err := user.SetPassword(password) - require.NoError(t, err) - err = db.Create(user).Error + err := db.Create(user).Error require.NoError(t, err) return user } -// CreateTestToken creates an auth token for a user -func CreateTestToken(t *testing.T, db *gorm.DB, userID uint) *models.AuthToken { - token, err := models.GetOrCreateToken(db, userID) - require.NoError(t, err) - return token -} - // CreateTestResidenceType creates a test residence type func CreateTestResidenceType(t *testing.T, db *gorm.DB, name string) *models.ResidenceType { rt := &models.ResidenceType{Name: name} @@ -362,12 +360,15 @@ func AssertStatusCode(t *testing.T, w *httptest.ResponseRecorder, expected int) require.Equal(t, expected, w.Code, "unexpected status code: %s", w.Body.String()) } -// MockAuthMiddleware creates middleware that sets a test user in context +// MockAuthMiddleware creates middleware that sets a test user in context. +// Uses the same context keys as KratosAuth (authUserKey / authTokenKey) so +// handlers are unaware of the swap. The constants are duplicated here to +// avoid an import cycle (testutil <- middleware <- repositories <- testutil). func MockAuthMiddleware(user *models.User) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - c.Set("auth_user", user) - c.Set("auth_token", "test-token") + c.Set(authUserKey, user) + c.Set(authTokenKey, "test-token") return next(c) } } diff --git a/migrations/000007_kratos_identity.sql b/migrations/000007_kratos_identity.sql new file mode 100644 index 0000000..246c43c --- /dev/null +++ b/migrations/000007_kratos_identity.sql @@ -0,0 +1,37 @@ +-- +goose Up +-- Phase 2: hand-rolled auth replaced by Ory Kratos. Kratos owns identities, +-- credentials, sessions, email verification, recovery and social sign-in. +-- honeyDue keeps a slim auth_user row linked to the Kratos identity by +-- kratos_id; all domain tables keep their existing integer auth_user FKs. +-- +-- Pre-production: a clean slate is taken. auth_user is truncated (cascading +-- to all user-scoped domain data) so no auth_user row exists without a +-- Kratos identity behind it. There is no data migration. + +-- honeyDue's hand-rolled auth tables are no longer used — Kratos owns this. +DROP TABLE IF EXISTS user_authtoken; +DROP TABLE IF EXISTS user_confirmationcode; +DROP TABLE IF EXISTS user_passwordresetcode; +DROP TABLE IF EXISTS user_applesocialauth; +DROP TABLE IF EXISTS user_googlesocialauth; + +-- Link each auth_user row to its Kratos identity (UUID). +ALTER TABLE auth_user ADD COLUMN IF NOT EXISTS kratos_id uuid; +CREATE UNIQUE INDEX IF NOT EXISTS uq_auth_user_kratos_id + ON auth_user (kratos_id) WHERE kratos_id IS NOT NULL; + +-- password is NOT NULL in the Django-era schema but is no longer used — +-- Kratos holds credentials. Make it nullable so provisioning need not +-- invent a placeholder hash. +ALTER TABLE auth_user ALTER COLUMN password DROP NOT NULL; + +-- Clean slate (pre-production): drop every existing account and all +-- user-scoped domain data so nothing is left orphaned without a Kratos id. +TRUNCATE TABLE auth_user CASCADE; + +-- +goose Down +-- The dropped tables' data cannot be restored. Down only removes the +-- kratos_id column and restores the password NOT NULL constraint; reverting +-- to hand-rolled auth means reverting the Phase 2 application code. +DROP INDEX IF EXISTS uq_auth_user_kratos_id; +ALTER TABLE auth_user DROP COLUMN IF EXISTS kratos_id;