feat(auth): replace hand-rolled auth with Ory Kratos — phase 2 backend
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled
Backend CI / Build (push) Has been cancelled

Delegates all credential management (login, register, password reset,
email verification, social sign-in) to Ory Kratos. The Go API now acts
as a resource server: the new KratosAuth middleware validates sessions
against the Kratos whoami endpoint, writes the local User mirror into
Echo context, and all existing domain handlers continue working
unchanged. Hand-rolled token auth, AuthToken model, apple_auth/
google_auth services, and the auth refresh flow are removed. Tests are
updated to use the fake-token middleware pattern so existing integration
assertions require no rewrite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-05-18 17:55:56 -05:00
parent b66151ddd9
commit 81578f6e27
36 changed files with 927 additions and 7002 deletions
@@ -1,215 +1,30 @@
// apple_social_auth_handler is a stub — the user_applesocialauth table was
// dropped in the Ory Kratos migration (phase 2). Social sign-in is now
// handled by Kratos.
package handlers package handlers
import ( import (
"net/http" "net/http"
"strconv"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
"github.com/treytartt/honeydue-api/internal/models"
) )
// AdminAppleSocialAuthHandler handles admin Apple social auth management endpoints // AdminAppleSocialAuthHandler is a no-op stub.
type AdminAppleSocialAuthHandler struct { type AdminAppleSocialAuthHandler struct {
db *gorm.DB db *gorm.DB
} }
// NewAdminAppleSocialAuthHandler creates a new admin Apple social auth handler
func NewAdminAppleSocialAuthHandler(db *gorm.DB) *AdminAppleSocialAuthHandler { func NewAdminAppleSocialAuthHandler(db *gorm.DB) *AdminAppleSocialAuthHandler {
return &AdminAppleSocialAuthHandler{db: db} return &AdminAppleSocialAuthHandler{db: db}
} }
// AppleSocialAuthResponse represents the response for an Apple social auth entry func (h *AdminAppleSocialAuthHandler) gone(c echo.Context) error {
type AppleSocialAuthResponse struct { return c.JSON(http.StatusGone, map[string]string{"message": "Apple social auth is managed by Ory Kratos"})
ID uint `json:"id"`
UserID uint `json:"user_id"`
Username string `json:"username"`
UserEmail string `json:"user_email"`
AppleID string `json:"apple_id"`
Email string `json:"email"`
IsPrivateEmail bool `json:"is_private_email"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// UpdateAppleSocialAuthRequest represents the request to update an Apple social auth entry
type UpdateAppleSocialAuthRequest struct {
Email *string `json:"email"`
IsPrivateEmail *bool `json:"is_private_email"`
}
// List handles GET /api/admin/apple-social-auth
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var entries []models.AppleSocialAuth
var total int64
query := h.db.Model(&models.AppleSocialAuth{}).Preload("User")
// Apply search
if filters.Search != "" {
search := "%" + filters.Search + "%"
query = query.Joins("JOIN auth_user ON auth_user.id = user_applesocialauth.user_id").
Where("user_applesocialauth.apple_id ILIKE ? OR user_applesocialauth.email ILIKE ? OR auth_user.username ILIKE ? OR auth_user.email ILIKE ?",
search, search, search, search)
}
// Get total count
query.Count(&total)
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "apple_id", "email", "is_private_email",
"created_at", "updated_at",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
if err := query.Find(&entries).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entries"})
}
// Build response
responses := make([]AppleSocialAuthResponse, len(entries))
for i, entry := range entries {
responses[i] = h.toResponse(&entry)
}
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
}
// Get handles GET /api/admin/apple-social-auth/:id
func (h *AdminAppleSocialAuthHandler) Get(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
var entry models.AppleSocialAuth
if err := h.db.Preload("User").First(&entry, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
}
return c.JSON(http.StatusOK, h.toResponse(&entry))
}
// GetByUser handles GET /api/admin/apple-social-auth/user/:user_id
func (h *AdminAppleSocialAuthHandler) GetByUser(c echo.Context) error {
userID, err := strconv.ParseUint(c.Param("user_id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"})
}
var entry models.AppleSocialAuth
if err := h.db.Preload("User").Where("user_id = ?", userID).First(&entry).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found for user"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
}
return c.JSON(http.StatusOK, h.toResponse(&entry))
}
// Update handles PUT /api/admin/apple-social-auth/:id
func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
var entry models.AppleSocialAuth
if err := h.db.First(&entry, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
}
var req UpdateAppleSocialAuthRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
if req.Email != nil {
entry.Email = *req.Email
}
if req.IsPrivateEmail != nil {
entry.IsPrivateEmail = *req.IsPrivateEmail
}
if err := h.db.Save(&entry).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update Apple social auth entry"})
}
h.db.Preload("User").First(&entry, id)
return c.JSON(http.StatusOK, h.toResponse(&entry))
}
// Delete handles DELETE /api/admin/apple-social-auth/:id
func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
var entry models.AppleSocialAuth
if err := h.db.First(&entry, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
}
if err := h.db.Delete(&entry).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entry"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entry deleted successfully"})
}
// BulkDelete handles DELETE /api/admin/apple-social-auth/bulk
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": result.RowsAffected})
}
// toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse
func (h *AdminAppleSocialAuthHandler) toResponse(entry *models.AppleSocialAuth) AppleSocialAuthResponse {
response := AppleSocialAuthResponse{
ID: entry.ID,
UserID: entry.UserID,
AppleID: entry.AppleID,
Email: entry.Email,
IsPrivateEmail: entry.IsPrivateEmail,
CreatedAt: entry.CreatedAt.Format("2006-01-02T15:04:05Z"),
UpdatedAt: entry.UpdatedAt.Format("2006-01-02T15:04:05Z"),
}
if entry.User.ID != 0 {
response.Username = entry.User.Username
response.UserEmail = entry.User.Email
}
return response
} }
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error { return h.gone(c) }
func (h *AdminAppleSocialAuthHandler) Get(c echo.Context) error { return h.gone(c) }
func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error { return h.gone(c) }
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error { return h.gone(c) }
func (h *AdminAppleSocialAuthHandler) GetByUser(c echo.Context) error { return h.gone(c) }
+9 -126
View File
@@ -1,144 +1,27 @@
// auth_token_handler is a stub — the user_authtoken table was dropped in the
// Ory Kratos migration (phase 2). Auth tokens are now Kratos sessions.
package handlers package handlers
import ( import (
"net/http" "net/http"
"strconv"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
"github.com/treytartt/honeydue-api/internal/models"
) )
// AdminAuthTokenHandler handles admin auth token management endpoints // AdminAuthTokenHandler is a no-op stub.
type AdminAuthTokenHandler struct { type AdminAuthTokenHandler struct {
db *gorm.DB db *gorm.DB
} }
// NewAdminAuthTokenHandler creates a new admin auth token handler
func NewAdminAuthTokenHandler(db *gorm.DB) *AdminAuthTokenHandler { func NewAdminAuthTokenHandler(db *gorm.DB) *AdminAuthTokenHandler {
return &AdminAuthTokenHandler{db: db} return &AdminAuthTokenHandler{db: db}
} }
// AuthTokenResponse represents an auth token in API responses func (h *AdminAuthTokenHandler) gone(c echo.Context) error {
type AuthTokenResponse struct { return c.JSON(http.StatusGone, map[string]string{"message": "auth tokens are managed by Ory Kratos"})
Key string `json:"key"`
UserID uint `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Created string `json:"created"`
}
// List handles GET /api/admin/auth-tokens
func (h *AdminAuthTokenHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var tokens []models.AuthToken
var total int64
query := h.db.Model(&models.AuthToken{}).Preload("User")
// Apply search (search by user info)
if filters.Search != "" {
search := "%" + filters.Search + "%"
query = query.Joins("JOIN auth_user ON auth_user.id = user_authtoken.user_id").
Where(
"auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_authtoken.key ILIKE ?",
search, search, search,
)
}
// Get total count
query.Count(&total)
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"created", "user_id",
}, "created")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
if err := query.Find(&tokens).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch auth tokens"})
}
// Build response
responses := make([]AuthTokenResponse, len(tokens))
for i, token := range tokens {
responses[i] = AuthTokenResponse{
Key: token.Key,
UserID: token.UserID,
Username: token.User.Username,
Email: token.User.Email,
Created: token.Created.Format("2006-01-02T15:04:05Z"),
}
}
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
}
// Get handles GET /api/admin/auth-tokens/:id (id is actually user_id)
func (h *AdminAuthTokenHandler) Get(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"})
}
var token models.AuthToken
if err := h.db.Preload("User").Where("user_id = ?", id).First(&token).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Auth token not found"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch auth token"})
}
response := AuthTokenResponse{
Key: token.Key,
UserID: token.UserID,
Username: token.User.Username,
Email: token.User.Email,
Created: token.Created.Format("2006-01-02T15:04:05Z"),
}
return c.JSON(http.StatusOK, response)
}
// Delete handles DELETE /api/admin/auth-tokens/:id (revoke token)
func (h *AdminAuthTokenHandler) Delete(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"})
}
result := h.db.Where("user_id = ?", id).Delete(&models.AuthToken{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to revoke token"})
}
if result.RowsAffected == 0 {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Auth token not found"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Auth token revoked successfully"})
}
// BulkDelete handles DELETE /api/admin/auth-tokens/bulk
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to revoke tokens"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Auth tokens revoked successfully", "count": result.RowsAffected})
} }
func (h *AdminAuthTokenHandler) List(c echo.Context) error { return h.gone(c) }
func (h *AdminAuthTokenHandler) Get(c echo.Context) error { return h.gone(c) }
func (h *AdminAuthTokenHandler) Delete(c echo.Context) error { return h.gone(c) }
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
@@ -1,162 +1,28 @@
// confirmation_code_handler is a stub — the user_confirmationcode table was
// dropped in the Ory Kratos migration (phase 2). Email verification is now
// handled by Kratos.
package handlers package handlers
import ( import (
"net/http" "net/http"
"strconv"
"strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
"github.com/treytartt/honeydue-api/internal/models"
) )
// maskCode masks a confirmation code, showing only the last 4 characters. // AdminConfirmationCodeHandler is a no-op stub.
func maskCode(code string) string {
if len(code) <= 4 {
return strings.Repeat("*", len(code))
}
return strings.Repeat("*", len(code)-4) + code[len(code)-4:]
}
// AdminConfirmationCodeHandler handles admin confirmation code management endpoints
type AdminConfirmationCodeHandler struct { type AdminConfirmationCodeHandler struct {
db *gorm.DB db *gorm.DB
} }
// NewAdminConfirmationCodeHandler creates a new admin confirmation code handler
func NewAdminConfirmationCodeHandler(db *gorm.DB) *AdminConfirmationCodeHandler { func NewAdminConfirmationCodeHandler(db *gorm.DB) *AdminConfirmationCodeHandler {
return &AdminConfirmationCodeHandler{db: db} return &AdminConfirmationCodeHandler{db: db}
} }
// ConfirmationCodeResponse represents a confirmation code in API responses func (h *AdminConfirmationCodeHandler) gone(c echo.Context) error {
type ConfirmationCodeResponse struct { return c.JSON(http.StatusGone, map[string]string{"message": "confirmation codes are managed by Ory Kratos"})
ID uint `json:"id"`
UserID uint `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Code string `json:"code"`
ExpiresAt string `json:"expires_at"`
IsUsed bool `json:"is_used"`
CreatedAt string `json:"created_at"`
}
// List handles GET /api/admin/confirmation-codes
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var codes []models.ConfirmationCode
var total int64
query := h.db.Model(&models.ConfirmationCode{}).Preload("User")
// Apply search (search by user info or code)
if filters.Search != "" {
search := "%" + filters.Search + "%"
query = query.Joins("JOIN auth_user ON auth_user.id = user_confirmationcode.user_id").
Where(
"auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_confirmationcode.code ILIKE ?",
search, search, search,
)
}
// Get total count
query.Count(&total)
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "created_at", "expires_at", "is_used",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
if err := query.Find(&codes).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch confirmation codes"})
}
// Build response
responses := make([]ConfirmationCodeResponse, len(codes))
for i, code := range codes {
responses[i] = ConfirmationCodeResponse{
ID: code.ID,
UserID: code.UserID,
Username: code.User.Username,
Email: code.User.Email,
Code: maskCode(code.Code),
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
IsUsed: code.IsUsed,
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
}
}
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
}
// Get handles GET /api/admin/confirmation-codes/:id
func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
var code models.ConfirmationCode
if err := h.db.Preload("User").First(&code, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Confirmation code not found"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch confirmation code"})
}
response := ConfirmationCodeResponse{
ID: code.ID,
UserID: code.UserID,
Username: code.User.Username,
Email: code.User.Email,
Code: maskCode(code.Code),
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
IsUsed: code.IsUsed,
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
}
return c.JSON(http.StatusOK, response)
}
// Delete handles DELETE /api/admin/confirmation-codes/:id
func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
result := h.db.Delete(&models.ConfirmationCode{}, id)
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation code"})
}
if result.RowsAffected == 0 {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Confirmation code not found"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Confirmation code deleted successfully"})
}
// BulkDelete handles DELETE /api/admin/confirmation-codes/bulk
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Confirmation codes deleted successfully", "count": result.RowsAffected})
} }
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error { return h.gone(c) }
func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error { return h.gone(c) }
func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error { return h.gone(c) }
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
@@ -1,159 +1,28 @@
// password_reset_code_handler is a stub — the user_passwordresetcode table
// was dropped in the Ory Kratos migration (phase 2). Password resets are now
// handled by Kratos.
package handlers package handlers
import ( import (
"net/http" "net/http"
"strconv"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/admin/dto"
"github.com/treytartt/honeydue-api/internal/models"
) )
// AdminPasswordResetCodeHandler handles admin password reset code management endpoints // AdminPasswordResetCodeHandler is a no-op stub.
type AdminPasswordResetCodeHandler struct { type AdminPasswordResetCodeHandler struct {
db *gorm.DB db *gorm.DB
} }
// NewAdminPasswordResetCodeHandler creates a new admin password reset code handler
func NewAdminPasswordResetCodeHandler(db *gorm.DB) *AdminPasswordResetCodeHandler { func NewAdminPasswordResetCodeHandler(db *gorm.DB) *AdminPasswordResetCodeHandler {
return &AdminPasswordResetCodeHandler{db: db} return &AdminPasswordResetCodeHandler{db: db}
} }
// PasswordResetCodeResponse represents a password reset code in API responses func (h *AdminPasswordResetCodeHandler) gone(c echo.Context) error {
type PasswordResetCodeResponse struct { return c.JSON(http.StatusGone, map[string]string{"message": "password reset codes are managed by Ory Kratos"})
ID uint `json:"id"`
UserID uint `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
ResetToken string `json:"reset_token"`
ExpiresAt string `json:"expires_at"`
Used bool `json:"used"`
Attempts int `json:"attempts"`
MaxAttempts int `json:"max_attempts"`
CreatedAt string `json:"created_at"`
}
// List handles GET /api/admin/password-reset-codes
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
var filters dto.PaginationParams
if err := c.Bind(&filters); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
var codes []models.PasswordResetCode
var total int64
query := h.db.Model(&models.PasswordResetCode{}).Preload("User")
// Apply search (search by user info or token)
if filters.Search != "" {
search := "%" + filters.Search + "%"
query = query.Joins("JOIN auth_user ON auth_user.id = user_passwordresetcode.user_id").
Where(
"auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_passwordresetcode.reset_token ILIKE ?",
search, search, search,
)
}
// Get total count
query.Count(&total)
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
sortBy := filters.GetSafeSortBy([]string{
"id", "user_id", "created_at", "expires_at", "used",
}, "created_at")
query = query.Order(sortBy + " " + filters.GetSortDir())
// Apply pagination
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
if err := query.Find(&codes).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch password reset codes"})
}
// Build response
responses := make([]PasswordResetCodeResponse, len(codes))
for i, code := range codes {
responses[i] = PasswordResetCodeResponse{
ID: code.ID,
UserID: code.UserID,
Username: code.User.Username,
Email: code.User.Email,
ResetToken: code.ResetToken[:8] + "..." + code.ResetToken[len(code.ResetToken)-4:], // Truncate for display
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
Used: code.Used,
Attempts: code.Attempts,
MaxAttempts: code.MaxAttempts,
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
}
}
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
}
// Get handles GET /api/admin/password-reset-codes/:id
func (h *AdminPasswordResetCodeHandler) Get(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
var code models.PasswordResetCode
if err := h.db.Preload("User").First(&code, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Password reset code not found"})
}
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch password reset code"})
}
response := PasswordResetCodeResponse{
ID: code.ID,
UserID: code.UserID,
Username: code.User.Username,
Email: code.User.Email,
ResetToken: code.ResetToken[:8] + "..." + code.ResetToken[len(code.ResetToken)-4:],
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
Used: code.Used,
Attempts: code.Attempts,
MaxAttempts: code.MaxAttempts,
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
}
return c.JSON(http.StatusOK, response)
}
// Delete handles DELETE /api/admin/password-reset-codes/:id
func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error {
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
}
result := h.db.Delete(&models.PasswordResetCode{}, id)
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset code"})
}
if result.RowsAffected == 0 {
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Password reset code not found"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Password reset code deleted successfully"})
}
// BulkDelete handles DELETE /api/admin/password-reset-codes/bulk
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error {
var req dto.BulkDeleteRequest
if err := c.Bind(&req); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
}
result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{})
if result.Error != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Password reset codes deleted successfully", "count": result.RowsAffected})
} }
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error { return h.gone(c) }
func (h *AdminPasswordResetCodeHandler) Get(c echo.Context) error { return h.gone(c) }
func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error { return h.gone(c) }
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
+3 -6
View File
@@ -207,9 +207,7 @@ func (h *AdminUserHandler) Create(c echo.Context) error {
user.IsSuperuser = *req.IsSuperuser user.IsSuperuser = *req.IsSuperuser
} }
if err := user.SetPassword(req.Password); err != nil { // Password management is handled by Ory Kratos; no local password hashing.
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to hash password"})
}
if err := h.db.Create(&user).Error; err != nil { if err := h.db.Create(&user).Error; err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create user"}) return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create user"})
@@ -284,10 +282,9 @@ func (h *AdminUserHandler) Update(c echo.Context) error {
if req.IsSuperuser != nil { if req.IsSuperuser != nil {
user.IsSuperuser = *req.IsSuperuser user.IsSuperuser = *req.IsSuperuser
} }
// Password management is handled by Ory Kratos; local password update ignored.
if req.Password != nil { if req.Password != nil {
if err := user.SetPassword(*req.Password); err != nil { _ = req.Password // Password changes must go through Kratos admin API
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to hash password"})
}
} }
if err := h.db.Save(&user).Error; err != nil { if err := h.db.Save(&user).Error; err != nil {
+5
View File
@@ -142,6 +142,9 @@ type SecurityConfig struct {
MaxPasswordResetRate int // per hour MaxPasswordResetRate int // per hour
TokenExpiryDays int // Number of days before auth tokens expire (default 90) TokenExpiryDays int // Number of days before auth tokens expire (default 90)
TokenRefreshDays int // Token must be at least this many days old before refresh (default 60) TokenRefreshDays int // Token must be at least this many days old before refresh (default 60)
// KratosPublicURL is the Ory Kratos public API base URL. The auth
// middleware validates sessions against {KratosPublicURL}/sessions/whoami.
KratosPublicURL string
} }
// StorageConfig holds file storage settings. // StorageConfig holds file storage settings.
@@ -304,6 +307,7 @@ func Load() (*Config, error) {
MaxPasswordResetRate: 3, MaxPasswordResetRate: 3,
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"), TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"), TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
KratosPublicURL: viper.GetString("KRATOS_PUBLIC_URL"),
}, },
Storage: StorageConfig{ Storage: StorageConfig{
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"), UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
@@ -411,6 +415,7 @@ func setDefaults() {
// Token expiry defaults // Token expiry defaults
viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days
viper.SetDefault("KRATOS_PUBLIC_URL", "http://kratos:4433") // Ory Kratos public API
viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days
// Storage defaults // Storage defaults
-5
View File
@@ -244,12 +244,7 @@ func Migrate() error {
// User and auth tables // User and auth tables
&models.User{}, &models.User{},
&models.AuthToken{},
&models.UserProfile{}, &models.UserProfile{},
&models.ConfirmationCode{},
&models.PasswordResetCode{},
&models.AppleSocialAuth{},
&models.GoogleSocialAuth{},
// Admin users (separate from app users) // Admin users (separate from app users)
&models.AdminUser{}, &models.AdminUser{},
+12 -432
View File
@@ -1,8 +1,6 @@
package handlers package handlers
import ( import (
"context"
"errors"
"net/http" "net/http"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@@ -16,18 +14,18 @@ import (
"github.com/treytartt/honeydue-api/internal/validator" "github.com/treytartt/honeydue-api/internal/validator"
) )
// AuthHandler handles authentication endpoints // AuthHandler handles user profile and account management endpoints.
// Session lifecycle (login, register, logout, password reset) is delegated
// to Ory Kratos; this handler only deals with the honeyDue user record.
type AuthHandler struct { type AuthHandler struct {
authService *services.AuthService authService *services.AuthService
emailService *services.EmailService emailService *services.EmailService
cache *services.CacheService cache *services.CacheService
appleAuthService *services.AppleAuthService storageService *services.StorageService
googleAuthService *services.GoogleAuthService auditService *services.AuditService
storageService *services.StorageService
auditService *services.AuditService
} }
// NewAuthHandler creates a new auth handler // NewAuthHandler creates a new auth handler.
func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService, cache *services.CacheService) *AuthHandler { func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService, cache *services.CacheService) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
authService: authService, authService: authService,
@@ -36,136 +34,21 @@ func NewAuthHandler(authService *services.AuthService, emailService *services.Em
} }
} }
// SetAppleAuthService sets the Apple auth service (called after initialization) // SetStorageService sets the storage service for file deletion during account deletion.
func (h *AuthHandler) SetAppleAuthService(appleAuth *services.AppleAuthService) {
h.appleAuthService = appleAuth
}
// SetGoogleAuthService sets the Google auth service (called after initialization)
func (h *AuthHandler) SetGoogleAuthService(googleAuth *services.GoogleAuthService) {
h.googleAuthService = googleAuth
}
// SetStorageService sets the storage service for file deletion during account deletion
func (h *AuthHandler) SetStorageService(storageService *services.StorageService) { func (h *AuthHandler) SetStorageService(storageService *services.StorageService) {
h.storageService = storageService h.storageService = storageService
} }
// SetAuditService sets the audit service for logging security events // SetAuditService sets the audit service for logging security events.
func (h *AuthHandler) SetAuditService(auditService *services.AuditService) { func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
h.auditService = auditService h.auditService = auditService
} }
// noStore marks a response as non-cacheable (audit L2) — auth responses // noStore marks a response as non-cacheable.
// carry tokens and user data that must never sit in any cache.
func noStore(c echo.Context) { func noStore(c echo.Context) {
c.Response().Header().Set("Cache-Control", "no-store") c.Response().Header().Set("Cache-Control", "no-store")
} }
// Login handles POST /api/auth/login/
func (h *AuthHandler) Login(c echo.Context) error {
noStore(c)
var req requests.LoginRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
response, err := h.authService.Login(c.Request().Context(), &req, c.RealIP())
if err != nil {
log.Debug().Err(err).Str("identifier", req.Username).
Str("ip", c.RealIP()).Str("user_agent", c.Request().UserAgent()).
Msg("Login failed")
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
"identifier": req.Username,
})
}
return err
}
if h.auditService != nil {
userID := response.User.ID
h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil)
}
return c.JSON(http.StatusOK, response)
}
// Register handles POST /api/auth/register/
func (h *AuthHandler) Register(c echo.Context) error {
noStore(c)
var req requests.RegisterRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req)
if err != nil {
log.Debug().Err(err).Msg("Registration failed")
return err
}
if h.auditService != nil {
userID := response.User.ID
h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{
"username": req.Username,
"email": req.Email,
})
}
// Send welcome email with confirmation code (async)
if h.emailService != nil && confirmationCode != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine")
}
}()
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
}
}()
}
return c.JSON(http.StatusCreated, response)
}
// Logout handles POST /api/auth/logout/
func (h *AuthHandler) Logout(c echo.Context) error {
token := middleware.GetAuthToken(c)
if token == "" {
return apperrors.Unauthorized("error.not_authenticated")
}
// Log audit event before invalidating the token
if h.auditService != nil {
user := middleware.GetAuthUser(c)
if user != nil {
h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil)
}
}
// Invalidate token in database
if err := h.authService.Logout(c.Request().Context(), token); err != nil {
log.Warn().Err(err).Msg("Failed to delete token from database")
}
// Invalidate token in cache
if h.cache != nil {
if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil {
log.Warn().Err(err).Msg("Failed to invalidate token in cache")
}
}
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Logged out successfully"})
}
// CurrentUser handles GET /api/auth/me/ // CurrentUser handles GET /api/auth/me/
func (h *AuthHandler) CurrentUser(c echo.Context) error { func (h *AuthHandler) CurrentUser(c echo.Context) error {
noStore(c) noStore(c)
@@ -207,301 +90,6 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error {
return c.JSON(http.StatusOK, response) return c.JSON(http.StatusOK, response)
} }
// VerifyEmail handles POST /api/auth/verify-email/
func (h *AuthHandler) VerifyEmail(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.VerifyEmailRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
return err
}
// Send post-verification welcome email with tips (async)
if h.emailService != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine")
}
}()
if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
}
}()
}
return c.JSON(http.StatusOK, responses.VerifyEmailResponse{
Message: "Email verified successfully",
Verified: true,
})
}
// ResendVerification handles POST /api/auth/resend-verification/
func (h *AuthHandler) ResendVerification(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
return err
}
// Send verification email (async)
if h.emailService != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine")
}
}()
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
}
}()
}
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Verification email sent"})
}
// ForgotPassword handles POST /api/auth/forgot-password/
func (h *AuthHandler) ForgotPassword(c echo.Context) error {
var req requests.ForgotPasswordRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
noStore(c)
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
"email": req.Email,
})
}
// Audit LIVE-L13: run the user lookup, code generation, and email send
// entirely in the background, then return the generic response
// immediately. This makes the response time identical whether or not
// the email belongs to a real account, defeating timing-based user
// enumeration. context.Background() is used because the request context
// is cancelled the moment this handler returns. Per-account rate
// limiting still runs inside the service; the edge auth-rate-limit
// middleware covers per-IP abuse.
email := req.Email
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", email).Msg("Panic in forgot-password goroutine")
}
}()
code, user, err := h.authService.ForgotPassword(context.Background(), email)
if err != nil || code == "" || user == nil {
return
}
if h.emailService != nil {
if sendErr := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); sendErr != nil {
log.Error().Err(sendErr).Str("email", user.Email).Msg("Failed to send password reset email")
}
}
}()
// Always return success to prevent email enumeration.
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
Message: "Password reset email sent",
})
}
// VerifyResetCode handles POST /api/auth/verify-reset-code/
func (h *AuthHandler) VerifyResetCode(c echo.Context) error {
var req requests.VerifyResetCodeRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code)
if err != nil {
log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed")
return err
}
return c.JSON(http.StatusOK, responses.VerifyResetCodeResponse{
Message: "Reset code verified",
ResetToken: resetToken,
})
}
// ResetPassword handles POST /api/auth/reset-password/
func (h *AuthHandler) ResetPassword(c echo.Context) error {
var req requests.ResetPasswordRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword)
if err != nil {
log.Debug().Err(err).Msg("Password reset failed")
return err
}
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{
"method": "reset_token",
})
}
return c.JSON(http.StatusOK, responses.ResetPasswordResponse{
Message: "Password reset successful",
})
}
// AppleSignIn handles POST /api/auth/apple-sign-in/
func (h *AuthHandler) AppleSignIn(c echo.Context) error {
noStore(c)
var req requests.AppleSignInRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
if h.appleAuthService == nil {
log.Error().Msg("Apple auth service not configured")
return &apperrors.AppError{
Code: 500,
MessageKey: "error.apple_signin_not_configured",
}
}
response, err := h.authService.AppleSignIn(c.Request().Context(), h.appleAuthService, &req)
if err != nil {
// Check for legacy Apple Sign In error (not yet migrated)
if errors.Is(err, services.ErrAppleSignInFailed) {
log.Debug().Err(err).Msg("Apple Sign In failed (legacy error)")
return apperrors.Unauthorized("error.invalid_apple_token")
}
log.Debug().Err(err).Msg("Apple Sign In failed")
return err
}
// Send welcome email for new users (async)
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine")
}
}()
if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email")
}
}()
}
return c.JSON(http.StatusOK, response)
}
// GoogleSignIn handles POST /api/auth/google-sign-in/
func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
noStore(c)
var req requests.GoogleSignInRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
if h.googleAuthService == nil {
log.Error().Msg("Google auth service not configured")
return &apperrors.AppError{
Code: 500,
MessageKey: "error.google_signin_not_configured",
}
}
response, err := h.authService.GoogleSignIn(c.Request().Context(), h.googleAuthService, &req)
if err != nil {
// Check for legacy Google Sign In error (not yet migrated)
if errors.Is(err, services.ErrGoogleSignInFailed) {
log.Debug().Err(err).Msg("Google Sign In failed (legacy error)")
return apperrors.Unauthorized("error.invalid_google_token")
}
log.Debug().Err(err).Msg("Google Sign In failed")
return err
}
// Send welcome email for new users (async)
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine")
}
}()
if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email")
}
}()
}
return c.JSON(http.StatusOK, response)
}
// RefreshToken handles POST /api/auth/refresh/
func (h *AuthHandler) RefreshToken(c echo.Context) error {
noStore(c)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
token := middleware.GetAuthToken(c)
if token == "" {
return apperrors.Unauthorized("error.not_authenticated")
}
response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID)
if err != nil {
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
return err
}
// If the token was refreshed (new token), invalidate the old one from cache
if response.Token != token && h.cache != nil {
if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil {
log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh")
}
}
return c.JSON(http.StatusOK, response)
}
// DeleteAccount handles DELETE /api/auth/account/ // DeleteAccount handles DELETE /api/auth/account/
func (h *AuthHandler) DeleteAccount(c echo.Context) error { func (h *AuthHandler) DeleteAccount(c echo.Context) error {
user, err := middleware.MustGetAuthUser(c) user, err := middleware.MustGetAuthUser(c)
@@ -544,13 +132,5 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
}() }()
} }
// Invalidate auth token from cache
token := middleware.GetAuthToken(c)
if h.cache != nil && token != "" {
if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil {
log.Warn().Err(err).Msg("Failed to invalidate token in cache after account deletion")
}
}
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Account deleted successfully"}) return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Account deleted successfully"})
} }
+15 -106
View File
@@ -35,26 +35,25 @@ func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB
return handler, e, db return handler, e, db
} }
func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) { // TestAuthHandler_DeleteAccount_WithConfirmation verifies that DELETE /account/
// succeeds when the user sends confirmation: "DELETE".
// Post-Kratos: all users (regardless of provider) must confirm with "DELETE".
func TestAuthHandler_DeleteAccount_WithConfirmation(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t) handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "Password123") user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "ignored")
// Create profile for the user // Create profile for the user
profile := &models.UserProfile{UserID: user.ID, Verified: true} profile := &models.UserProfile{UserID: user.ID, Verified: true}
require.NoError(t, db.Create(profile).Error) require.NoError(t, db.Create(profile).Error)
// Create auth token
testutil.CreateTestToken(t, db, user.ID)
authGroup := e.Group("/api/auth") authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user)) authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/account/", handler.DeleteAccount) authGroup.DELETE("/account/", handler.DeleteAccount)
t.Run("successful deletion with correct password", func(t *testing.T) { t.Run("successful deletion with DELETE confirmation", func(t *testing.T) {
password := "Password123"
req := map[string]interface{}{ req := map[string]interface{}{
"password": password, "confirmation": "DELETE",
} }
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
@@ -74,106 +73,15 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
// Verify profile is deleted // Verify profile is deleted
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count) db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
// Verify auth token is deleted
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count)
}) })
} }
func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) { // TestAuthHandler_DeleteAccount_MissingConfirmation verifies that a missing
// confirmation string is rejected with 400.
func TestAuthHandler_DeleteAccount_MissingConfirmation(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t) handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "Password123") user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "ignored")
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/account/", handler.DeleteAccount)
t.Run("wrong password returns 401", func(t *testing.T) {
wrongPw := "wrongpassword"
req := map[string]interface{}{
"password": wrongPw,
}
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "Password123")
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/account/", handler.DeleteAccount)
t.Run("missing password returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestAuthHandler_DeleteAccount_SocialAuthUser(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "randompassword")
// Create Apple social auth record
appleAuth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_123",
Email: "apple@test.com",
}
require.NoError(t, db.Create(appleAuth).Error)
// Create profile
profile := &models.UserProfile{UserID: user.ID, Verified: true}
require.NoError(t, db.Create(profile).Error)
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.DELETE("/account/", handler.DeleteAccount)
t.Run("successful deletion with DELETE confirmation", func(t *testing.T) {
confirmation := "DELETE"
req := map[string]interface{}{
"confirmation": confirmation,
}
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
// Verify user is deleted
var count int64
db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count)
// Verify apple auth is deleted
db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count)
})
}
func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
handler, e, db := setupDeleteAccountHandler(t)
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "randompassword")
// Create Google social auth record
googleAuth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_456",
Email: "google@test.com",
}
require.NoError(t, db.Create(googleAuth).Error)
authGroup := e.Group("/api/auth") authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user)) authGroup.Use(testutil.MockAuthMiddleware(user))
@@ -188,9 +96,8 @@ func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
}) })
t.Run("wrong confirmation returns 400", func(t *testing.T) { t.Run("wrong confirmation returns 400", func(t *testing.T) {
wrongConfirmation := "delete"
req := map[string]interface{}{ req := map[string]interface{}{
"confirmation": wrongConfirmation, "confirmation": "delete", // lowercase — must be exact "DELETE"
} }
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token") w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
@@ -199,6 +106,8 @@ func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
}) })
} }
// TestAuthHandler_DeleteAccount_Unauthenticated verifies that 401 is returned
// when no auth middleware is set.
func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) { func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
handler, e, _ := setupDeleteAccountHandler(t) handler, e, _ := setupDeleteAccountHandler(t)
@@ -207,7 +116,7 @@ func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
t.Run("unauthenticated request returns 401", func(t *testing.T) { t.Run("unauthenticated request returns 401", func(t *testing.T) {
req := map[string]interface{}{ req := map[string]interface{}{
"password": "Password123", "confirmation": "DELETE",
} }
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "") w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "")
+33 -325
View File
@@ -1,3 +1,7 @@
// auth_handler_test.go tests the auth handler endpoints that survived the
// Ory Kratos migration: GET /me/ and PUT/PATCH /profile/.
// Login, register, logout, forgot-password, and social sign-in are now
// handled by Kratos.
package handlers package handlers
import ( import (
@@ -34,204 +38,32 @@ func setupAuthHandler(t *testing.T) (*AuthHandler, *echo.Echo, *repositories.Use
return handler, e, userRepo return handler, e, userRepo
} }
func TestAuthHandler_Register(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/register/", handler.Register)
t.Run("successful registration", func(t *testing.T) {
req := requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
FirstName: "New",
LastName: "User",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
testutil.AssertJSONFieldExists(t, response, "token")
testutil.AssertJSONFieldExists(t, response, "user")
testutil.AssertJSONFieldExists(t, response, "message")
user := response["user"].(map[string]interface{})
assert.Equal(t, "newuser", user["username"])
assert.Equal(t, "new@test.com", user["email"])
assert.Equal(t, "New", user["first_name"])
assert.Equal(t, "User", user["last_name"])
})
t.Run("registration with missing fields", func(t *testing.T) {
req := map[string]string{
"username": "test",
// Missing email and password
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
response := testutil.ParseJSON(t, w.Body.Bytes())
testutil.AssertJSONFieldExists(t, response, "error")
})
t.Run("registration with short password", func(t *testing.T) {
req := requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "short", // Less than 8 chars
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("registration with duplicate username", func(t *testing.T) {
// First registration
req := requests.RegisterRequest{
Username: "duplicate",
Email: "unique1@test.com",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
// Try to register again with same username
req.Email = "unique2@test.com"
w = testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusConflict) // 409 for duplicate resource
response := testutil.ParseJSON(t, w.Body.Bytes())
assert.Contains(t, response["error"], "Username already taken")
})
t.Run("registration with duplicate email", func(t *testing.T) {
// First registration
req := requests.RegisterRequest{
Username: "user1",
Email: "duplicate@test.com",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
// Try to register again with same email
req.Username = "user2"
w = testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusConflict) // 409 for duplicate resource
response := testutil.ParseJSON(t, w.Body.Bytes())
assert.Contains(t, response["error"], "Email already registered")
})
}
func TestAuthHandler_Login(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/register/", handler.Register)
e.POST("/api/auth/login/", handler.Login)
// Create a test user
registerReq := requests.RegisterRequest{
Username: "logintest",
Email: "login@test.com",
Password: "Password123",
FirstName: "Test",
LastName: "User",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
t.Run("successful login with username", func(t *testing.T) {
req := requests.LoginRequest{
Username: "logintest",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
testutil.AssertJSONFieldExists(t, response, "token")
testutil.AssertJSONFieldExists(t, response, "user")
user := response["user"].(map[string]interface{})
assert.Equal(t, "logintest", user["username"])
assert.Equal(t, "login@test.com", user["email"])
})
t.Run("successful login with email", func(t *testing.T) {
req := requests.LoginRequest{
Username: "login@test.com", // Using email as username
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
testutil.AssertStatusCode(t, w, http.StatusOK)
})
t.Run("login with wrong password", func(t *testing.T) {
req := requests.LoginRequest{
Username: "logintest",
Password: "wrongpassword",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
response := testutil.ParseJSON(t, w.Body.Bytes())
assert.Contains(t, response["error"], "Invalid credentials")
})
t.Run("login with non-existent user", func(t *testing.T) {
req := requests.LoginRequest{
Username: "nonexistent",
Password: "Password123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
t.Run("login with missing fields", func(t *testing.T) {
req := map[string]string{
"username": "logintest",
// Missing password
}
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestAuthHandler_CurrentUser(t *testing.T) { func TestAuthHandler_CurrentUser(t *testing.T) {
handler, e, userRepo := setupAuthHandler(t) handler, e, _ := setupAuthHandler(t)
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "Password123") user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "")
user.FirstName = "Test" user.FirstName = "Test"
user.LastName = "User" user.LastName = "User"
userRepo.Update(user) // Use the userRepo from setupAuthHandler's DB, but since we need the user
// in the same DB we re-create it there.
db2 := testutil.SetupTestDB(t)
user2 := testutil.CreateTestUser(t, db2, "metest2", "me2@test.com", "")
user2.FirstName = "Test"
user2.LastName = "User"
userRepo2 := repositories.NewUserRepository(db2)
require.NoError(t, userRepo2.Update(user2))
// Build handler against db2
cfg := &config.Config{}
authService2 := services.NewAuthService(userRepo2, cfg)
handler2 := NewAuthHandler(authService2, nil, nil)
// Set up route with mock auth middleware
authGroup := e.Group("/api/auth") authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user)) authGroup.Use(testutil.MockAuthMiddleware(user2))
authGroup.GET("/me/", handler.CurrentUser) authGroup.GET("/me/", handler2.CurrentUser)
_ = handler // avoid unused
t.Run("get current user", func(t *testing.T) { t.Run("get current user", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/auth/me/", nil, "test-token") w := testutil.MakeRequest(e, "GET", "/api/auth/me/", nil, "test-token")
@@ -242,23 +74,26 @@ func TestAuthHandler_CurrentUser(t *testing.T) {
err := json.Unmarshal(w.Body.Bytes(), &response) err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "metest", response["username"]) assert.Equal(t, "metest2", response["username"])
assert.Equal(t, "me@test.com", response["email"]) assert.Equal(t, "me2@test.com", response["email"])
}) })
} }
func TestAuthHandler_UpdateProfile(t *testing.T) { func TestAuthHandler_UpdateProfile(t *testing.T) {
handler, e, userRepo := setupAuthHandler(t)
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "Password123") userRepo := repositories.NewUserRepository(db)
userRepo.Update(user) cfg := &config.Config{}
authService := services.NewAuthService(userRepo, cfg)
handler := NewAuthHandler(authService, nil, nil)
e := testutil.SetupTestRouter()
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "")
authGroup := e.Group("/api/auth") authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user)) authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.PUT("/profile/", handler.UpdateProfile) authGroup.PUT("/profile/", handler.UpdateProfile)
t.Run("update profile", func(t *testing.T) { t.Run("update first and last name", func(t *testing.T) {
firstName := "Updated" firstName := "Updated"
lastName := "Name" lastName := "Name"
req := requests.UpdateProfileRequest{ req := requests.UpdateProfileRequest{
@@ -278,130 +113,3 @@ func TestAuthHandler_UpdateProfile(t *testing.T) {
assert.Equal(t, "Name", response["last_name"]) assert.Equal(t, "Name", response["last_name"])
}) })
} }
func TestAuthHandler_ForgotPassword(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/register/", handler.Register)
e.POST("/api/auth/forgot-password/", handler.ForgotPassword)
// Create a test user
registerReq := requests.RegisterRequest{
Username: "forgottest",
Email: "forgot@test.com",
Password: "Password123",
}
testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "")
t.Run("forgot password with valid email", func(t *testing.T) {
req := requests.ForgotPasswordRequest{
Email: "forgot@test.com",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "")
// Always returns 200 to prevent email enumeration
testutil.AssertStatusCode(t, w, http.StatusOK)
response := testutil.ParseJSON(t, w.Body.Bytes())
testutil.AssertJSONFieldExists(t, response, "message")
})
t.Run("forgot password with invalid email", func(t *testing.T) {
req := requests.ForgotPasswordRequest{
Email: "nonexistent@test.com",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "")
// Still returns 200 to prevent email enumeration
testutil.AssertStatusCode(t, w, http.StatusOK)
})
}
func TestAuthHandler_Logout(t *testing.T) {
handler, e, userRepo := setupAuthHandler(t)
db := testutil.SetupTestDB(t)
user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "Password123")
userRepo.Update(user)
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/logout/", handler.Logout)
t.Run("successful logout", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/auth/logout/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
response := testutil.ParseJSON(t, w.Body.Bytes())
assert.Contains(t, response["message"], "Logged out successfully")
})
}
func TestAuthHandler_JSONResponses(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/register/", handler.Register)
e.POST("/api/auth/login/", handler.Login)
t.Run("register response has correct JSON structure", func(t *testing.T) {
req := requests.RegisterRequest{
Username: "jsontest",
Email: "json@test.com",
Password: "Password123",
FirstName: "JSON",
LastName: "Test",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusCreated)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Verify top-level structure
assert.Contains(t, response, "token")
assert.Contains(t, response, "user")
assert.Contains(t, response, "message")
// Verify token is not empty
assert.NotEmpty(t, response["token"])
// Verify user structure
user := response["user"].(map[string]interface{})
assert.Contains(t, user, "id")
assert.Contains(t, user, "username")
assert.Contains(t, user, "email")
assert.Contains(t, user, "first_name")
assert.Contains(t, user, "last_name")
assert.Contains(t, user, "is_active")
assert.Contains(t, user, "date_joined")
// Verify types
assert.IsType(t, float64(0), user["id"]) // JSON numbers are float64
assert.IsType(t, "", user["username"])
assert.IsType(t, "", user["email"])
assert.IsType(t, true, user["is_active"])
})
t.Run("error response has correct JSON structure", func(t *testing.T) {
req := map[string]string{
"username": "test",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "error")
assert.IsType(t, "", response["error"])
})
}
-226
View File
@@ -506,232 +506,6 @@ func TestTaskHandler_CreateCompletion_NoTaskID(t *testing.T) {
}) })
} }
// =============================================================================
// Auth Handler - Additional Coverage
// =============================================================================
func TestAuthHandler_AppleSignIn_NotConfigured(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/apple-sign-in/", handler.AppleSignIn)
t.Run("returns 500 when apple auth not configured", func(t *testing.T) {
req := map[string]interface{}{
"id_token": "fake-token",
"user_id": "fake-user-id",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/apple-sign-in/", req, "")
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
})
t.Run("missing identity_token returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "POST", "/api/auth/apple-sign-in/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestAuthHandler_GoogleSignIn_NotConfigured(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/google-sign-in/", handler.GoogleSignIn)
t.Run("returns 500 when google auth not configured", func(t *testing.T) {
req := map[string]interface{}{
"id_token": "fake-token",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/google-sign-in/", req, "")
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
})
t.Run("missing id_token returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "POST", "/api/auth/google-sign-in/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
// setupAuthHandlerWithDB is like setupAuthHandler but also returns the underlying *gorm.DB
// for tests that need to create records like ConfirmationCode directly.
func setupAuthHandlerWithDB(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{
SecretKey: "test-secret-key",
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
},
}
authService := services.NewAuthService(userRepo, cfg)
handler := NewAuthHandler(authService, nil, nil)
e := testutil.SetupTestRouter()
return handler, e, db
}
func TestAuthHandler_VerifyEmail(t *testing.T) {
handler, e, db := setupAuthHandlerWithDB(t)
user := testutil.CreateTestUser(t, db, "verifytest", "verify@test.com", "Password123")
// Create confirmation code
confirmCode := &models.ConfirmationCode{
UserID: user.ID,
Code: "123456",
ExpiresAt: time.Now().Add(24 * time.Hour),
IsUsed: false,
}
require.NoError(t, db.Create(confirmCode).Error)
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/verify-email/", handler.VerifyEmail)
t.Run("successful verification", func(t *testing.T) {
req := requests.VerifyEmailRequest{
Code: "123456",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, true, response["verified"])
})
t.Run("wrong code returns error", func(t *testing.T) {
req := requests.VerifyEmailRequest{
Code: "999999",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token")
// Code already used or wrong code
assert.True(t, w.Code == http.StatusBadRequest || w.Code == http.StatusNotFound,
"expected 400 or 404, got %d", w.Code)
})
t.Run("missing code returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestAuthHandler_ResendVerification(t *testing.T) {
handler, e, db := setupAuthHandlerWithDB(t)
user := testutil.CreateTestUser(t, db, "resendtest", "resend@test.com", "Password123")
authGroup := e.Group("/api/auth")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/resend-verification/", handler.ResendVerification)
t.Run("successful resend", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/auth/resend-verification/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "message")
})
}
func TestAuthHandler_RefreshToken(t *testing.T) {
handler, e, db := setupAuthHandlerWithDB(t)
user := testutil.CreateTestUser(t, db, "refreshtest", "refresh@test.com", "Password123")
// Create auth token and use its actual key in the middleware
authToken := testutil.CreateTestToken(t, db, user.ID)
authGroup := e.Group("/api/auth")
authGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("auth_user", user)
c.Set("auth_token", authToken.Plaintext) // raw token — repo hashes for lookup (audit C1)
return next(c)
}
})
authGroup.POST("/refresh/", handler.RefreshToken)
t.Run("successful refresh", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Plaintext)
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "token")
})
}
func TestAuthHandler_VerifyResetCode(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/register/", handler.Register)
e.POST("/api/auth/verify-reset-code/", handler.VerifyResetCode)
t.Run("invalid code returns error", func(t *testing.T) {
req := requests.VerifyResetCodeRequest{
Email: "nonexistent@test.com",
Code: "999999",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-reset-code/", req, "")
// Should not be 200 since no valid code exists
assert.NotEqual(t, http.StatusOK, w.Code)
})
t.Run("missing fields returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-reset-code/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestAuthHandler_ResetPassword(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/reset-password/", handler.ResetPassword)
t.Run("invalid reset token returns error", func(t *testing.T) {
req := requests.ResetPasswordRequest{
ResetToken: "invalid-token",
NewPassword: "NewPassword123",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "")
assert.NotEqual(t, http.StatusOK, w.Code)
})
t.Run("missing fields returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("short password returns 400", func(t *testing.T) {
req := requests.ResetPasswordRequest{
ResetToken: "some-token",
NewPassword: "short",
}
w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
func TestAuthHandler_ForgotPassword_MissingEmail(t *testing.T) {
handler, e, _ := setupAuthHandler(t)
e.POST("/api/auth/forgot-password/", handler.ForgotPassword)
t.Run("missing email returns 400", func(t *testing.T) {
req := map[string]interface{}{}
w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}
// ============================================================================= // =============================================================================
// Residence Handler - Additional Error Paths // Residence Handler - Additional Error Paths
// ============================================================================= // =============================================================================
+21
View File
@@ -190,6 +190,27 @@ func shouldSkipSpecRoute(path string) bool {
if strings.HasPrefix(path, "/uploads/") || strings.HasPrefix(path, "/media/") { if strings.HasPrefix(path, "/uploads/") || strings.HasPrefix(path, "/media/") {
return true return true
} }
// Auth routes delegated to Ory Kratos (phase 2 auth refactor).
// These endpoints are no longer served by the Go API; the spec is retained
// as documentation of the Kratos-facing contract.
kratosRoutes := map[string]bool{
"/auth/login/": true,
"/auth/register/": true,
"/auth/logout/": true,
"/auth/refresh/": true,
"/auth/forgot-password/": true,
"/auth/verify-reset-code/": true,
"/auth/reset-password/": true,
"/auth/verify-email/": true,
"/auth/resend-verification/": true,
"/auth/apple-sign-in/": true,
"/auth/google-sign-in/": true,
}
if kratosRoutes[path] {
return true
}
return false return false
} }
+98 -282
View File
@@ -6,9 +6,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -17,6 +19,7 @@ import (
"github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/handlers" "github.com/treytartt/honeydue-api/internal/handlers"
"github.com/treytartt/honeydue-api/internal/middleware" "github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories" "github.com/treytartt/honeydue-api/internal/repositories"
"github.com/treytartt/honeydue-api/internal/services" "github.com/treytartt/honeydue-api/internal/services"
"github.com/treytartt/honeydue-api/internal/testutil" "github.com/treytartt/honeydue-api/internal/testutil"
@@ -105,11 +108,40 @@ type TestApp struct {
TaskRepo *repositories.TaskRepository TaskRepo *repositories.TaskRepository
ContractorRepo *repositories.ContractorRepository ContractorRepo *repositories.ContractorRepository
AuthService *services.AuthService AuthService *services.AuthService
// tokenStore maps fake token strings to users for the test-auth middleware.
tokenStore map[string]*models.User
tokenStoreMu sync.RWMutex
}
// fakeAuthMiddleware replaces the real Kratos middleware in integration tests.
// It looks up the "Authorization: Token <tok>" value in app.tokenStore.
func (app *TestApp) fakeAuthMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ah := c.Request().Header.Get("Authorization")
if ah == "" {
return apperrors.Unauthorized("error.not_authenticated")
}
tok := ah
if len(ah) > 6 && ah[:6] == "Token " {
tok = ah[6:]
} else if len(ah) > 7 && ah[:7] == "Bearer " {
tok = ah[7:]
}
app.tokenStoreMu.RLock()
user, ok := app.tokenStore[tok]
app.tokenStoreMu.RUnlock()
if !ok {
return apperrors.Unauthorized("error.not_authenticated")
}
c.Set("auth_user", user)
c.Set("auth_token", tok)
return next(c)
}
}
} }
func setupIntegrationTest(t *testing.T) *TestApp { func setupIntegrationTest(t *testing.T) *TestApp {
// Echo does not need test mode
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db) testutil.SeedLookupData(t, db)
@@ -122,10 +154,7 @@ func setupIntegrationTest(t *testing.T) *TestApp {
// Create config // Create config
cfg := &config.Config{ cfg := &config.Config{
Security: config.SecurityConfig{ Security: config.SecurityConfig{
SecretKey: "test-secret-key-for-integration-tests", SecretKey: "test-secret-key-for-integration-tests",
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
}, },
} }
@@ -141,28 +170,33 @@ func setupIntegrationTest(t *testing.T) *TestApp {
taskHandler := handlers.NewTaskHandler(taskService, nil) taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(contractorService) contractorHandler := handlers.NewContractorHandler(contractorService)
// Create router with real middleware app := &TestApp{
e := echo.New() DB: db,
Router: echo.New(),
AuthHandler: authHandler,
ResidenceHandler: residenceHandler,
TaskHandler: taskHandler,
ContractorHandler: contractorHandler,
UserRepo: userRepo,
ResidenceRepo: residenceRepo,
TaskRepo: taskRepo,
ContractorRepo: contractorRepo,
AuthService: authService,
tokenStore: make(map[string]*models.User),
}
e := app.Router
e.Validator = validator.NewCustomValidator() e.Validator = validator.NewCustomValidator()
e.HTTPErrorHandler = apperrors.HTTPErrorHandler e.HTTPErrorHandler = apperrors.HTTPErrorHandler
// Add timezone middleware globally so X-Timezone header is processed // Timezone middleware processes X-Timezone header
e.Use(middleware.TimezoneMiddleware()) e.Use(middleware.TimezoneMiddleware())
// Public routes // Protected routes — guarded by the fake token middleware
auth := e.Group("/api/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
// Protected routes - use AuthMiddleware without Redis cache for testing
authMiddleware := middleware.NewAuthMiddleware(db, nil)
api := e.Group("/api") api := e.Group("/api")
api.Use(authMiddleware.TokenAuth()) api.Use(app.fakeAuthMiddleware())
{ {
api.GET("/auth/me", authHandler.CurrentUser) api.GET("/auth/me", authHandler.CurrentUser)
api.POST("/auth/logout", authHandler.Logout)
residences := api.Group("/residences") residences := api.Group("/residences")
{ {
@@ -216,19 +250,7 @@ func setupIntegrationTest(t *testing.T) *TestApp {
api.GET("/contractors/by-residence/:residence_id", contractorHandler.ListContractorsByResidence) api.GET("/contractors/by-residence/:residence_id", contractorHandler.ListContractorsByResidence)
} }
return &TestApp{ return app
DB: db,
Router: e,
AuthHandler: authHandler,
ResidenceHandler: residenceHandler,
TaskHandler: taskHandler,
ContractorHandler: contractorHandler,
UserRepo: userRepo,
ResidenceRepo: residenceRepo,
TaskRepo: taskRepo,
ContractorRepo: contractorRepo,
AuthService: authService,
}
} }
// Helper to make authenticated requests // Helper to make authenticated requests
@@ -251,156 +273,16 @@ func (app *TestApp) makeAuthenticatedRequest(t *testing.T, method, path string,
return w return w
} }
// Helper to register and login a user, returns token // registerAndLogin creates a user directly in the DB and returns a synthetic token
func (app *TestApp) registerAndLogin(t *testing.T, username, email, password string) string { // that the fake auth middleware will accept. No HTTP register/login endpoints are called.
// Register func (app *TestApp) registerAndLogin(t *testing.T, username, email, _ string) string {
registerBody := map[string]string{ t.Helper()
"username": username, user := testutil.CreateTestUser(t, app.DB, username, email, "")
"email": email, tok := uuid.NewString()
"password": password, app.tokenStoreMu.Lock()
} app.tokenStore[tok] = user
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "") app.tokenStoreMu.Unlock()
require.Equal(t, http.StatusCreated, w.Code) return tok
// Login
loginBody := map[string]string{
"username": username,
"password": password,
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
require.Equal(t, http.StatusOK, w.Code)
var loginResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
require.NoError(t, err)
return loginResp["token"].(string)
}
// ============ Authentication Flow Tests ============
func TestIntegration_AuthenticationFlow(t *testing.T) {
app := setupIntegrationTest(t)
// 1. Register a new user
registerBody := map[string]string{
"username": "testuser",
"email": "test@example.com",
"password": "SecurePass123!",
}
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
assert.Equal(t, http.StatusCreated, w.Code)
var registerResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &registerResp)
require.NoError(t, err)
assert.NotEmpty(t, registerResp["token"])
assert.NotNil(t, registerResp["user"])
// 2. Login with the same credentials
loginBody := map[string]string{
"username": "testuser",
"password": "SecurePass123!",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
assert.Equal(t, http.StatusOK, w.Code)
var loginResp map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &loginResp)
require.NoError(t, err)
token := loginResp["token"].(string)
assert.NotEmpty(t, token)
// 3. Get current user with token
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token)
assert.Equal(t, http.StatusOK, w.Code)
var meResp map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &meResp)
require.NoError(t, err)
assert.Equal(t, "testuser", meResp["username"])
assert.Equal(t, "test@example.com", meResp["email"])
// 4. Access protected route without token should fail
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, "")
assert.Equal(t, http.StatusUnauthorized, w.Code)
// 5. Access protected route with invalid token should fail
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, "invalid-token")
assert.Equal(t, http.StatusUnauthorized, w.Code)
// 6. Logout
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/logout", nil, token)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestIntegration_RegistrationValidation(t *testing.T) {
app := setupIntegrationTest(t)
tests := []struct {
name string
body map[string]string
expectedStatus int
}{
{
name: "missing username",
body: map[string]string{"email": "test@example.com", "password": "pass123"},
expectedStatus: http.StatusBadRequest,
},
{
name: "missing email",
body: map[string]string{"username": "testuser", "password": "pass123"},
expectedStatus: http.StatusBadRequest,
},
{
name: "missing password",
body: map[string]string{"username": "testuser", "email": "test@example.com"},
expectedStatus: http.StatusBadRequest,
},
{
name: "invalid email",
body: map[string]string{"username": "testuser", "email": "invalid", "password": "pass123"},
expectedStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", tt.body, "")
assert.Equal(t, tt.expectedStatus, w.Code)
})
}
}
func TestIntegration_DuplicateRegistration(t *testing.T) {
app := setupIntegrationTest(t)
// Register first user (password must be >= 8 chars)
registerBody := map[string]string{
"username": "testuser",
"email": "test@example.com",
"password": "Password123",
}
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
assert.Equal(t, http.StatusCreated, w.Code)
// Try to register with same username - returns 409 (Conflict)
registerBody2 := map[string]string{
"username": "testuser",
"email": "different@example.com",
"password": "Password123",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody2, "")
assert.Equal(t, http.StatusConflict, w.Code)
// Try to register with same email - returns 409 (Conflict)
registerBody3 := map[string]string{
"username": "differentuser",
"email": "test@example.com",
"password": "Password123",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody3, "")
assert.Equal(t, http.StatusConflict, w.Code)
} }
// ============ Residence Flow Tests ============ // ============ Residence Flow Tests ============
@@ -827,48 +709,16 @@ func TestIntegration_ResponseStructure(t *testing.T) {
func TestIntegration_ComprehensiveE2E(t *testing.T) { func TestIntegration_ComprehensiveE2E(t *testing.T) {
app := setupIntegrationTest(t) app := setupIntegrationTest(t)
// ============ Phase 1: Authentication ============ // ============ Phase 1: User Setup ============
t.Log("Phase 1: Testing authentication flow") t.Log("Phase 1: Setting up test user")
// Register new user token := app.registerAndLogin(t, "e2e_testuser", "e2e@example.com", "")
registerBody := map[string]string{
"username": "e2e_testuser",
"email": "e2e@example.com",
"password": "SecurePass123!",
}
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed")
var registerResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &registerResp)
require.NoError(t, err)
assert.NotEmpty(t, registerResp["token"], "Registration should return token")
assert.NotNil(t, registerResp["user"], "Registration should return user")
// Verify login with same credentials
loginBody := map[string]string{
"username": "e2e_testuser",
"password": "SecurePass123!",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
require.Equal(t, http.StatusOK, w.Code, "Login should succeed")
var loginResp map[string]interface{}
err = json.Unmarshal(w.Body.Bytes(), &loginResp)
require.NoError(t, err)
token := loginResp["token"].(string)
assert.NotEmpty(t, token, "Login should return token")
// Verify authenticated access // Verify authenticated access
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token) w := app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token)
require.Equal(t, http.StatusOK, w.Code, "Should access protected route with valid token") require.Equal(t, http.StatusOK, w.Code, "Should access protected route with valid token")
var meResp map[string]interface{} t.Log("✓ User setup verified")
json.Unmarshal(w.Body.Bytes(), &meResp)
assert.Equal(t, "e2e_testuser", meResp["username"])
assert.Equal(t, "e2e@example.com", meResp["email"])
t.Log("✓ Authentication flow verified")
// ============ Phase 2: Create 5 Residences ============ // ============ Phase 2: Create 5 Residences ============
t.Log("Phase 2: Creating 5 residences") t.Log("Phase 2: Creating 5 residences")
@@ -1244,29 +1094,9 @@ func TestIntegration_ComprehensiveE2E(t *testing.T) {
t.Logf("✓ All %d visible tasks verified in correct columns by ID", expectedVisibleTasks) t.Logf("✓ All %d visible tasks verified in correct columns by ID", expectedVisibleTasks)
// ============ Phase 9: Create User B ============ // ============ Phase 9: Create User B ============
t.Log("Phase 9: Creating User B and verifying login") t.Log("Phase 9: Creating User B")
// Register User B tokenB := app.registerAndLogin(t, "e2e_userb", "e2e_userb@example.com", "")
registerBodyB := map[string]string{
"username": "e2e_userb",
"email": "e2e_userb@example.com",
"password": "SecurePass456!",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBodyB, "")
require.Equal(t, http.StatusCreated, w.Code, "User B registration should succeed")
// Login as User B
loginBodyB := map[string]string{
"username": "e2e_userb",
"password": "SecurePass456!",
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBodyB, "")
require.Equal(t, http.StatusOK, w.Code, "User B login should succeed")
var loginRespB map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &loginRespB)
tokenB := loginRespB["token"].(string)
assert.NotEmpty(t, tokenB, "User B should have a token")
// Verify User B can access their own profile // Verify User B can access their own profile
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB) w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB)
@@ -1592,8 +1422,6 @@ func formatID(id float64) string {
// setupContractorTest sets up a test environment including contractor routes // setupContractorTest sets up a test environment including contractor routes
func setupContractorTest(t *testing.T) *TestApp { func setupContractorTest(t *testing.T) *TestApp {
// Echo does not need test mode
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db) testutil.SeedLookupData(t, db)
@@ -1606,10 +1434,7 @@ func setupContractorTest(t *testing.T) *TestApp {
// Create config // Create config
cfg := &config.Config{ cfg := &config.Config{
Security: config.SecurityConfig{ Security: config.SecurityConfig{
SecretKey: "test-secret-key-for-integration-tests", SecretKey: "test-secret-key-for-integration-tests",
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
}, },
} }
@@ -1625,29 +1450,32 @@ func setupContractorTest(t *testing.T) *TestApp {
taskHandler := handlers.NewTaskHandler(taskService, nil) taskHandler := handlers.NewTaskHandler(taskService, nil)
contractorHandler := handlers.NewContractorHandler(contractorService) contractorHandler := handlers.NewContractorHandler(contractorService)
// Create router with real middleware app := &TestApp{
e := echo.New() DB: db,
Router: echo.New(),
AuthHandler: authHandler,
ResidenceHandler: residenceHandler,
TaskHandler: taskHandler,
ContractorHandler: contractorHandler,
UserRepo: userRepo,
ResidenceRepo: residenceRepo,
TaskRepo: taskRepo,
ContractorRepo: contractorRepo,
AuthService: authService,
tokenStore: make(map[string]*models.User),
}
e := app.Router
e.Validator = validator.NewCustomValidator() e.Validator = validator.NewCustomValidator()
e.HTTPErrorHandler = apperrors.HTTPErrorHandler e.HTTPErrorHandler = apperrors.HTTPErrorHandler
// Add timezone middleware globally so X-Timezone header is processed // Timezone middleware
e.Use(middleware.TimezoneMiddleware()) e.Use(middleware.TimezoneMiddleware())
// Public routes
auth := e.Group("/api/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
// Protected routes // Protected routes
authMiddleware := middleware.NewAuthMiddleware(db, nil)
api := e.Group("/api") api := e.Group("/api")
api.Use(authMiddleware.TokenAuth()) api.Use(app.fakeAuthMiddleware())
{ {
api.GET("/auth/me", authHandler.CurrentUser)
api.POST("/auth/logout", authHandler.Logout)
residences := api.Group("/residences") residences := api.Group("/residences")
{ {
residences.GET("", residenceHandler.ListResidences) residences.GET("", residenceHandler.ListResidences)
@@ -1680,19 +1508,7 @@ func setupContractorTest(t *testing.T) *TestApp {
} }
} }
return &TestApp{ return app
DB: db,
Router: e,
AuthHandler: authHandler,
ResidenceHandler: residenceHandler,
TaskHandler: taskHandler,
ContractorHandler: contractorHandler,
UserRepo: userRepo,
ResidenceRepo: residenceRepo,
TaskRepo: taskRepo,
ContractorRepo: contractorRepo,
AuthService: authService,
}
} }
// ============ Test 1: Recurring Task Lifecycle ============ // ============ Test 1: Recurring Task Lifecycle ============
@@ -2045,12 +1861,12 @@ func TestIntegration_MultiUserSharing(t *testing.T) {
// Phase 9: Remove User B from residence 3 // Phase 9: Remove User B from residence 3
t.Log("Phase 9: Remove User B from residence 3") t.Log("Phase 9: Remove User B from residence 3")
// Get User B's ID // Get User B's ID from the token store
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB) app.tokenStoreMu.RLock()
require.Equal(t, http.StatusOK, w.Code) userBModel := app.tokenStore[tokenB]
var userBInfo map[string]interface{} app.tokenStoreMu.RUnlock()
json.Unmarshal(w.Body.Bytes(), &userBInfo) require.NotNil(t, userBModel, "User B should be in token store")
userBID := uint(userBInfo["id"].(float64)) userBID := userBModel.ID
// Remove User B from residence 3 // Remove User B from residence 3
w = app.makeAuthenticatedRequest(t, "DELETE", fmt.Sprintf("/api/residences/%d/users/%d", residenceIDs[2], userBID), nil, tokenA) w = app.makeAuthenticatedRequest(t, "DELETE", fmt.Sprintf("/api/residences/%d/users/%d", residenceIDs[2], userBID), nil, tokenA)
@@ -6,9 +6,11 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -35,6 +37,48 @@ type SecurityTestApp struct {
Router *echo.Echo Router *echo.Echo
SubscriptionService *services.SubscriptionService SubscriptionService *services.SubscriptionService
SubscriptionRepo *repositories.SubscriptionRepository SubscriptionRepo *repositories.SubscriptionRepository
tokenStore map[string]*models.User
tokenStoreMu sync.RWMutex
}
// fakeAuthMiddleware returns an Echo middleware that authenticates requests using
// the in-process tokenStore instead of calling the real Kratos session endpoint.
func (app *SecurityTestApp) fakeAuthMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ah := c.Request().Header.Get("Authorization")
if ah == "" {
return apperrors.Unauthorized("error.not_authenticated")
}
tok := ah
if len(ah) > 6 && ah[:6] == "Token " {
tok = ah[6:]
} else if len(ah) > 7 && ah[:7] == "Bearer " {
tok = ah[7:]
}
app.tokenStoreMu.RLock()
user, ok := app.tokenStore[tok]
app.tokenStoreMu.RUnlock()
if !ok {
return apperrors.Unauthorized("error.not_authenticated")
}
c.Set("auth_user", user)
c.Set("auth_token", tok)
return next(c)
}
}
}
// registerAndLoginSec creates a user directly in the DB and returns a fake token
// that the fakeAuthMiddleware will accept. No HTTP register/login calls are made.
func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, _ string) (string, uint) {
t.Helper()
user := testutil.CreateTestUser(t, app.DB, username, email, "")
tok := uuid.NewString()
app.tokenStoreMu.Lock()
app.tokenStore[tok] = user
app.tokenStoreMu.Unlock()
return tok, user.ID
} }
func setupSecurityTest(t *testing.T) *SecurityTestApp { func setupSecurityTest(t *testing.T) *SecurityTestApp {
@@ -78,27 +122,25 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
notificationHandler := handlers.NewNotificationHandler(notificationService) notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
// Create router with real middleware app := &SecurityTestApp{
DB: db,
SubscriptionService: subscriptionService,
SubscriptionRepo: subscriptionRepo,
tokenStore: make(map[string]*models.User),
}
// Create router with fake auth middleware
e := echo.New() e := echo.New()
e.Validator = validator.NewCustomValidator() e.Validator = validator.NewCustomValidator()
e.HTTPErrorHandler = apperrors.HTTPErrorHandler e.HTTPErrorHandler = apperrors.HTTPErrorHandler
e.Use(middleware.TimezoneMiddleware()) e.Use(middleware.TimezoneMiddleware())
// Public routes
auth := e.Group("/api/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
// Protected routes // Protected routes
authMiddleware := middleware.NewAuthMiddleware(db, nil)
api := e.Group("/api") api := e.Group("/api")
api.Use(authMiddleware.TokenAuth()) api.Use(app.fakeAuthMiddleware())
{ {
api.GET("/auth/me", authHandler.CurrentUser) api.GET("/auth/me", authHandler.CurrentUser)
api.POST("/auth/logout", authHandler.Logout)
residences := api.Group("/residences") residences := api.Group("/residences")
{ {
@@ -146,42 +188,8 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
} }
} }
return &SecurityTestApp{ app.Router = e
DB: db, return app
Router: e,
SubscriptionService: subscriptionService,
SubscriptionRepo: subscriptionRepo,
}
}
// registerAndLoginSec registers and logs in a user, returns token and user ID.
func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, password string) (string, uint) {
// Register
registerBody := map[string]string{
"username": username,
"email": email,
"password": password,
}
w := app.makeAuthReq(t, "POST", "/api/auth/register", registerBody, "")
require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed for %s", username)
// Login
loginBody := map[string]string{
"username": username,
"password": password,
}
w = app.makeAuthReq(t, "POST", "/api/auth/login", loginBody, "")
require.Equal(t, http.StatusOK, w.Code, "Login should succeed for %s", username)
var loginResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
require.NoError(t, err)
token := loginResp["token"].(string)
userMap := loginResp["user"].(map[string]interface{})
userID := uint(userMap["id"].(float64))
return token, userID
} }
// makeAuthReq creates and sends an HTTP request through the router. // makeAuthReq creates and sends an HTTP request through the router.
@@ -6,12 +6,15 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors" "github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/config"
@@ -22,7 +25,6 @@ import (
"github.com/treytartt/honeydue-api/internal/services" "github.com/treytartt/honeydue-api/internal/services"
"github.com/treytartt/honeydue-api/internal/testutil" "github.com/treytartt/honeydue-api/internal/testutil"
"github.com/treytartt/honeydue-api/internal/validator" "github.com/treytartt/honeydue-api/internal/validator"
"gorm.io/gorm"
) )
// SubscriptionTestApp holds components for subscription integration testing // SubscriptionTestApp holds components for subscription integration testing
@@ -31,11 +33,51 @@ type SubscriptionTestApp struct {
Router *echo.Echo Router *echo.Echo
SubscriptionService *services.SubscriptionService SubscriptionService *services.SubscriptionService
SubscriptionRepo *repositories.SubscriptionRepository SubscriptionRepo *repositories.SubscriptionRepository
tokenStore map[string]*models.User
tokenStoreMu sync.RWMutex
}
// fakeAuthMiddleware returns an Echo middleware that authenticates requests using
// the in-process tokenStore instead of calling the real Kratos session endpoint.
func (app *SubscriptionTestApp) fakeAuthMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ah := c.Request().Header.Get("Authorization")
if ah == "" {
return apperrors.Unauthorized("error.not_authenticated")
}
tok := ah
if len(ah) > 6 && ah[:6] == "Token " {
tok = ah[6:]
} else if len(ah) > 7 && ah[:7] == "Bearer " {
tok = ah[7:]
}
app.tokenStoreMu.RLock()
user, ok := app.tokenStore[tok]
app.tokenStoreMu.RUnlock()
if !ok {
return apperrors.Unauthorized("error.not_authenticated")
}
c.Set("auth_user", user)
c.Set("auth_token", tok)
return next(c)
}
}
}
// registerAndLogin creates a user directly in the DB and returns a fake token
// and user ID. No HTTP register/login calls are made.
func (app *SubscriptionTestApp) registerAndLogin(t *testing.T, username, email, _ string) (string, uint) {
t.Helper()
user := testutil.CreateTestUser(t, app.DB, username, email, "")
tok := uuid.NewString()
app.tokenStoreMu.Lock()
app.tokenStore[tok] = user
app.tokenStoreMu.Unlock()
return tok, user.ID
} }
func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp { func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
// Echo does not need test mode
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
testutil.SeedLookupData(t, db) testutil.SeedLookupData(t, db)
@@ -67,22 +109,23 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true) residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil) subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
// Create router app := &SubscriptionTestApp{
DB: db,
SubscriptionService: subscriptionService,
SubscriptionRepo: subscriptionRepo,
tokenStore: make(map[string]*models.User),
}
// Create router with fake auth middleware
e := echo.New() e := echo.New()
e.Validator = validator.NewCustomValidator() e.Validator = validator.NewCustomValidator()
e.HTTPErrorHandler = apperrors.HTTPErrorHandler e.HTTPErrorHandler = apperrors.HTTPErrorHandler
// Public routes e.Use(middleware.TimezoneMiddleware())
auth := e.Group("/api/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
// Protected routes // Protected routes
authMiddleware := middleware.NewAuthMiddleware(db, nil)
api := e.Group("/api") api := e.Group("/api")
api.Use(authMiddleware.TokenAuth()) api.Use(app.fakeAuthMiddleware())
{ {
api.GET("/auth/me", authHandler.CurrentUser) api.GET("/auth/me", authHandler.CurrentUser)
@@ -98,12 +141,8 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
} }
} }
return &SubscriptionTestApp{ app.Router = e
DB: db, return app
Router: e,
SubscriptionService: subscriptionService,
SubscriptionRepo: subscriptionRepo,
}
} }
// Helper to make authenticated requests // Helper to make authenticated requests
@@ -129,36 +168,6 @@ func (app *SubscriptionTestApp) makeAuthenticatedRequest(t *testing.T, method, p
return w return w
} }
// Helper to register and login a user, returns token and user ID
func (app *SubscriptionTestApp) registerAndLogin(t *testing.T, username, email, password string) (string, uint) {
// Register
registerBody := map[string]string{
"username": username,
"email": email,
"password": password,
}
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
require.Equal(t, http.StatusCreated, w.Code)
// Login
loginBody := map[string]string{
"username": username,
"password": password,
}
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
require.Equal(t, http.StatusOK, w.Code)
var loginResp map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
require.NoError(t, err)
token := loginResp["token"].(string)
userMap := loginResp["user"].(map[string]interface{})
userID := uint(userMap["id"].(float64))
return token, userID
}
// TestIntegration_IsFreeBypassesLimitations tests that users with IsFree=true // TestIntegration_IsFreeBypassesLimitations tests that users with IsFree=true
// see limitations_enabled=false regardless of global settings // see limitations_enabled=false regardless of global settings
func TestIntegration_IsFreeBypassesLimitations(t *testing.T) { func TestIntegration_IsFreeBypassesLimitations(t *testing.T) {
+107
View File
@@ -0,0 +1,107 @@
// Package kratos is a thin client for the Ory Kratos public API. honeyDue
// delegates all identity concerns (credentials, sessions, verification,
// recovery, social sign-in) to Kratos; this client only validates sessions.
package kratos
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
)
// ErrUnauthorized is returned by Whoami when the session is missing, invalid,
// or inactive — the caller should respond 401.
var ErrUnauthorized = errors.New("kratos: session invalid or inactive")
// Client talks to the Ory Kratos public API.
type Client struct {
publicURL string
http *http.Client
}
// NewClient builds a Kratos client for the given public-API base URL
// (e.g. http://kratos:4433 in-cluster).
func NewClient(publicURL string) *Client {
return &Client{
publicURL: strings.TrimRight(publicURL, "/"),
http: &http.Client{Timeout: 5 * time.Second},
}
}
// Identity is the subset of a Kratos identity honeyDue consumes. It mirrors
// the identity schema in deploy-k3s/manifests/kratos/configmap.yaml.
type Identity struct {
ID string `json:"id"` // UUID — the stable identity identifier
Traits struct {
Email string `json:"email"`
Name struct {
First string `json:"first"`
Last string `json:"last"`
} `json:"name"`
} `json:"traits"`
VerifiableAddresses []struct {
Value string `json:"value"`
Verified bool `json:"verified"`
} `json:"verifiable_addresses"`
}
// Session is a Kratos session as returned by GET /sessions/whoami.
type Session struct {
ID string `json:"id"`
Active bool `json:"active"`
Identity Identity `json:"identity"`
}
// EmailVerified reports whether any of the identity's email addresses is
// verified — the source of truth for honeyDue's RequireVerified gate.
func (s *Session) EmailVerified() bool {
for _, a := range s.Identity.VerifiableAddresses {
if a.Verified {
return true
}
}
return false
}
// Whoami validates a session against Kratos. Supply the mobile session token
// (sessionToken) OR the browser cookie header (cookie) — whichever is
// non-empty is forwarded to Kratos. Returns ErrUnauthorized for an invalid or
// inactive session.
func (c *Client) Whoami(ctx context.Context, sessionToken, cookie string) (*Session, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.publicURL+"/sessions/whoami", nil)
if err != nil {
return nil, err
}
if sessionToken != "" {
req.Header.Set("X-Session-Token", sessionToken)
}
if cookie != "" {
req.Header.Set("Cookie", cookie)
}
resp, err := c.http.Do(req)
if err != nil {
return nil, fmt.Errorf("kratos whoami: %w", err)
}
defer resp.Body.Close()
switch {
case resp.StatusCode == http.StatusUnauthorized, resp.StatusCode == http.StatusForbidden:
return nil, ErrUnauthorized
case resp.StatusCode != http.StatusOK:
return nil, fmt.Errorf("kratos whoami: unexpected status %d", resp.StatusCode)
}
var s Session
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
return nil, fmt.Errorf("kratos whoami: decode: %w", err)
}
if !s.Active || s.Identity.ID == "" {
return nil, ErrUnauthorized
}
return &s, nil
}
-438
View File
@@ -1,438 +0,0 @@
package middleware
import (
"context"
"fmt"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/services"
)
const (
// AuthUserKey is the key used to store the authenticated user in the context
AuthUserKey = "auth_user"
// AuthTokenKey is the key used to store the token in the context
AuthTokenKey = "auth_token"
// TokenCacheTTL is the duration to cache tokens in Redis. Tokens are
// valid for DefaultTokenExpiryDays (90), and explicit logout invalidates
// the cache, so a long TTL here just means most authed requests skip the
// auth-token SQL query entirely.
TokenCacheTTL = 1 * time.Hour
// TokenCachePrefix is the prefix for token cache keys
TokenCachePrefix = "auth_token_"
// UserCacheTTL is how long full user records are cached in memory to
// avoid hitting the database on every authenticated request. Bumped from
// 30s — at 30s the trace showed a SELECT auth_user query on most warm
// requests because users aren't in cache long enough to hit twice.
UserCacheTTL = 5 * time.Minute
// UserCacheMaxSize bounds the per-pod in-memory user cache. With ~1KB
// per User struct, 5000 entries = ~5MB per pod. Older entries are
// evicted LRU before the limit is exceeded.
UserCacheMaxSize = 5000
// DefaultTokenExpiryDays is the default number of days before a token expires.
DefaultTokenExpiryDays = 90
)
// AuthMiddleware provides token authentication middleware
type AuthMiddleware struct {
db *gorm.DB
cache *services.CacheService
userCache *UserCache
tokenExpiryDays int
}
// NewAuthMiddleware creates a new auth middleware instance
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
return &AuthMiddleware{
db: db,
cache: cache,
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
tokenExpiryDays: DefaultTokenExpiryDays,
}
}
// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration
func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware {
expiryDays := DefaultTokenExpiryDays
if cfg != nil && cfg.Security.TokenExpiryDays > 0 {
expiryDays = cfg.Security.TokenExpiryDays
}
return &AuthMiddleware{
db: db,
cache: cache,
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
tokenExpiryDays: expiryDays,
}
}
// TokenExpiryDuration returns the token expiry duration.
func (m *AuthMiddleware) TokenExpiryDuration() time.Duration {
return time.Duration(m.tokenExpiryDays) * 24 * time.Hour
}
// isTokenExpired checks if a token's created timestamp indicates expiry.
func (m *AuthMiddleware) isTokenExpired(created time.Time) bool {
if created.IsZero() {
return false // Legacy tokens without created time are not expired
}
return time.Since(created) > m.TokenExpiryDuration()
}
// TokenAuth returns an Echo middleware that validates token authentication
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Extract token from Authorization header
token, err := extractToken(c)
if err != nil {
return apperrors.Unauthorized("error.not_authenticated")
}
// Try to get user from cache first (includes expiry check)
user, err := m.getUserFromCache(c.Request().Context(), token)
if err == nil && user != nil {
// Cache hit - set user in context and continue
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
return next(c)
}
// Check if the cache indicated token expiry
if err != nil && err.Error() == "token expired" {
return apperrors.Unauthorized("error.token_expired")
}
// Cache miss - look up token in database
user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token)
if err != nil {
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
return apperrors.Unauthorized("error.invalid_token")
}
// Check token expiry
if m.isTokenExpired(authToken.Created) {
log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired")
return apperrors.Unauthorized("error.token_expired")
}
// Cache the user ID and token creation time for future requests
if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil {
log.Warn().Err(cacheErr).Msg("Failed to cache token info")
}
// Set user in context
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
return next(c)
}
}
}
// OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it
func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
token, err := extractToken(c)
if err != nil {
// No token or invalid format - continue without user
return next(c)
}
// Try cache first
user, err := m.getUserFromCache(c.Request().Context(), token)
if err == nil && user != nil {
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
return next(c)
}
// Try database
user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token)
if err == nil && !m.isTokenExpired(authToken.Created) {
m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created)
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
}
return next(c)
}
}
}
// RequireVerified returns middleware that rejects users whose email is not
// verified (audit LIVE-L19). Apply it after TokenAuth to gate sensitive
// actions — e.g. generating residence share codes — behind proof that the
// account actually controls its email address.
func (m *AuthMiddleware) RequireVerified() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return apperrors.Unauthorized("error.not_authenticated")
}
var verified bool
err := m.db.WithContext(c.Request().Context()).
Model(&models.UserProfile{}).
Where("user_id = ?", user.ID).
Select("verified").
Scan(&verified).Error
if err != nil {
return apperrors.Internal(err)
}
if !verified {
return apperrors.Forbidden("error.email_not_verified")
}
return next(c)
}
}
}
// extractToken extracts the token from the Authorization header
func extractToken(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return "", fmt.Errorf("authorization header required")
}
// Support both "Token xxx" (Django style) and "Bearer xxx" formats
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 {
return "", fmt.Errorf("invalid authorization header format")
}
scheme := parts[0]
token := parts[1]
if scheme != "Token" && scheme != "Bearer" {
return "", fmt.Errorf("invalid authorization scheme: %s", scheme)
}
if token == "" {
return "", fmt.Errorf("token is empty")
}
return token, nil
}
// getUserFromCache tries to get user from Redis cache, then from the
// in-memory user cache, before falling back to the database.
// Returns a "token expired" error if the cached creation time indicates expiry.
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
if m.cache == nil {
return nil, fmt.Errorf("cache not available")
}
userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token)
if err != nil {
if err == redis.Nil {
return nil, fmt.Errorf("token not in cache")
}
return nil, err
}
// Check token expiry from cached creation time
if createdUnix > 0 {
created := time.Unix(createdUnix, 0)
if m.isTokenExpired(created) {
m.cache.InvalidateAuthToken(ctx, token)
return nil, fmt.Errorf("token expired")
}
}
// Try in-memory user cache first to avoid a DB round-trip
if cached := m.userCache.Get(userID); cached != nil {
if !cached.IsActive {
m.cache.InvalidateAuthToken(ctx, token)
m.userCache.Invalidate(userID)
return nil, fmt.Errorf("user is inactive")
}
return cached, nil
}
// In-memory cache miss — fetch from database
var user models.User
if err := m.db.WithContext(ctx).First(&user, userID).Error; err != nil {
// User was deleted - invalidate caches
m.cache.InvalidateAuthToken(ctx, token)
return nil, err
}
// Check if user is active
if !user.IsActive {
m.cache.InvalidateAuthToken(ctx, token)
return nil, fmt.Errorf("user is inactive")
}
// Store in in-memory cache for subsequent requests
m.userCache.Set(&user)
return &user, nil
}
// getUserFromDatabaseWithToken looks up the token in the database and returns
// both the user and the auth token record (for expiry checking). The ctx is
// threaded into the GORM session so the SQL span attaches to the request trace.
//
// Uses a single JOIN query instead of GORM's Preload (which issues 2 SELECTs).
// Over a transatlantic link this saves ~110ms RTT per cache miss.
func (m *AuthMiddleware) getUserFromDatabaseWithToken(ctx context.Context, token string) (*models.User, *models.AuthToken, error) {
// Flat result row: every column from auth_user prefixed `u_`, every
// column from user_authtoken left in its native shape. Mapping to two
// structs is mechanical so we don't need a struct tag soup.
type joinedRow struct {
// AuthToken columns
Key string `gorm:"column:key"`
Created time.Time `gorm:"column:created"`
UserID uint `gorm:"column:user_id"`
// User columns (prefixed to avoid collision with UserID)
UID uint `gorm:"column:u_id"`
UUsername string `gorm:"column:u_username"`
UEmail string `gorm:"column:u_email"`
UFirstName string `gorm:"column:u_first_name"`
ULastName string `gorm:"column:u_last_name"`
UPassword string `gorm:"column:u_password"`
UIsActive bool `gorm:"column:u_is_active"`
UIsStaff bool `gorm:"column:u_is_staff"`
UIsSuper bool `gorm:"column:u_is_superuser"`
UDateJoined time.Time `gorm:"column:u_date_joined"`
ULastLogin *time.Time `gorm:"column:u_last_login"`
}
var row joinedRow
err := m.db.WithContext(ctx).
Table("user_authtoken AS t").
Select(`
t.key, t.created, t.user_id,
u.id AS u_id,
u.username AS u_username,
u.email AS u_email,
u.first_name AS u_first_name,
u.last_name AS u_last_name,
u.password AS u_password,
u.is_active AS u_is_active,
u.is_staff AS u_is_staff,
u.is_superuser AS u_is_superuser,
u.date_joined AS u_date_joined,
u.last_login AS u_last_login
`).
Joins("INNER JOIN auth_user u ON u.id = t.user_id").
Where("t.key = ?", models.HashToken(token)). // audit C1: only the hash is stored
Limit(1).
Scan(&row).Error
if err != nil || row.Key == "" {
return nil, nil, fmt.Errorf("token not found")
}
user := models.User{
ID: row.UID,
Username: row.UUsername,
Email: row.UEmail,
FirstName: row.UFirstName,
LastName: row.ULastName,
Password: row.UPassword,
IsActive: row.UIsActive,
IsStaff: row.UIsStaff,
IsSuperuser: row.UIsSuper,
DateJoined: row.UDateJoined,
LastLogin: row.ULastLogin,
}
authToken := models.AuthToken{
Key: row.Key,
Created: row.Created,
UserID: row.UserID,
User: user,
}
if !user.IsActive {
return nil, nil, fmt.Errorf("user is inactive")
}
m.userCache.Set(&user)
return &user, &authToken, nil
}
// getUserFromDatabase looks up the token in the database and caches the
// resulting user record in memory.
// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking.
func (m *AuthMiddleware) getUserFromDatabase(ctx context.Context, token string) (*models.User, error) {
user, _, err := m.getUserFromDatabaseWithToken(ctx, token)
return user, err
}
// cacheTokenInfo caches the user ID and token creation time for a token
func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error {
if m.cache == nil {
return nil
}
return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix())
}
// cacheUserID caches the user ID for a token
func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error {
if m.cache == nil {
return nil
}
return m.cache.CacheAuthToken(ctx, token, userID)
}
// InvalidateToken removes a token from the cache
func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error {
if m.cache == nil {
return nil
}
return m.cache.InvalidateAuthToken(ctx, token)
}
// GetAuthUser retrieves the authenticated user from the Echo context.
// Returns nil if the context value is missing or not of the expected type.
func GetAuthUser(c echo.Context) *models.User {
val := c.Get(AuthUserKey)
if val == nil {
return nil
}
user, ok := val.(*models.User)
if !ok {
return nil
}
return user
}
// GetAuthToken retrieves the auth token from the Echo context
func GetAuthToken(c echo.Context) string {
token := c.Get(AuthTokenKey)
if token == nil {
return ""
}
tokenStr, ok := token.(string)
if !ok {
return ""
}
return tokenStr
}
// MustGetAuthUser retrieves the authenticated user or returns error with 401
func MustGetAuthUser(c echo.Context) (*models.User, error) {
user := GetAuthUser(c)
if user == nil {
return nil, apperrors.Unauthorized("error.not_authenticated")
}
return user, nil
}
// truncateToken safely truncates a token string for logging.
// Returns at most the first 8 characters followed by "...".
func truncateToken(token string) string {
if len(token) > 8 {
return token[:8] + "..."
}
return token + "..."
}
-165
View File
@@ -1,165 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/honeydue-api/internal/models"
)
// setupTestDB creates a temporary in-memory SQLite database with the required
// tables for auth middleware tests.
func setupTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&models.User{}, &models.AuthToken{})
require.NoError(t, err)
return db
}
// createTestUserAndToken creates a user and an auth token, then backdates the
// token's Created timestamp by the specified number of days.
func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) {
t.Helper()
user := &models.User{
Username: username,
Email: username + "@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("Password123"))
require.NoError(t, db.Create(user).Error)
token := &models.AuthToken{
UserID: user.ID,
}
require.NoError(t, db.Create(token).Error)
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
require.NoError(t, db.Model(token).Update("created", backdated).Error)
token.Created = backdated
return user, token
}
func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry
m := NewAuthMiddleware(db, nil) // No Redis cache for these tests
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestTokenAuth_AcceptsValidToken(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify user was set in context
user := GetAuthUser(c)
require.NotNil(t, user)
assert.Equal(t, "valid_user", user.Username)
}
func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestTokenAuth_RejectsInvalidToken(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token nonexistent-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
-337
View File
@@ -1,337 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
)
func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "bearer_user", 10)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Bearer "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
user := GetAuthUser(c)
require.NotNil(t, user)
assert.Equal(t, "bearer_user", user.Username)
}
func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "scheme_user", 10)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Basic "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "JustATokenWithNoScheme")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
func TestTokenAuth_EmptyToken_Rejected(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token ")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.not_authenticated")
}
func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
db := setupTestDB(t)
user, token := createTestUserAndToken(t, db, "inactive_user", 10)
// Deactivate the user
require.NoError(t, db.Model(user).Update("is_active", false).Error)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
// No Authorization header
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-user", rec.Body.String())
}
func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "opt_user", 10)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "opt_user", rec.Body.String())
}
func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
db := setupTestDB(t)
_, token := createTestUserAndToken(t, db, "expired_opt_user", 91)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-user", rec.Body.String())
}
func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token nonexistent-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return c.String(http.StatusOK, "no-user")
}
return c.String(http.StatusOK, user.Username)
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no-user", rec.Body.String())
}
func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
db := setupTestDB(t)
cfg := &config.Config{
Security: config.SecurityConfig{
TokenExpiryDays: 30,
},
}
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
assert.NotNil(t, m)
assert.Equal(t, 30, m.tokenExpiryDays)
// Token at 29 days should be valid
_, token := createTestUserAndToken(t, db, "short_expiry_user", 29)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
}
func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
db := setupTestDB(t)
cfg := &config.Config{
Security: config.SecurityConfig{
TokenExpiryDays: 30,
},
}
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
// Token at 31 days should be expired with 30-day config
_, token := createTestUserAndToken(t, db, "custom_expired_user", 31)
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
handler := m.TokenAuth()(func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
err := handler(c)
require.Error(t, err)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddlewareWithConfig(db, nil, nil)
assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays)
}
func TestGetAuthToken_ReturnsToken(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(AuthTokenKey, "test-token-value")
assert.Equal(t, "test-token-value", GetAuthToken(c))
}
func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
// No token set
assert.Equal(t, "", GetAuthToken(c))
}
func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
c.Set(AuthTokenKey, 12345) // Wrong type
assert.Equal(t, "", GetAuthToken(c))
}
func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil)
// Legacy tokens without created time should not be expired
assert.False(t, m.isTokenExpired(models.AuthToken{}.Created))
}
func TestInvalidateToken_NilCache_NoError(t *testing.T) {
db := setupTestDB(t)
m := NewAuthMiddleware(db, nil) // nil cache
err := m.InvalidateToken(nil, "some-token")
assert.NoError(t, err)
}
+271
View File
@@ -0,0 +1,271 @@
package middleware
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/apperrors"
"github.com/treytartt/honeydue-api/internal/kratos"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/services"
)
const (
// AuthUserKey stores the authenticated *models.User in the echo context.
AuthUserKey = "auth_user"
// AuthTokenKey stores the raw session credential in the echo context.
AuthTokenKey = "auth_token"
// authVerifiedKey stores the Kratos email-verified flag in the context.
authVerifiedKey = "auth_email_verified"
// UserCacheTTL / UserCacheMaxSize bound the in-memory local-user cache.
UserCacheTTL = 5 * time.Minute
UserCacheMaxSize = 5000
// kratosSessionCacheTTL is how long a validated session is cached in
// Redis, so most authed requests skip the Kratos /whoami round trip.
kratosSessionCacheTTL = 5 * time.Minute
kratosSessionPrefix = "kratos_sess:"
)
// KratosAuth authenticates requests against an Ory Kratos session. It
// replaces the hand-rolled token auth: the session is validated via Kratos
// /sessions/whoami (Redis-cached), and the matching local auth_user row is
// lazily provisioned on first sight of a Kratos identity.
type KratosAuth struct {
kratos *kratos.Client
cache *services.CacheService
db *gorm.DB
userCache *UserCache
}
// NewKratosAuth builds the Kratos auth middleware.
func NewKratosAuth(k *kratos.Client, cache *services.CacheService, db *gorm.DB) *KratosAuth {
return &KratosAuth{
kratos: k,
cache: cache,
db: db,
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
}
}
// Authenticate validates the Kratos session and requires it.
func (m *KratosAuth) Authenticate() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user, verified, cred, err := m.resolve(c)
if err != nil {
log.Debug().Err(err).Msg("Kratos authentication failed")
return apperrors.Unauthorized("error.not_authenticated")
}
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, cred)
c.Set(authVerifiedKey, verified)
return next(c)
}
}
}
// OptionalAuthenticate authenticates if a session is present, else continues
// unauthenticated.
func (m *KratosAuth) OptionalAuthenticate() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if user, verified, cred, err := m.resolve(c); err == nil {
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, cred)
c.Set(authVerifiedKey, verified)
}
return next(c)
}
}
}
// RequireVerified rejects users whose Kratos email address is not verified.
// Apply after Authenticate.
func (m *KratosAuth) RequireVerified() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if GetAuthUser(c) == nil {
return apperrors.Unauthorized("error.not_authenticated")
}
if verified, _ := c.Get(authVerifiedKey).(bool); !verified {
return apperrors.Forbidden("error.email_not_verified")
}
return next(c)
}
}
}
// resolve validates the request's session and returns the local user.
func (m *KratosAuth) resolve(c echo.Context) (*models.User, bool, string, error) {
token, cookie := extractSession(c)
if token == "" && cookie == "" {
return nil, false, "", errors.New("no session credential")
}
cred := token
if cred == "" {
cred = cookie
}
ctx := c.Request().Context()
// Redis cache: kratos_sess:<hash(cred)> -> "<userID>|<0|1>"
cacheKey := kratosSessionPrefix + hashCredential(cred)
if m.cache != nil {
if v, err := m.cache.GetString(ctx, cacheKey); err == nil && v != "" {
if user, verified, ok := m.userFromCacheValue(ctx, v); ok {
return user, verified, cred, nil
}
}
}
sess, err := m.kratos.Whoami(ctx, token, cookie)
if err != nil {
return nil, false, "", err
}
user, err := m.provision(ctx, sess)
if err != nil {
return nil, false, "", err
}
if m.cache != nil {
_ = m.cache.SetString(ctx, cacheKey,
fmt.Sprintf("%d|%s", user.ID, boolDigit(sess.EmailVerified())), kratosSessionCacheTTL)
}
return user, sess.EmailVerified(), cred, nil
}
// provision finds the local auth_user row for a Kratos identity, creating it
// (and a UserProfile) on first sight. Concurrent first requests are handled
// by re-reading after a unique-constraint conflict.
func (m *KratosAuth) provision(ctx context.Context, sess *kratos.Session) (*models.User, error) {
var user models.User
err := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error
if err == nil {
return &user, nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
user = models.User{
KratosID: sess.Identity.ID,
Email: sess.Identity.Traits.Email,
Username: sess.Identity.Traits.Email,
FirstName: sess.Identity.Traits.Name.First,
LastName: sess.Identity.Traits.Name.Last,
IsActive: true,
DateJoined: time.Now().UTC(),
}
txErr := m.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&user).Error; err != nil {
return err
}
return tx.Create(&models.UserProfile{
UserID: user.ID,
Verified: sess.EmailVerified(),
}).Error
})
if txErr != nil {
// Likely a concurrent provision of the same identity — re-read.
if e := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error; e == nil {
return &user, nil
}
return nil, txErr
}
log.Info().Str("kratos_id", sess.Identity.ID).Uint("user_id", user.ID).
Msg("provisioned local user from Kratos identity")
return &user, nil
}
// userFromCacheValue resolves a cached "<userID>|<0|1>" value to a user.
func (m *KratosAuth) userFromCacheValue(ctx context.Context, v string) (*models.User, bool, bool) {
parts := strings.SplitN(v, "|", 2)
if len(parts) != 2 {
return nil, false, false
}
var id uint
if _, err := fmt.Sscanf(parts[0], "%d", &id); err != nil || id == 0 {
return nil, false, false
}
verified := parts[1] == "1"
if cached := m.userCache.Get(id); cached != nil {
return cached, verified, true
}
var user models.User
if err := m.db.WithContext(ctx).First(&user, id).Error; err != nil {
return nil, false, false
}
m.userCache.Set(&user)
return &user, verified, true
}
// extractSession pulls the session credential from the request: the
// X-Session-Token header or Authorization bearer (mobile clients), or the
// ory_kratos_session cookie (web).
func extractSession(c echo.Context) (token, cookie string) {
if t := c.Request().Header.Get("X-Session-Token"); t != "" {
token = t
} else if ah := c.Request().Header.Get("Authorization"); ah != "" {
parts := strings.SplitN(ah, " ", 2)
if len(parts) == 2 && (parts[0] == "Bearer" || parts[0] == "Token") {
token = parts[1]
}
}
if token == "" {
if ck := c.Request().Header.Get("Cookie"); strings.Contains(ck, "ory_kratos_session") {
cookie = ck
}
}
return token, cookie
}
func hashCredential(cred string) string {
sum := sha256.Sum256([]byte(cred))
return hex.EncodeToString(sum[:])
}
func boolDigit(b bool) string {
if b {
return "1"
}
return "0"
}
// truncateToken returns the first 8 characters of a credential followed by
// "..." for safe inclusion in log lines.
func truncateToken(tok string) string {
if len(tok) <= 8 {
return tok + "..."
}
return tok[:8] + "..."
}
// GetAuthUser retrieves the authenticated user from the echo context.
func GetAuthUser(c echo.Context) *models.User {
user, _ := c.Get(AuthUserKey).(*models.User)
return user
}
// GetAuthToken retrieves the session credential from the echo context.
func GetAuthToken(c echo.Context) string {
tok, _ := c.Get(AuthTokenKey).(string)
return tok
}
// MustGetAuthUser retrieves the authenticated user or returns a 401 error.
func MustGetAuthUser(c echo.Context) (*models.User, error) {
user := GetAuthUser(c)
if user == nil {
return nil, apperrors.Unauthorized("error.not_authenticated")
}
return user, nil
}
+1 -125
View File
@@ -19,7 +19,7 @@ func setupModelsTestDB(t *testing.T) *gorm.DB {
Logger: logger.Default.LogMode(logger.Silent), Logger: logger.Default.LogMode(logger.Silent),
}) })
require.NoError(t, err) require.NoError(t, err)
err = db.AutoMigrate(&User{}, &AuthToken{}, &UserProfile{}) err = db.AutoMigrate(&User{}, &UserProfile{})
require.NoError(t, err) require.NoError(t, err)
return db return db
} }
@@ -233,105 +233,6 @@ func TestNotificationType_Constants(t *testing.T) {
assert.Equal(t, NotificationType("warranty_expiring"), NotificationWarrantyExpiring) assert.Equal(t, NotificationType("warranty_expiring"), NotificationWarrantyExpiring)
} }
// === AuthToken model tests ===
func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "tokenuser",
Email: "token@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
token := &AuthToken{UserID: user.ID}
err = db.Create(token).Error
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Len(t, token.Key, 64) // SHA-256 hex hash (audit C1)
assert.Len(t, token.Plaintext, 40) // raw 20-byte token, returned to the client
assert.False(t, token.Created.IsZero())
}
func TestAuthToken_BeforeCreate_PreservesExistingKey(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "tokenuser",
Email: "token@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
existingKey := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
token := &AuthToken{
Key: existingKey,
UserID: user.ID,
}
err = db.Create(token).Error
require.NoError(t, err)
assert.Equal(t, existingKey, token.Key)
}
func TestGetOrCreateToken_CreatesNew(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "newtoken",
Email: "newtoken@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
token, err := GetOrCreateToken(db, user.ID)
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Equal(t, user.ID, token.UserID)
}
func TestGetOrCreateToken_ReturnsExisting(t *testing.T) {
db := setupModelsTestDB(t)
user := &User{
Username: "existingtoken",
Email: "existingtoken@test.com",
Password: "dummy",
IsActive: true,
}
err := db.Create(user).Error
require.NoError(t, err)
token1, err := GetOrCreateToken(db, user.ID)
require.NoError(t, err)
token2, err := GetOrCreateToken(db, user.ID)
require.NoError(t, err)
assert.Equal(t, token1.Key, token2.Key)
}
// === User model additional tests ===
func TestUser_SetPassword_And_CheckPassword_Integration(t *testing.T) {
user := &User{}
err := user.SetPassword("Password123")
require.NoError(t, err)
assert.True(t, user.CheckPassword("Password123"))
assert.False(t, user.CheckPassword("WrongPassword"))
assert.False(t, user.CheckPassword(""))
assert.False(t, user.CheckPassword("password123")) // case sensitive
}
// === Task model additional tests === // === Task model additional tests ===
func TestTask_IsOverdue_CancelledNotOverdue(t *testing.T) { func TestTask_IsOverdue_CancelledNotOverdue(t *testing.T) {
@@ -565,31 +466,6 @@ func TestGetDefaultProLimits(t *testing.T) {
assert.Nil(t, limits.DocumentsLimit) assert.Nil(t, limits.DocumentsLimit)
} }
// === ConfirmationCode additional tests ===
func TestConfirmationCode_TableName(t *testing.T) {
cc := ConfirmationCode{}
assert.Equal(t, "user_confirmationcode", cc.TableName())
}
// === PasswordResetCode additional tests ===
func TestPasswordResetCode_TableName(t *testing.T) {
prc := PasswordResetCode{}
assert.Equal(t, "user_passwordresetcode", prc.TableName())
}
// === Social Auth TableName tests ===
func TestAppleSocialAuth_TableName(t *testing.T) {
a := AppleSocialAuth{}
assert.Equal(t, "user_applesocialauth", a.TableName())
}
func TestGoogleSocialAuth_TableName(t *testing.T) {
g := GoogleSocialAuth{}
assert.Equal(t, "user_googlesocialauth", g.TableName())
}
// === BaseModel tests === // === BaseModel tests ===
+26 -263
View File
@@ -1,69 +1,38 @@
package models package models
import ( import "time"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"time"
"golang.org/x/crypto/bcrypt" // User represents the auth_user table. Identity — credentials, email
"gorm.io/gorm" // verification, sessions, social sign-in — is owned by Ory Kratos (phase 2).
) // This row is honeyDue's local mirror of a Kratos identity, linked by
// KratosID; every domain table keeps its existing integer FK to auth_user.id.
// User represents the auth_user table (Django's default User model)
type User struct { type User struct {
ID uint `gorm:"primaryKey" json:"id"` ID uint `gorm:"primaryKey" json:"id"`
Password string `gorm:"column:password;size:128;not null" json:"-"` KratosID string `gorm:"column:kratos_id;uniqueIndex;size:36" json:"-"` // Kratos identity UUID
Username string `gorm:"column:username;uniqueIndex;size:150" json:"username"`
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"`
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"`
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
LastLogin *time.Time `gorm:"column:last_login" json:"last_login,omitempty"` LastLogin *time.Time `gorm:"column:last_login" json:"last_login,omitempty"`
IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"`
Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"`
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"`
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
// Relations (not stored in auth_user table) // Relations not columns on auth_user.
Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"` Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"`
AuthToken *AuthToken `gorm:"foreignKey:UserID" json:"-"` OwnedResidences []Residence `gorm:"foreignKey:OwnerID" json:"-"`
OwnedResidences []Residence `gorm:"foreignKey:OwnerID" json:"-"` SharedResidences []Residence `gorm:"many2many:residence_residence_users;" json:"-"`
SharedResidences []Residence `gorm:"many2many:residence_residence_users;" json:"-"`
NotificationPref *NotificationPreference `gorm:"foreignKey:UserID" json:"-"` NotificationPref *NotificationPreference `gorm:"foreignKey:UserID" json:"-"`
Subscription *UserSubscription `gorm:"foreignKey:UserID" json:"-"` Subscription *UserSubscription `gorm:"foreignKey:UserID" json:"-"`
} }
// TableName returns the table name for GORM // TableName returns the table name for GORM.
func (User) TableName() string { func (User) TableName() string {
return "auth_user" return "auth_user"
} }
// BcryptCost is the bcrypt work factor for password and code hashing. // GetFullName returns the user's display name.
// 12 (audit M2) is stronger than bcrypt.DefaultCost (10).
const BcryptCost = 12
// SetPassword hashes and sets the password
func (u *User) SetPassword(password string) error {
// Django uses PBKDF2_SHA256 by default, but we use bcrypt for Go.
// Passwords set by Django won't verify with Go's bcrypt check — those
// users must reset their password after migration.
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
if err != nil {
return err
}
u.Password = string(hash)
return nil
}
// CheckPassword verifies a password against the stored hash
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
return err == nil
}
// GetFullName returns the user's full name
func (u *User) GetFullName() string { func (u *User) GetFullName() string {
if u.FirstName != "" && u.LastName != "" { if u.FirstName != "" && u.LastName != "" {
return u.FirstName + " " + u.LastName return u.FirstName + " " + u.LastName
@@ -74,80 +43,9 @@ func (u *User) GetFullName() string {
return u.Username return u.Username
} }
// AuthToken represents the user_authtoken table. // UserProfile represents the user_userprofile table — honeyDue-specific
// // profile data, keyed to a local user. Email-verification state is owned by
// Audit C1: the Key column stores the SHA-256 hash of the token, never the // Kratos; the Verified column is a convenience mirror set at provision time.
// token itself. The raw token is handed to the client exactly once, at
// creation, via the non-persisted Plaintext field — it is never stored or
// logged. A database compromise therefore yields no usable session tokens.
type AuthToken struct {
Key string `gorm:"column:key;primaryKey;size:64" json:"-"`
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
Created time.Time `gorm:"column:created;autoCreateTime" json:"created"`
// Plaintext is the raw token value. It is never persisted (gorm:"-")
// and is only populated on a freshly-created token so the caller can
// return it to the client. On a token loaded from the DB it is "".
Plaintext string `gorm:"-" json:"-"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (AuthToken) TableName() string {
return "user_authtoken"
}
// BeforeCreate generates a token if one is not already set, storing only
// its hash in Key and the raw value in the non-persisted Plaintext field.
func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
if t.Key == "" {
raw := generateToken()
t.Plaintext = raw
t.Key = HashToken(raw)
}
if t.Created.IsZero() {
t.Created = time.Now().UTC()
}
return nil
}
// generateToken creates a random 40-character hex token (the raw value).
func generateToken() string {
b := make([]byte, 20)
rand.Read(b)
return hex.EncodeToString(b)
}
// HashToken returns the at-rest representation of an auth token: the
// hex-encoded SHA-256 hash. Auth tokens are 160-bit random values, so a
// fast deterministic hash is appropriate — there is nothing to brute-force,
// and determinism preserves the single indexed-lookup query in the auth
// middleware. The raw token is never stored.
func HashToken(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// GetOrCreate gets an existing token or creates a new one for the user
func GetOrCreateToken(tx *gorm.DB, userID uint) (*AuthToken, error) {
var token AuthToken
result := tx.Where("user_id = ?", userID).First(&token)
if result.Error == gorm.ErrRecordNotFound {
token = AuthToken{UserID: userID}
if err := tx.Create(&token).Error; err != nil {
return nil, err
}
} else if result.Error != nil {
return nil, result.Error
}
return &token, nil
}
// UserProfile represents the user_userprofile table
type UserProfile struct { type UserProfile struct {
BaseModel BaseModel
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"` UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
@@ -161,142 +59,7 @@ type UserProfile struct {
User User `gorm:"foreignKey:UserID" json:"-"` User User `gorm:"foreignKey:UserID" json:"-"`
} }
// TableName returns the table name for GORM // TableName returns the table name for GORM.
func (UserProfile) TableName() string { func (UserProfile) TableName() string {
return "user_userprofile" return "user_userprofile"
} }
// ConfirmationCode represents the user_confirmationcode table
type ConfirmationCode struct {
BaseModel
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
Code string `gorm:"column:code;size:6;not null" json:"-"`
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (ConfirmationCode) TableName() string {
return "user_confirmationcode"
}
// IsValid checks if the confirmation code is still valid
func (c *ConfirmationCode) IsValid() bool {
return !c.IsUsed && time.Now().UTC().Before(c.ExpiresAt)
}
// GenerateConfirmationCode creates a uniformly-random 6-digit code using
// rejection sampling on crypto/rand (audit H4 — removes the modulo bias of
// the previous implementation).
func GenerateConfirmationCode() string {
for {
var b [4]byte
if _, err := rand.Read(b[:]); err != nil {
continue
}
// 4294000000 is the largest multiple of 1e6 <= MaxUint32; rejecting
// the tail above it makes n % 1000000 perfectly uniform.
n := binary.BigEndian.Uint32(b[:])
if n < 4294000000 {
return fmt.Sprintf("%06d", n%1000000)
}
}
}
// PasswordResetCode represents the user_passwordresetcode table
type PasswordResetCode struct {
BaseModel
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
CodeHash string `gorm:"column:code_hash;size:128;not null" json:"-"`
ResetToken string `gorm:"column:reset_token;uniqueIndex;size:64;not null" json:"reset_token"`
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
Used bool `gorm:"column:used;default:false" json:"used"`
Attempts int `gorm:"column:attempts;default:0" json:"attempts"`
MaxAttempts int `gorm:"column:max_attempts;default:5" json:"max_attempts"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (PasswordResetCode) TableName() string {
return "user_passwordresetcode"
}
// SetCode hashes and stores the reset code
func (p *PasswordResetCode) SetCode(code string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(code), BcryptCost)
if err != nil {
return err
}
p.CodeHash = string(hash)
return nil
}
// CheckCode verifies a code against the stored hash
func (p *PasswordResetCode) CheckCode(code string) bool {
err := bcrypt.CompareHashAndPassword([]byte(p.CodeHash), []byte(code))
return err == nil
}
// IsValid checks if the reset code is still valid
func (p *PasswordResetCode) IsValid() bool {
return !p.Used && time.Now().UTC().Before(p.ExpiresAt) && p.Attempts < p.MaxAttempts
}
// IncrementAttempts increments the attempt counter
func (p *PasswordResetCode) IncrementAttempts(tx *gorm.DB) error {
p.Attempts++
return tx.Model(p).Update("attempts", p.Attempts).Error
}
// MarkAsUsed marks the code as used
func (p *PasswordResetCode) MarkAsUsed(tx *gorm.DB) error {
p.Used = true
return tx.Model(p).Update("used", true).Error
}
// GenerateResetToken creates a URL-safe token
func GenerateResetToken() string {
b := make([]byte, 32)
rand.Read(b)
return hex.EncodeToString(b)
}
// AppleSocialAuth represents a user's linked Apple ID for Sign in with Apple
type AppleSocialAuth struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"uniqueIndex;not null" json:"user_id"`
User User `gorm:"foreignKey:UserID" json:"-"`
AppleID string `gorm:"column:apple_id;size:255;uniqueIndex;not null" json:"apple_id"` // Apple's unique subject ID
Email string `gorm:"column:email;size:254" json:"email"` // May be private relay
IsPrivateEmail bool `gorm:"column:is_private_email;default:false" json:"is_private_email"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
}
// TableName returns the table name for GORM
func (AppleSocialAuth) TableName() string {
return "user_applesocialauth"
}
// GoogleSocialAuth represents a user's linked Google account for Sign in with Google
type GoogleSocialAuth struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"uniqueIndex;not null" json:"user_id"`
User User `gorm:"foreignKey:UserID" json:"-"`
GoogleID string `gorm:"column:google_id;size:255;uniqueIndex;not null" json:"google_id"` // Google's unique subject ID
Email string `gorm:"column:email;size:254" json:"email"`
Name string `gorm:"column:name;size:255" json:"name"`
Picture string `gorm:"column:picture;size:512" json:"picture"` // Profile picture URL
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
}
// TableName returns the table name for GORM
func (GoogleSocialAuth) TableName() string {
return "user_googlesocialauth"
}
+3 -167
View File
@@ -2,50 +2,15 @@ package models
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestUser_SetPassword(t *testing.T) {
user := &User{}
err := user.SetPassword("testPassword123")
require.NoError(t, err)
assert.NotEmpty(t, user.Password)
assert.NotEqual(t, "testPassword123", user.Password) // Should be hashed
}
func TestUser_CheckPassword(t *testing.T) {
user := &User{}
err := user.SetPassword("correctpassword")
require.NoError(t, err)
tests := []struct {
name string
password string
expected bool
}{
{"correct password", "correctpassword", true},
{"wrong password", "wrongpassword", false},
{"empty password", "", false},
{"similar password", "correctpassword1", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := user.CheckPassword(tt.password)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUser_GetFullName(t *testing.T) { func TestUser_GetFullName(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
user User user User
expected string expected string
}{ }{
{ {
name: "first and last name", name: "first and last name",
@@ -82,136 +47,7 @@ func TestUser_TableName(t *testing.T) {
assert.Equal(t, "auth_user", user.TableName()) assert.Equal(t, "auth_user", user.TableName())
} }
func TestAuthToken_TableName(t *testing.T) {
token := AuthToken{}
assert.Equal(t, "user_authtoken", token.TableName())
}
func TestUserProfile_TableName(t *testing.T) { func TestUserProfile_TableName(t *testing.T) {
profile := UserProfile{} profile := UserProfile{}
assert.Equal(t, "user_userprofile", profile.TableName()) assert.Equal(t, "user_userprofile", profile.TableName())
} }
func TestConfirmationCode_IsValid(t *testing.T) {
now := time.Now().UTC()
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
name string
code ConfirmationCode
expected bool
}{
{
name: "valid code",
code: ConfirmationCode{IsUsed: false, ExpiresAt: future},
expected: true,
},
{
name: "used code",
code: ConfirmationCode{IsUsed: true, ExpiresAt: future},
expected: false,
},
{
name: "expired code",
code: ConfirmationCode{IsUsed: false, ExpiresAt: past},
expected: false,
},
{
name: "used and expired",
code: ConfirmationCode{IsUsed: true, ExpiresAt: past},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.code.IsValid()
assert.Equal(t, tt.expected, result)
})
}
}
func TestPasswordResetCode_IsValid(t *testing.T) {
now := time.Now().UTC()
future := now.Add(1 * time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
name string
code PasswordResetCode
expected bool
}{
{
name: "valid code",
code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 0, MaxAttempts: 5},
expected: true,
},
{
name: "used code",
code: PasswordResetCode{Used: true, ExpiresAt: future, Attempts: 0, MaxAttempts: 5},
expected: false,
},
{
name: "expired code",
code: PasswordResetCode{Used: false, ExpiresAt: past, Attempts: 0, MaxAttempts: 5},
expected: false,
},
{
name: "max attempts reached",
code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 5, MaxAttempts: 5},
expected: false,
},
{
name: "attempts under max",
code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 4, MaxAttempts: 5},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.code.IsValid()
assert.Equal(t, tt.expected, result)
})
}
}
func TestPasswordResetCode_SetAndCheckCode(t *testing.T) {
code := &PasswordResetCode{}
err := code.SetCode("123456")
require.NoError(t, err)
assert.NotEmpty(t, code.CodeHash)
// Check correct code
assert.True(t, code.CheckCode("123456"))
// Check wrong code
assert.False(t, code.CheckCode("654321"))
assert.False(t, code.CheckCode(""))
}
func TestGenerateConfirmationCode(t *testing.T) {
code := GenerateConfirmationCode()
assert.Len(t, code, 6)
// Generate multiple codes and ensure they're different
codes := make(map[string]bool)
for i := 0; i < 10; i++ {
c := GenerateConfirmationCode()
assert.Len(t, c, 6)
codes[c] = true
}
// Most codes should be unique (very unlikely to have collisions)
assert.Greater(t, len(codes), 5)
}
func TestGenerateResetToken(t *testing.T) {
token := GenerateResetToken()
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
// Ensure uniqueness
token2 := GenerateResetToken()
assert.NotEqual(t, token, token2)
}
+21 -349
View File
@@ -11,18 +11,21 @@ import (
"github.com/treytartt/honeydue-api/internal/models" "github.com/treytartt/honeydue-api/internal/models"
) )
// FindByKratosID finds a user by Kratos identity UUID.
func (r *UserRepository) FindByKratosID(kratosID string) (*models.User, error) {
var user models.User
if err := r.db.Where("kratos_id = ?", kratosID).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
var ( var (
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
ErrUserExists = errors.New("user already exists") ErrUserExists = errors.New("user already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenNotFound = errors.New("token not found")
ErrCodeNotFound = errors.New("code not found")
ErrCodeExpired = errors.New("code expired")
ErrCodeUsed = errors.New("code already used")
ErrTooManyAttempts = errors.New("too many attempts")
ErrRateLimitExceeded = errors.New("rate limit exceeded")
ErrAppleAuthNotFound = errors.New("apple social auth not found")
ErrGoogleAuthNotFound = errors.New("google social auth not found")
) )
// UserRepository handles user-related database operations // UserRepository handles user-related database operations
@@ -145,111 +148,6 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
return count > 0, nil return count > 0, nil
} }
// --- Auth Token Methods ---
// GetOrCreateToken gets or creates an auth token for a user.
// Wrapped in a transaction to prevent race conditions where two
// concurrent requests could create duplicate tokens for the same user.
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
var token models.AuthToken
err := r.db.Transaction(func(tx *gorm.DB) error {
result := tx.Where("user_id = ?", userID).First(&token)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
token = models.AuthToken{UserID: userID}
if err := tx.Create(&token).Error; err != nil {
return err
}
} else if result.Error != nil {
return result.Error
}
return nil
})
if err != nil {
return nil, err
}
return &token, nil
}
// FindTokenByKey looks up an auth token by its raw key value. The raw token
// is hashed (audit C1) before the indexed lookup, since only the hash is
// stored.
func (r *UserRepository) FindTokenByKey(rawKey string) (*models.AuthToken, error) {
var token models.AuthToken
if err := r.db.Where("key = ?", models.HashToken(rawKey)).First(&token).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTokenNotFound
}
return nil, err
}
return &token, nil
}
// CreateToken creates a new auth token for a user.
func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
token := models.AuthToken{UserID: userID}
if err := r.db.Create(&token).Error; err != nil {
return nil, err
}
return &token, nil
}
// CreateFreshToken issues a new auth token for the user, replacing any
// existing one. Because tokens are stored hashed (audit C1) the server
// cannot re-issue a previously-minted token's plaintext, so every login
// mints a fresh token. The returned token's Plaintext field carries the
// raw value to hand to the client; it is never persisted.
//
// It also returns the stored hashes of the token rows it deleted, so the
// caller can evict those entries from the Redis token cache (audit MEDIUM-1).
// Without that, a prior (e.g. stolen) token keeps authenticating via a cache
// hit for up to the cache TTL even though its DB row is gone.
func (r *UserRepository) CreateFreshToken(userID uint) (*models.AuthToken, []string, error) {
var token models.AuthToken
var oldHashes []string
err := r.db.Transaction(func(tx *gorm.DB) error {
var old []models.AuthToken
if err := tx.Where("user_id = ?", userID).Find(&old).Error; err != nil {
return err
}
oldHashes = make([]string, 0, len(old))
for i := range old {
if old[i].Key != "" {
oldHashes = append(oldHashes, old[i].Key)
}
}
if err := tx.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
return err
}
token = models.AuthToken{UserID: userID}
return tx.Create(&token).Error
})
if err != nil {
return nil, nil, err
}
return &token, oldHashes, nil
}
// DeleteToken deletes an auth token by its raw key value. The raw token is
// hashed (audit C1) before the lookup, since only the hash is stored.
func (r *UserRepository) DeleteToken(token string) error {
result := r.db.Where("key = ?", models.HashToken(token)).Delete(&models.AuthToken{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// DeleteTokenByUserID deletes an auth token by user ID
func (r *UserRepository) DeleteTokenByUserID(userID uint) error {
return r.db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error
}
// --- User Profile Methods --- // --- User Profile Methods ---
@@ -280,146 +178,6 @@ func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error {
return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error
} }
// --- Confirmation Code Methods ---
// CreateConfirmationCode creates a new confirmation code
func (r *UserRepository) CreateConfirmationCode(userID uint, code string, expiresAt time.Time) (*models.ConfirmationCode, error) {
// Invalidate any existing unused codes for this user
r.db.Model(&models.ConfirmationCode{}).
Where("user_id = ? AND is_used = ?", userID, false).
Update("is_used", true)
confirmCode := &models.ConfirmationCode{
UserID: userID,
Code: code,
ExpiresAt: expiresAt,
IsUsed: false,
}
if err := r.db.Create(confirmCode).Error; err != nil {
return nil, err
}
return confirmCode, nil
}
// FindConfirmationCode finds a valid confirmation code for a user
func (r *UserRepository) FindConfirmationCode(userID uint, code string) (*models.ConfirmationCode, error) {
var confirmCode models.ConfirmationCode
if err := r.db.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false).
First(&confirmCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCodeNotFound
}
return nil, err
}
if !confirmCode.IsValid() {
if confirmCode.IsUsed {
return nil, ErrCodeUsed
}
return nil, ErrCodeExpired
}
return &confirmCode, nil
}
// MarkConfirmationCodeUsed marks a confirmation code as used
func (r *UserRepository) MarkConfirmationCodeUsed(codeID uint) error {
return r.db.Model(&models.ConfirmationCode{}).Where("id = ?", codeID).Update("is_used", true).Error
}
// --- Password Reset Code Methods ---
// CreatePasswordResetCode creates a new password reset code
func (r *UserRepository) CreatePasswordResetCode(userID uint, codeHash string, resetToken string, expiresAt time.Time) (*models.PasswordResetCode, error) {
// Invalidate any existing unused codes for this user
r.db.Model(&models.PasswordResetCode{}).
Where("user_id = ? AND used = ?", userID, false).
Update("used", true)
resetCode := &models.PasswordResetCode{
UserID: userID,
CodeHash: codeHash,
ResetToken: resetToken,
ExpiresAt: expiresAt,
Used: false,
Attempts: 0,
MaxAttempts: 5,
}
if err := r.db.Create(resetCode).Error; err != nil {
return nil, err
}
return resetCode, nil
}
// FindPasswordResetCode finds a password reset code by email and checks validity
func (r *UserRepository) FindPasswordResetCodeByEmail(email string) (*models.PasswordResetCode, *models.User, error) {
user, err := r.FindByEmail(email)
if err != nil {
return nil, nil, err
}
var resetCode models.PasswordResetCode
if err := r.db.Where("user_id = ? AND used = ?", user.ID, false).
Order("created_at DESC").
First(&resetCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrCodeNotFound
}
return nil, nil, err
}
return &resetCode, user, nil
}
// FindPasswordResetCodeByToken finds a password reset code by reset token
func (r *UserRepository) FindPasswordResetCodeByToken(resetToken string) (*models.PasswordResetCode, error) {
var resetCode models.PasswordResetCode
if err := r.db.Where("reset_token = ?", resetToken).First(&resetCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCodeNotFound
}
return nil, err
}
if !resetCode.IsValid() {
if resetCode.Used {
return nil, ErrCodeUsed
}
if resetCode.Attempts >= resetCode.MaxAttempts {
return nil, ErrTooManyAttempts
}
return nil, ErrCodeExpired
}
return &resetCode, nil
}
// IncrementResetCodeAttempts increments the attempt counter
func (r *UserRepository) IncrementResetCodeAttempts(codeID uint) error {
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).
Update("attempts", gorm.Expr("attempts + 1")).Error
}
// MarkPasswordResetCodeUsed marks a password reset code as used
func (r *UserRepository) MarkPasswordResetCodeUsed(codeID uint) error {
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).Update("used", true).Error
}
// CountRecentPasswordResetRequests counts reset requests in the last hour
func (r *UserRepository) CountRecentPasswordResetRequests(userID uint) (int64, error) {
var count int64
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
if err := r.db.Model(&models.PasswordResetCode{}).
Where("user_id = ? AND created_at > ?", userID, oneHourAgo).
Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
// --- Search Methods --- // --- Search Methods ---
@@ -576,27 +334,11 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
return profiles, err return profiles, err
} }
// --- Auth Provider Detection --- // FindAuthProvider returns "kratos" for all Kratos-managed users (the sole
// provider after the Ory Kratos migration). Kept for compatibility with
// FindAuthProvider determines the auth provider for a user. // callers that still check the provider string.
// Returns "apple", "google", or "email". func (r *UserRepository) FindAuthProvider(_ uint) (string, error) {
func (r *UserRepository) FindAuthProvider(userID uint) (string, error) { return "kratos", nil
var count int64
if err := r.db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
return "", err
}
if count > 0 {
return "apple", nil
}
if err := r.db.Model(&models.GoogleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
return "", err
}
if count > 0 {
return "google", nil
}
return "email", nil
} }
// --- Account Deletion --- // --- Account Deletion ---
@@ -721,35 +463,12 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
return nil, err return nil, err
} }
// 8. Social auth records // 8. User profile
if err := db.Where("user_id = ?", userID).Delete(&models.AppleSocialAuth{}).Error; err != nil {
return nil, err
}
if err := db.Where("user_id = ?", userID).Delete(&models.GoogleSocialAuth{}).Error; err != nil {
return nil, err
}
// 9. Confirmation codes
if err := db.Where("user_id = ?", userID).Delete(&models.ConfirmationCode{}).Error; err != nil {
return nil, err
}
// 10. Password reset codes
if err := db.Where("user_id = ?", userID).Delete(&models.PasswordResetCode{}).Error; err != nil {
return nil, err
}
// 11. Auth tokens
if err := db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
return nil, err
}
// 12. User profile
if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil { if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil {
return nil, err return nil, err
} }
// 13. User // 9. User
if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil { if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil {
return nil, err return nil, err
} }
@@ -765,53 +484,6 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
return cleanURLs, nil return cleanURLs, nil
} }
// --- Apple Social Auth Methods ---
// FindByAppleID finds an Apple social auth by Apple ID
func (r *UserRepository) FindByAppleID(appleID string) (*models.AppleSocialAuth, error) {
var auth models.AppleSocialAuth
if err := r.db.Where("apple_id = ?", appleID).First(&auth).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAppleAuthNotFound
}
return nil, err
}
return &auth, nil
}
// CreateAppleSocialAuth creates a new Apple social auth record
func (r *UserRepository) CreateAppleSocialAuth(auth *models.AppleSocialAuth) error {
return r.db.Create(auth).Error
}
// UpdateAppleSocialAuth updates an Apple social auth record
func (r *UserRepository) UpdateAppleSocialAuth(auth *models.AppleSocialAuth) error {
return r.db.Save(auth).Error
}
// --- Google Social Auth Methods ---
// FindByGoogleID finds a Google social auth by Google ID
func (r *UserRepository) FindByGoogleID(googleID string) (*models.GoogleSocialAuth, error) {
var auth models.GoogleSocialAuth
if err := r.db.Where("google_id = ?", googleID).First(&auth).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGoogleAuthNotFound
}
return nil, err
}
return &auth, nil
}
// CreateGoogleSocialAuth creates a new Google social auth record
func (r *UserRepository) CreateGoogleSocialAuth(auth *models.GoogleSocialAuth) error {
return r.db.Create(auth).Error
}
// UpdateGoogleSocialAuth updates a Google social auth record
func (r *UserRepository) UpdateGoogleSocialAuth(auth *models.GoogleSocialAuth) error {
return r.db.Save(auth).Error
}
// WithContext returns a copy of the repository whose underlying *gorm.DB carries // WithContext returns a copy of the repository whose underlying *gorm.DB carries
// the supplied context. SQL emitted via this copy gets attached to ctx's trace span // the supplied context. SQL emitted via this copy gets attached to ctx's trace span
@@ -2,7 +2,6 @@ package repositories
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -78,99 +77,25 @@ func TestUserRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
assert.True(t, exists) assert.True(t, exists)
} }
func TestUserRepository_GetOrCreateToken(t *testing.T) { func TestUserRepository_FindByKratosID(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
repo := NewUserRepository(db) repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") user := testutil.CreateTestUser(t, db, "kratosuser", "kratos@example.com", "")
// Create token found, err := repo.FindByKratosID(user.KratosID)
token1, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, token1.Key) assert.Equal(t, user.ID, found.ID)
assert.Equal(t, user.KratosID, found.KratosID)
// Should return same token
token2, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
assert.Equal(t, token1.Key, token2.Key)
} }
func TestUserRepository_FindTokenByKey(t *testing.T) { func TestUserRepository_FindByKratosID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
repo := NewUserRepository(db) repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123") _, err := repo.FindByKratosID("nonexistent-kratos-id")
token, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
found, err := repo.FindTokenByKey(token.Plaintext)
require.NoError(t, err)
assert.Equal(t, token.Key, found.Key)
assert.Equal(t, user.ID, found.UserID)
}
func TestUserRepository_FindTokenByKey_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindTokenByKey("nonexistent-token-key")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrTokenNotFound) assert.ErrorIs(t, err, ErrUserNotFound)
}
func TestUserRepository_DeleteToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
token, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
err = repo.DeleteToken(token.Plaintext)
require.NoError(t, err)
_, err = repo.FindTokenByKey(token.Plaintext)
assert.ErrorIs(t, err, ErrTokenNotFound)
}
func TestUserRepository_DeleteToken_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
err := repo.DeleteToken("nonexistent-key")
assert.ErrorIs(t, err, ErrTokenNotFound)
}
func TestUserRepository_DeleteTokenByUserID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
_, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
err = repo.DeleteTokenByUserID(user.ID)
require.NoError(t, err)
// Token should be gone
var count int64
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count)
}
func TestUserRepository_CreateToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
token, err := repo.CreateToken(user.ID)
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Equal(t, user.ID, token.UserID)
} }
func TestUserRepository_UpdateLastLogin(t *testing.T) { func TestUserRepository_UpdateLastLogin(t *testing.T) {
@@ -255,54 +180,6 @@ func TestUserRepository_FindByIDWithProfile_NotFound(t *testing.T) {
assert.ErrorIs(t, err, ErrUserNotFound) assert.ErrorIs(t, err, ErrUserNotFound)
} }
func TestUserRepository_ConfirmationCode_Lifecycle(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create confirmation code
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreateConfirmationCode(user.ID, "123456", expiresAt)
require.NoError(t, err)
assert.NotZero(t, code.ID)
// Find it
found, err := repo.FindConfirmationCode(user.ID, "123456")
require.NoError(t, err)
assert.Equal(t, code.ID, found.ID)
// Mark as used
err = repo.MarkConfirmationCodeUsed(code.ID)
require.NoError(t, err)
// Should not find used code
_, err = repo.FindConfirmationCode(user.ID, "123456")
assert.Error(t, err)
}
func TestUserRepository_ConfirmationCode_InvalidatesExisting(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
// Create first code
code1, err := repo.CreateConfirmationCode(user.ID, "111111", expiresAt)
require.NoError(t, err)
// Create second code (should invalidate first)
_, err = repo.CreateConfirmationCode(user.ID, "222222", expiresAt)
require.NoError(t, err)
// First code should be used/invalidated
var c models.ConfirmationCode
db.First(&c, code1.ID)
assert.True(t, c.IsUsed)
}
func TestUserRepository_Transaction(t *testing.T) { func TestUserRepository_Transaction(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
repo := NewUserRepository(db) repo := NewUserRepository(db)
@@ -331,105 +208,6 @@ func TestUserRepository_DB(t *testing.T) {
assert.NotNil(t, repo.DB()) assert.NotNil(t, repo.DB())
} }
func TestUserRepository_FindByAppleID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
appleAuth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_123",
Email: "apple@test.com",
}
require.NoError(t, db.Create(appleAuth).Error)
found, err := repo.FindByAppleID("apple_sub_123")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestUserRepository_FindByAppleID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByAppleID("nonexistent_apple_id")
assert.ErrorIs(t, err, ErrAppleAuthNotFound)
}
func TestUserRepository_FindByGoogleID(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
googleAuth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_123",
Email: "google@test.com",
}
require.NoError(t, db.Create(googleAuth).Error)
found, err := repo.FindByGoogleID("google_sub_123")
require.NoError(t, err)
assert.Equal(t, user.ID, found.UserID)
}
func TestUserRepository_FindByGoogleID_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindByGoogleID("nonexistent_google_id")
assert.ErrorIs(t, err, ErrGoogleAuthNotFound)
}
func TestUserRepository_CreateAndUpdateAppleSocialAuth(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
auth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_456",
Email: "apple@test.com",
}
err := repo.CreateAppleSocialAuth(auth)
require.NoError(t, err)
assert.NotZero(t, auth.ID)
auth.Email = "updated@test.com"
err = repo.UpdateAppleSocialAuth(auth)
require.NoError(t, err)
found, err := repo.FindByAppleID("apple_sub_456")
require.NoError(t, err)
assert.Equal(t, "updated@test.com", found.Email)
}
func TestUserRepository_CreateAndUpdateGoogleSocialAuth(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
auth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_456",
Email: "google@test.com",
Name: "Test User",
}
err := repo.CreateGoogleSocialAuth(auth)
require.NoError(t, err)
assert.NotZero(t, auth.ID)
auth.Name = "Updated Name"
err = repo.UpdateGoogleSocialAuth(auth)
require.NoError(t, err)
found, err := repo.FindByGoogleID("google_sub_456")
require.NoError(t, err)
assert.Equal(t, "Updated Name", found.Name)
}
func TestUserRepository_SearchUsers(t *testing.T) { func TestUserRepository_SearchUsers(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
repo := NewUserRepository(db) repo := NewUserRepository(db)
@@ -2,7 +2,6 @@ package repositories
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -11,207 +10,6 @@ import (
"github.com/treytartt/honeydue-api/internal/testutil" "github.com/treytartt/honeydue-api/internal/testutil"
) )
// === Password Reset Code Lifecycle ===
func TestUserRepository_PasswordResetCode_Lifecycle(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_abc123", "reset_token_xyz", expiresAt)
require.NoError(t, err)
assert.NotZero(t, code.ID)
assert.Equal(t, "hash_abc123", code.CodeHash)
assert.Equal(t, "reset_token_xyz", code.ResetToken)
assert.False(t, code.Used)
assert.Equal(t, 0, code.Attempts)
}
func TestUserRepository_CreatePasswordResetCode_InvalidatesExisting(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code1, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
require.NoError(t, err)
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
require.NoError(t, err)
// First code should be marked as used
var c models.PasswordResetCode
db.First(&c, code1.ID)
assert.True(t, c.Used)
}
func TestUserRepository_FindPasswordResetCodeByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash_abc", "token_abc", expiresAt)
require.NoError(t, err)
found, foundUser, err := repo.FindPasswordResetCodeByEmail("test@example.com")
require.NoError(t, err)
assert.Equal(t, user.ID, foundUser.ID)
assert.Equal(t, "hash_abc", found.CodeHash)
}
func TestUserRepository_FindPasswordResetCodeByEmail_UserNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, _, err := repo.FindPasswordResetCodeByEmail("nonexistent@example.com")
assert.Error(t, err)
}
func TestUserRepository_FindPasswordResetCodeByEmail_NoCode(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
_, _, err := repo.FindPasswordResetCodeByEmail("test@example.com")
assert.ErrorIs(t, err, ErrCodeNotFound)
}
func TestUserRepository_FindPasswordResetCodeByToken(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash_xyz", "token_xyz", expiresAt)
require.NoError(t, err)
found, err := repo.FindPasswordResetCodeByToken("token_xyz")
require.NoError(t, err)
assert.Equal(t, "hash_xyz", found.CodeHash)
}
func TestUserRepository_FindPasswordResetCodeByToken_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
_, err := repo.FindPasswordResetCodeByToken("nonexistent_token")
assert.ErrorIs(t, err, ErrCodeNotFound)
}
func TestUserRepository_FindPasswordResetCodeByToken_Expired(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Already expired
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash_exp", "token_exp", expiresAt)
require.NoError(t, err)
_, err = repo.FindPasswordResetCodeByToken("token_exp")
assert.ErrorIs(t, err, ErrCodeExpired)
}
func TestUserRepository_FindPasswordResetCodeByToken_Used(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_used", "token_used", expiresAt)
require.NoError(t, err)
// Mark as used
err = repo.MarkPasswordResetCodeUsed(code.ID)
require.NoError(t, err)
_, err = repo.FindPasswordResetCodeByToken("token_used")
assert.ErrorIs(t, err, ErrCodeUsed)
}
func TestUserRepository_FindPasswordResetCodeByToken_TooManyAttempts(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_attempts", "token_attempts", expiresAt)
require.NoError(t, err)
// Max out attempts
for i := 0; i < 5; i++ {
err = repo.IncrementResetCodeAttempts(code.ID)
require.NoError(t, err)
}
_, err = repo.FindPasswordResetCodeByToken("token_attempts")
assert.ErrorIs(t, err, ErrTooManyAttempts)
}
func TestUserRepository_IncrementResetCodeAttempts(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_inc", "token_inc", expiresAt)
require.NoError(t, err)
err = repo.IncrementResetCodeAttempts(code.ID)
require.NoError(t, err)
var updated models.PasswordResetCode
db.First(&updated, code.ID)
assert.Equal(t, 1, updated.Attempts)
}
func TestUserRepository_MarkPasswordResetCodeUsed(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
code, err := repo.CreatePasswordResetCode(user.ID, "hash_mark", "token_mark", expiresAt)
require.NoError(t, err)
err = repo.MarkPasswordResetCodeUsed(code.ID)
require.NoError(t, err)
var updated models.PasswordResetCode
db.First(&updated, code.ID)
assert.True(t, updated.Used)
}
func TestUserRepository_CountRecentPasswordResetRequests(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
expiresAt := time.Now().UTC().Add(1 * time.Hour)
_, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
require.NoError(t, err)
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
require.NoError(t, err)
count, err := repo.CountRecentPasswordResetRequests(user.ID)
require.NoError(t, err)
assert.Equal(t, int64(2), count)
}
// === FindUsersInSharedResidences === // === FindUsersInSharedResidences ===
func TestUserRepository_FindUsersInSharedResidences(t *testing.T) { func TestUserRepository_FindUsersInSharedResidences(t *testing.T) {
@@ -301,33 +99,6 @@ func TestUserRepository_FindProfilesInSharedResidences(t *testing.T) {
assert.Len(t, profiles, 2) assert.Len(t, profiles, 2)
} }
// === ConfirmationCode Expired ===
func TestUserRepository_FindConfirmationCode_Expired(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
// Create already-expired code
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
_, err := repo.CreateConfirmationCode(user.ID, "999999", expiresAt)
require.NoError(t, err)
_, err = repo.FindConfirmationCode(user.ID, "999999")
assert.ErrorIs(t, err, ErrCodeExpired)
}
func TestUserRepository_FindConfirmationCode_NotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
repo := NewUserRepository(db)
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
_, err := repo.FindConfirmationCode(user.ID, "000000")
assert.ErrorIs(t, err, ErrCodeNotFound)
}
// === Transaction Rollback === // === Transaction Rollback ===
func TestUserRepository_Transaction_Rollback(t *testing.T) { func TestUserRepository_Transaction_Rollback(t *testing.T) {
+3 -38
View File
@@ -19,7 +19,6 @@ func TestUserRepository_Create(t *testing.T) {
Email: "test@example.com", Email: "test@example.com",
IsActive: true, IsActive: true,
} }
user.SetPassword("Password123")
err := repo.Create(user) err := repo.Create(user)
require.NoError(t, err) require.NoError(t, err)
@@ -192,39 +191,11 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
repo := NewUserRepository(db) repo := NewUserRepository(db)
t.Run("email user", func(t *testing.T) { t.Run("kratos user", func(t *testing.T) {
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123") user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123")
provider, err := repo.FindAuthProvider(user.ID) provider, err := repo.FindAuthProvider(user.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "email", provider) assert.Equal(t, "kratos", provider) // All users are Kratos-managed
})
t.Run("apple user", func(t *testing.T) {
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
appleAuth := &models.AppleSocialAuth{
UserID: user.ID,
AppleID: "apple_sub_test",
Email: "apple@test.com",
}
require.NoError(t, db.Create(appleAuth).Error)
provider, err := repo.FindAuthProvider(user.ID)
require.NoError(t, err)
assert.Equal(t, "apple", provider)
})
t.Run("google user", func(t *testing.T) {
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
googleAuth := &models.GoogleSocialAuth{
UserID: user.ID,
GoogleID: "google_sub_test",
Email: "google@test.com",
}
require.NoError(t, db.Create(googleAuth).Error)
provider, err := repo.FindAuthProvider(user.ID)
require.NoError(t, err)
assert.Equal(t, "google", provider)
}) })
} }
@@ -235,11 +206,9 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123") user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123")
// Create profile and token // Create profile
profile := &models.UserProfile{UserID: user.ID, Verified: true} profile := &models.UserProfile{UserID: user.ID, Verified: true}
require.NoError(t, db.Create(profile).Error) require.NoError(t, db.Create(profile).Error)
_, err := models.GetOrCreateToken(db, user.ID)
require.NoError(t, err)
var fileURLs []string var fileURLs []string
txErr := repo.Transaction(func(txRepo *UserRepository) error { txErr := repo.Transaction(func(txRepo *UserRepository) error {
@@ -261,10 +230,6 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
// Verify profile is gone // Verify profile is gone
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count) db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count) assert.Equal(t, int64(0), count)
// Verify token is gone
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
assert.Equal(t, int64(0), count)
}) })
t.Run("returns file URLs for cleanup", func(t *testing.T) { t.Run("returns file URLs for cleanup", func(t *testing.T) {
+14 -50
View File
@@ -22,6 +22,7 @@ import (
"github.com/treytartt/honeydue-api/internal/dto/responses" "github.com/treytartt/honeydue-api/internal/dto/responses"
"github.com/treytartt/honeydue-api/internal/handlers" "github.com/treytartt/honeydue-api/internal/handlers"
"github.com/treytartt/honeydue-api/internal/i18n" "github.com/treytartt/honeydue-api/internal/i18n"
"github.com/treytartt/honeydue-api/internal/kratos"
custommiddleware "github.com/treytartt/honeydue-api/internal/middleware" custommiddleware "github.com/treytartt/honeydue-api/internal/middleware"
"github.com/treytartt/honeydue-api/internal/monitoring" "github.com/treytartt/honeydue-api/internal/monitoring"
"github.com/treytartt/honeydue-api/internal/prom" "github.com/treytartt/honeydue-api/internal/prom"
@@ -200,7 +201,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// Initialize services // Initialize services
authService := services.NewAuthService(userRepo, cfg) authService := services.NewAuthService(userRepo, cfg)
authService.SetNotificationRepository(notificationRepo) // For creating notification preferences on registration authService.SetNotificationRepository(notificationRepo)
userService := services.NewUserService(userRepo) userService := services.NewUserService(userRepo)
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg) residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
residenceService.SetTaskRepository(taskRepo) // Wire up task repo for statistics residenceService.SetTaskRepository(taskRepo) // Wire up task repo for statistics
@@ -220,7 +221,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// Wire Redis cache for residence-ID lookups across the four services that // Wire Redis cache for residence-ID lookups across the four services that
// read it on the request hot path. Cache is best-effort; nil cache is OK. // read it on the request hot path. Cache is best-effort; nil cache is OK.
if deps.Cache != nil { if deps.Cache != nil {
authService.SetCacheService(deps.Cache) // per-account login lockout (audit M5) authService.SetCacheService(deps.Cache)
residenceService.SetCacheService(deps.Cache) residenceService.SetCacheService(deps.Cache)
taskService.SetCacheService(deps.Cache) taskService.SetCacheService(deps.Cache)
contractorService.SetCacheService(deps.Cache) contractorService.SetCacheService(deps.Cache)
@@ -244,20 +245,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
subscriptionWebhookHandler.SetStripeService(stripeService) subscriptionWebhookHandler.SetStripeService(stripeService)
subscriptionWebhookHandler.SetCacheService(deps.Cache) subscriptionWebhookHandler.SetCacheService(deps.Cache)
// Initialize middleware // Initialize Kratos auth middleware (replaces hand-rolled token auth).
authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg) kratosClient := kratos.NewClient(cfg.Security.KratosPublicURL)
authMiddleware := custommiddleware.NewKratosAuth(kratosClient, deps.Cache, deps.DB)
// Initialize Apple auth service
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg)
// Initialize audit service for security event logging // Initialize audit service for security event logging
auditService := services.NewAuditService(deps.DB) auditService := services.NewAuditService(deps.DB)
// Initialize handlers // Initialize handlers
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache) authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
authHandler.SetAppleAuthService(appleAuthService)
authHandler.SetGoogleAuthService(googleAuthService)
authHandler.SetStorageService(deps.StorageService) authHandler.SetStorageService(deps.StorageService)
authHandler.SetAuditService(auditService) authHandler.SetAuditService(auditService)
userHandler := handlers.NewUserHandler(userService) userHandler := handlers.NewUserHandler(userService)
@@ -318,8 +314,8 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// API group // API group
api := e.Group("/api") api := e.Group("/api")
{ {
// Public auth routes (no auth required) // Session lifecycle (login, register, logout, password reset) is
setupPublicAuthRoutes(api, authHandler, cfg.Server.Debug) // handled by Ory Kratos — no public auth routes in this service.
// Public data routes (no auth required) // Public data routes (no auth required)
setupPublicDataRoutes(api, residenceHandler, taskHandler, contractorHandler, staticDataHandler, subscriptionHandler, taskTemplateHandler) setupPublicDataRoutes(api, residenceHandler, taskHandler, contractorHandler, staticDataHandler, subscriptionHandler, taskTemplateHandler)
@@ -329,7 +325,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// Protected routes (auth required) // Protected routes (auth required)
protected := api.Group("") protected := api.Group("")
protected.Use(authMiddleware.TokenAuth()) protected.Use(authMiddleware.Authenticate())
protected.Use(custommiddleware.TimezoneMiddleware()) protected.Use(custommiddleware.TimezoneMiddleware())
{ {
setupProtectedAuthRoutes(protected, authHandler) setupProtectedAuthRoutes(protected, authHandler)
@@ -516,50 +512,18 @@ func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc {
} }
} }
// setupPublicAuthRoutes configures public authentication routes with // setupPublicAuthRoutes was removed — session lifecycle (login, register,
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing. // logout, password reset, Apple/Google sign-in) is delegated to Ory Kratos.
// Rate limiters are disabled in debug mode to allow UI test suites to run
// without hitting 429 errors.
func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler, debug bool) {
auth := api.Group("/auth")
if debug { // setupProtectedAuthRoutes configures protected auth routes.
// No rate limiters in debug/local mode // Session lifecycle (login, logout, password reset, email verification) is
auth.POST("/login/", authHandler.Login) // delegated to Ory Kratos — only profile and account-deletion routes remain.
auth.POST("/register/", authHandler.Register)
auth.POST("/forgot-password/", authHandler.ForgotPassword)
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
auth.POST("/reset-password/", authHandler.ResetPassword)
auth.POST("/apple-sign-in/", authHandler.AppleSignIn)
auth.POST("/google-sign-in/", authHandler.GoogleSignIn)
} else {
// Rate limiters — created once, shared across requests.
loginRL := custommiddleware.LoginRateLimiter() // 10 req/min
registerRL := custommiddleware.RegistrationRateLimiter() // 5 req/min
passwordRL := custommiddleware.PasswordResetRateLimiter() // 3 req/min
auth.POST("/login/", authHandler.Login, loginRL)
auth.POST("/register/", authHandler.Register, registerRL)
auth.POST("/forgot-password/", authHandler.ForgotPassword, passwordRL)
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode, passwordRL)
auth.POST("/reset-password/", authHandler.ResetPassword, passwordRL)
auth.POST("/apple-sign-in/", authHandler.AppleSignIn, loginRL)
auth.POST("/google-sign-in/", authHandler.GoogleSignIn, loginRL)
}
}
// setupProtectedAuthRoutes configures protected authentication routes
func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) { func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) {
auth := api.Group("/auth") auth := api.Group("/auth")
{ {
auth.POST("/logout/", authHandler.Logout)
auth.POST("/refresh/", authHandler.RefreshToken)
auth.GET("/me/", authHandler.CurrentUser) auth.GET("/me/", authHandler.CurrentUser)
auth.PUT("/profile/", authHandler.UpdateProfile) auth.PUT("/profile/", authHandler.UpdateProfile)
auth.PATCH("/profile/", authHandler.UpdateProfile) auth.PATCH("/profile/", authHandler.UpdateProfile)
auth.POST("/verify/", authHandler.VerifyEmail) // Alias for mobile app compatibility
auth.POST("/verify-email/", authHandler.VerifyEmail) // Original route
auth.POST("/resend-verification/", authHandler.ResendVerification)
auth.DELETE("/account/", authHandler.DeleteAccount) auth.DELETE("/account/", authHandler.DeleteAccount)
} }
} }
-301
View File
@@ -1,301 +0,0 @@
package services
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/treytartt/honeydue-api/internal/config"
)
const (
appleKeysURL = "https://appleid.apple.com/auth/keys"
appleIssuer = "https://appleid.apple.com"
appleKeysCacheTTL = 24 * time.Hour
appleKeysCacheKey = "apple:public_keys"
)
var (
ErrInvalidAppleToken = errors.New("invalid Apple identity token")
ErrAppleTokenExpired = errors.New("Apple identity token has expired")
ErrInvalidAppleAudience = errors.New("invalid Apple token audience")
ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer")
ErrAppleKeyNotFound = errors.New("Apple public key not found")
)
// AppleJWKS represents Apple's JSON Web Key Set
type AppleJWKS struct {
Keys []AppleJWK `json:"keys"`
}
// AppleJWK represents a single JSON Web Key from Apple
type AppleJWK struct {
Kty string `json:"kty"` // Key type (RSA)
Kid string `json:"kid"` // Key ID
Use string `json:"use"` // Key use (sig)
Alg string `json:"alg"` // Algorithm (RS256)
N string `json:"n"` // RSA modulus
E string `json:"e"` // RSA exponent
}
// AppleTokenClaims represents the claims in an Apple identity token
type AppleTokenClaims struct {
jwt.RegisteredClaims
Email string `json:"email,omitempty"`
EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string
IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string
AuthTime int64 `json:"auth_time,omitempty"`
}
// IsEmailVerified returns whether the email is verified (handles both bool and string types)
func (c *AppleTokenClaims) IsEmailVerified() bool {
switch v := c.EmailVerified.(type) {
case bool:
return v
case string:
return v == "true"
default:
return false
}
}
// IsPrivateRelayEmail returns whether the email is a private relay email
func (c *AppleTokenClaims) IsPrivateRelayEmail() bool {
switch v := c.IsPrivateEmail.(type) {
case bool:
return v
case string:
return v == "true"
default:
return false
}
}
// AppleAuthService handles Apple Sign In token verification
type AppleAuthService struct {
cache *CacheService
config *config.Config
client *http.Client
}
// NewAppleAuthService creates a new Apple auth service
func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService {
return &AppleAuthService{
cache: cache,
config: cfg,
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// VerifyIdentityToken verifies an Apple identity token and returns the claims
func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) {
// Parse the token header to get the key ID
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, ErrInvalidAppleToken
}
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode token header: %w", err)
}
var header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse token header: %w", err)
}
// Get the public key for this key ID
publicKey, err := s.getPublicKey(ctx, header.Kid)
if err != nil {
return nil, err
}
// Parse and verify the token
token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrAppleTokenExpired
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*AppleTokenClaims)
if !ok || !token.Valid {
return nil, ErrInvalidAppleToken
}
// Verify the issuer
if claims.Issuer != appleIssuer {
return nil, ErrInvalidAppleIssuer
}
// Verify the audience (should be our bundle ID)
if !s.verifyAudience(claims.Audience) {
return nil, ErrInvalidAppleAudience
}
return claims, nil
}
// verifyAudience checks if the token audience matches our client ID.
// In production (non-debug), an empty clientID causes verification to fail
// rather than silently bypassing the check.
func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool {
clientID := s.config.AppleAuth.ClientID
if clientID == "" {
if s.config.Server.Debug {
// In debug mode only, skip audience verification for local development
return true
}
// In production, missing client ID means we cannot verify the audience
return false
}
for _, aud := range audience {
if aud == clientID {
return true
}
}
return false
}
// getPublicKey retrieves the public key for the given key ID
func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
// Try to get from cache first
keys, err := s.getCachedKeys(ctx)
if err != nil || keys == nil {
// Fetch fresh keys
keys, err = s.fetchApplePublicKeys(ctx)
if err != nil {
return nil, err
}
}
// Find the key with the matching ID
for keyID, pubKey := range keys {
if keyID == kid {
return pubKey, nil
}
}
// Key not found in cache, try fetching fresh keys
keys, err = s.fetchApplePublicKeys(ctx)
if err != nil {
return nil, err
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
return nil, ErrAppleKeyNotFound
}
// getCachedKeys retrieves cached Apple public keys from Redis
func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
if s.cache == nil {
return nil, nil
}
data, err := s.cache.GetString(ctx, appleKeysCacheKey)
if err != nil || data == "" {
return nil, nil
}
var jwks AppleJWKS
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
return nil, nil
}
return s.parseJWKS(&jwks)
}
// fetchApplePublicKeys fetches Apple's public keys and caches them
func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch Apple keys: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode)
}
var jwks AppleJWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode Apple keys: %w", err)
}
// Cache the keys
if s.cache != nil {
keysJSON, _ := json.Marshal(jwks)
_ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL)
}
return s.parseJWKS(&jwks)
}
// parseJWKS converts Apple's JWKS to RSA public keys
func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) {
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwks.Keys {
if key.Kty != "RSA" {
continue
}
// Decode the modulus (N)
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
continue
}
n := new(big.Int).SetBytes(nBytes)
// Decode the exponent (E)
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
continue
}
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
pubKey := &rsa.PublicKey{
N: n,
E: e,
}
keys[key.Kid] = pubKey
}
return keys, nil
}
-176
View File
@@ -1,176 +0,0 @@
package services
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
)
func setupRefreshTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
require.NoError(t, err)
return db
}
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
t.Helper()
user := &models.User{
Username: "refreshtest",
Email: "refresh@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("Password123"))
require.NoError(t, db.Create(user).Error)
return user
}
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
t.Helper()
token := &models.AuthToken{
UserID: userID,
}
require.NoError(t, db.Create(token).Error)
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
require.NoError(t, db.Model(token).Update("created", backdated).Error)
token.Created = backdated
return token
}
func newTestAuthService(db *gorm.DB) *AuthService {
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{
SecretKey: "test-secret",
TokenExpiryDays: 90,
TokenRefreshDays: 60,
},
}
return NewAuthService(userRepo, cfg)
}
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
assert.Contains(t, resp.Message, "still valid")
}
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
assert.Contains(t, resp.Message, "refreshed")
// Verify old token was deleted
var count int64
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
assert.Equal(t, int64(0), count, "old token should be deleted")
// Verify new token exists in DB
// resp.Token is the raw token; the DB stores its hash.
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
assert.Equal(t, int64(1), count, "new token should exist in DB")
// Verify new token belongs to the same user
var newToken models.AuthToken
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
assert.Equal(t, user.ID, newToken.UserID)
}
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
// meaning it enters the renewal window
token := createTokenWithAge(t, db, user.ID, 61)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
}
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 75)
svc := newTestAuthService(db)
// Try to refresh with a different user ID
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
}
File diff suppressed because it is too large Load Diff
+43 -666
View File
@@ -4,7 +4,6 @@ import (
"context" "context"
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -19,195 +18,18 @@ func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository)
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
notifRepo := repositories.NewNotificationRepository(db) notifRepo := repositories.NewNotificationRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Server: config.ServerConfig{
DebugFixedCodes: true,
},
Security: config.SecurityConfig{
SecretKey: "test-secret",
ConfirmationExpiry: 24 * time.Hour,
PasswordResetExpiry: 15 * time.Minute,
MaxPasswordResetRate: 3,
TokenExpiryDays: 90,
TokenRefreshDays: 60,
},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
service.SetNotificationRepository(notifRepo) service.SetNotificationRepository(notifRepo)
return service, userRepo return service, userRepo
} }
// === Login ===
func TestAuthService_Login(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "Password123",
}
resp, err := service.Login(context.Background(), req, "")
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "testuser", resp.User.Username)
}
func TestAuthService_Login_ByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Email: "test@test.com",
Password: "Password123",
}
resp, err := service.Login(context.Background(), req, "")
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
}
func TestAuthService_Login_InvalidCredentials(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "WrongPassword1",
}
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_Login_UserNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
req := &requests.LoginRequest{
Username: "nonexistent",
Password: "Password123",
}
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_Login_InactiveUser(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123")
// Deactivate
user.IsActive = false
db.Save(user)
req := &requests.LoginRequest{
Username: "inactive",
Password: "Password123",
}
_, err := service.Login(context.Background(), req, "")
// Audit L1: inactive accounts return the same generic error as bad
// credentials so login does not disclose which accounts exist.
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
// === Register ===
func TestAuthService_Register(t *testing.T) {
service, _ := setupAuthService(t)
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
resp, code, err := service.Register(context.Background(), req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "newuser", resp.User.Username)
assert.Equal(t, "123456", code) // DebugFixedCodes=true
}
func TestAuthService_Register_DuplicateUsername(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{DebugFixedCodes: true},
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123")
req := &requests.RegisterRequest{
Username: "taken",
Email: "different@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
}
func TestAuthService_Register_DuplicateEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{DebugFixedCodes: true},
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123")
req := &requests.RegisterRequest{
Username: "newuser",
Email: "taken@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
}
// === GetCurrentUser === // === GetCurrentUser ===
func TestAuthService_GetCurrentUser(t *testing.T) { func TestAuthService_GetCurrentUser(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -218,7 +40,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "testuser", resp.Username) assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@test.com", resp.Email) assert.Equal(t, "test@test.com", resp.Email)
assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth assert.Equal(t, "kratos", resp.AuthProvider) // All users are Kratos-managed
} }
// === UpdateProfile === // === UpdateProfile ===
@@ -226,9 +48,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
func TestAuthService_UpdateProfile(t *testing.T) { func TestAuthService_UpdateProfile(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -250,9 +70,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) { func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123") testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123")
@@ -271,9 +89,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
func TestAuthService_UpdateProfile_SameEmail(t *testing.T) { func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -290,443 +106,10 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
assert.Equal(t, "test@test.com", resp.Email) assert.Equal(t, "test@test.com", resp.Email)
} }
// === VerifyEmail ===
func TestAuthService_VerifyEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Register a user (creates confirmation code)
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
// Get the user ID
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
// Verify with the debug code
err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err)
// Verify again — should get already verified error
err = service.VerifyEmail(context.Background(), user.ID, "123456")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
service, _ := setupAuthService(t)
// Register
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
// but a wrong code should use the normal path
err = service.VerifyEmail(context.Background(), user.ID, "000000")
assert.Error(t, err)
}
// === ResendVerificationCode ===
func TestAuthService_ResendVerificationCode(t *testing.T) {
service, _ := setupAuthService(t)
// Register
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
code, err := service.ResendVerificationCode(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
}
func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
service, _ := setupAuthService(t)
// Register and verify
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err)
_, err = service.ResendVerificationCode(context.Background(), user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
// === ForgotPassword ===
func TestAuthService_ForgotPassword(t *testing.T) {
service, _ := setupAuthService(t)
// Register a user first
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
assert.NotNil(t, user)
assert.Equal(t, "test@test.com", user.Email)
}
func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Should not reveal that email doesn't exist
code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
require.NoError(t, err)
assert.Empty(t, code)
assert.Nil(t, user)
}
// === ResetPassword ===
func TestAuthService_ResetPassword(t *testing.T) {
service, _ := setupAuthService(t)
// Register
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
// Forgot password
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
// Verify reset code to get the token
resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
require.NoError(t, err)
assert.NotEmpty(t, resetToken)
// Reset password
err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
require.NoError(t, err)
// Login with new password
loginReq := &requests.LoginRequest{
Username: "testuser",
Password: "NewPassword123",
}
loginResp, err := service.Login(context.Background(), loginReq, "")
require.NoError(t, err)
assert.NotEmpty(t, loginResp.Token)
}
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
service, _ := setupAuthService(t)
err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
}
// === Logout ===
func TestAuthService_Logout(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
// Login first
loginReq := &requests.LoginRequest{
Username: "testuser",
Password: "Password123",
}
loginResp, err := service.Login(context.Background(), loginReq, "")
require.NoError(t, err)
// Logout
err = service.Logout(context.Background(), loginResp.Token)
require.NoError(t, err)
// Token should be deleted — refreshing should fail
_, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
assert.Error(t, err)
}
// === DeleteAccount ===
func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
service, _ := setupAuthService(t)
// Register
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
password := "Password123"
_, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
require.NoError(t, err)
}
func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
wrongPassword := "WrongPassword1"
_, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
_, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
}
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t)
password := "Password123"
_, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
}
// === Helper functions ===
func TestGenerateSixDigitCode(t *testing.T) {
code := generateSixDigitCode()
assert.Len(t, code, 6)
// Should be numeric
for _, c := range code {
assert.True(t, c >= '0' && c <= '9', "code should contain only digits")
}
}
func TestGenerateResetToken(t *testing.T) {
token := generateResetToken()
assert.NotEmpty(t, token)
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
}
func TestGetStringOrEmpty(t *testing.T) {
s := "hello"
assert.Equal(t, "hello", getStringOrEmpty(&s))
assert.Equal(t, "", getStringOrEmpty(nil))
}
func TestIsPrivateRelayEmail(t *testing.T) {
assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com"))
assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM"))
assert.False(t, isPrivateRelayEmail("user@gmail.com"))
}
func TestGetEmailFromRequest(t *testing.T) {
email := "req@test.com"
assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com"))
assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com"))
empty := ""
assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com"))
}
// === getEmailOrDefault ===
func TestGetEmailOrDefault(t *testing.T) {
// Non-empty email returns itself
assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com"))
// Empty email returns a generated placeholder
result := getEmailOrDefault("")
assert.Contains(t, result, "@privaterelay.appleid.com")
assert.Contains(t, result, "apple_")
}
// === generateUniqueUsername ===
func TestGenerateUniqueUsername(t *testing.T) {
// Normal email generates username from email prefix
username := generateUniqueUsername("john@test.com", nil)
assert.Contains(t, username, "john_")
// Private relay email falls back to first name
firstName := "Jane"
username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName)
assert.Contains(t, username, "jane_")
// Private relay email and no first name — fallback
username = generateUniqueUsername("abc@privaterelay.appleid.com", nil)
assert.Contains(t, username, "user_")
// Empty email with first name
firstName2 := "Bob"
username = generateUniqueUsername("", &firstName2)
assert.Contains(t, username, "bob_")
// Empty email and no first name
username = generateUniqueUsername("", nil)
assert.Contains(t, username, "user_")
}
// === generateGoogleUsername ===
func TestGenerateGoogleUsername(t *testing.T) {
// Normal email
username := generateGoogleUsername("john@gmail.com", "John")
assert.Contains(t, username, "john_")
// Empty email falls back to first name
username = generateGoogleUsername("", "Alice")
assert.Contains(t, username, "alice_")
// Empty email and empty first name — fallback
username = generateGoogleUsername("", "")
assert.Contains(t, username, "google_")
}
// === Login with empty password ===
func TestAuthService_Login_EmptyPassword(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "",
}
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
// === ForgotPassword rate limiting ===
func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
// Make max allowed reset requests (3 based on setup)
for i := 0; i < 3; i++ {
_, _, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
}
// The 4th should be rate limited
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
assert.Error(t, err)
}
// === VerifyResetCode with wrong code ===
func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
// Wrong code but with debug mode, "123456" works, "000000" should fail
_, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
assert.Error(t, err)
}
// === VerifyResetCode with nonexistent email ===
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
_, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
assert.Error(t, err)
}
// === UpdateProfile — change email to new email ===
func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) { func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123") user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -742,25 +125,44 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
assert.Equal(t, "newemail@test.com", resp.Email) assert.Equal(t, "newemail@test.com", resp.Email)
} }
// === DeleteAccount — empty password string === // === DeleteAccount ===
func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) { func TestAuthService_DeleteAccount_WithConfirmation(t *testing.T) {
service, userRepo := setupAuthService(t)
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
_ = user
confirmation := "DELETE"
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &confirmation)
require.NoError(t, err)
}
func TestAuthService_DeleteAccount_WrongConfirmation(t *testing.T) {
service, userRepo := setupAuthService(t)
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
wrongConf := "delete"
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &wrongConf)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
}
func TestAuthService_DeleteAccount_NoConfirmation(t *testing.T) {
service, userRepo := setupAuthService(t)
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
_, err := service.DeleteAccount(context.Background(), user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
}
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t) service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{ confirmation := "DELETE"
Username: "testuser", _, err := service.DeleteAccount(context.Background(), 99999, nil, &confirmation)
Email: "test@test.com", testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
emptyPw := ""
_, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
} }
// === SetNotificationRepository === // === SetNotificationRepository ===
@@ -769,35 +171,10 @@ func TestAuthService_SetNotificationRepository(t *testing.T) {
db := testutil.SetupTestDB(t) db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
notifRepo := repositories.NewNotificationRepository(db) notifRepo := repositories.NewNotificationRepository(db)
cfg := &config.Config{ cfg := &config.Config{}
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg) service := NewAuthService(userRepo, cfg)
assert.Nil(t, service.notificationRepo) assert.Nil(t, service.notificationRepo)
service.SetNotificationRepository(notifRepo) service.SetNotificationRepository(notifRepo)
assert.NotNil(t, service.notificationRepo) assert.NotNil(t, service.notificationRepo)
} }
// === Register creates profile and notification preferences ===
func TestAuthService_Register_CreatesProfile(t *testing.T) {
service, userRepo := setupAuthService(t)
req := &requests.RegisterRequest{
Username: "profileuser",
Email: "profile@test.com",
Password: "Password123",
FirstName: "John",
LastName: "Doe",
}
resp, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, "profileuser", resp.User.Username)
// Profile should exist
profile, err := userRepo.GetOrCreateProfile(resp.User.ID)
require.NoError(t, err)
assert.NotNil(t, profile)
}
-111
View File
@@ -12,7 +12,6 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/treytartt/honeydue-api/internal/config" "github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
) )
// CacheService provides Redis caching functionality // CacheService provides Redis caching functionality
@@ -134,116 +133,6 @@ func (c *CacheService) Close() error {
return nil return nil
} }
// Auth token cache helpers
const (
AuthTokenPrefix = "auth_token_"
TokenCacheTTL = 5 * time.Minute
)
// authTokenCacheKey returns the Redis key for an auth token. The raw token
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
func authTokenCacheKey(token string) string {
return AuthTokenPrefix + models.HashToken(token)
}
// CacheAuthToken caches a user ID for a token
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
}
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
}
// GetCachedAuthToken gets a cached user ID for a token
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
val, err := c.GetString(ctx, authTokenCacheKey(token))
if err != nil {
return 0, err
}
var userID uint
_, err = fmt.Sscanf(val, "%d", &userID)
return userID, err
}
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
val, err := c.GetString(ctx, authTokenCacheKey(token))
if err != nil {
return 0, 0, err
}
var userID uint
var createdUnix int64
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
if n < 1 {
return 0, 0, fmt.Errorf("invalid cached token format")
}
return userID, createdUnix, nil
}
// InvalidateAuthToken removes a cached token
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
return c.Delete(ctx, authTokenCacheKey(token))
}
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
// stored hash directly — used to evict a user's prior token on re-login
// (audit MEDIUM-1), where the server no longer has the plaintext.
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
keys := make([]string, 0, len(hashes))
for _, h := range hashes {
if h != "" {
keys = append(keys, AuthTokenPrefix+h)
}
}
if len(keys) == 0 {
return nil
}
return c.Delete(ctx, keys...)
}
// --- Per-account login-failure tracking (audit M5) ---
const loginFailPrefix = "login_fail:"
// RegisterLoginFailure records a failed login for an account from a given
// source IP, and returns the number of DISTINCT source IPs that have failed
// for this account within the window. Tracking distinct IPs as a set rather
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
// run the count up and lock a victim out by knowing only their email — a
// single IP is bounded by the per-IP edge/app rate limiters instead. A
// genuinely distributed credential-stuffing attack still trips the lockout.
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
key := loginFailPrefix + identifier
member := ip
if member == "" {
member = "unknown"
}
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
return 0, err
}
// Refresh the TTL on each failure: an active attack keeps the window
// open, while a quiet account ages out `window` after its last failure.
_ = c.client.Expire(ctx, key, window).Err()
return c.client.SCard(ctx, key).Result()
}
// LoginFailureIPCount returns how many distinct source IPs have failed to log
// in to this account within the window (audit MEDIUM-3). SCard on a missing
// key returns 0.
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
}
// ClearLoginFailures resets the failed-login IP set after a successful login.
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
}
// Static data cache helpers // Static data cache helpers
const ( const (
-307
View File
@@ -1,307 +0,0 @@
package services
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/treytartt/honeydue-api/internal/config"
)
const (
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
googleKeysCacheTTL = 24 * time.Hour
googleKeysCacheKey = "google:public_keys"
)
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
var googleIssuers = map[string]bool{
"accounts.google.com": true,
"https://accounts.google.com": true,
}
var (
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
ErrGoogleKeyNotFound = errors.New("Google public key not found")
)
// GoogleJWKS represents Google's JSON Web Key Set.
type GoogleJWKS struct {
Keys []GoogleJWK `json:"keys"`
}
// GoogleJWK represents a single JSON Web Key from Google.
type GoogleJWK struct {
Kty string `json:"kty"` // Key type (RSA)
Kid string `json:"kid"` // Key ID
Use string `json:"use"` // Key use (sig)
Alg string `json:"alg"` // Algorithm (RS256)
N string `json:"n"` // RSA modulus
E string `json:"e"` // RSA exponent
}
// GoogleTokenClaims represents the claims in a Google ID token JWT.
type GoogleTokenClaims struct {
jwt.RegisteredClaims
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Picture string `json:"picture,omitempty"`
Azp string `json:"azp,omitempty"` // Authorized party
}
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
type GoogleTokenInfo struct {
Sub string // Unique Google user ID
Email string
EmailVerified string // "true" or "false" — string for caller compatibility
Name string
GivenName string
FamilyName string
Picture string
Aud string
Azp string
Iss string
}
// IsEmailVerified returns whether the email is verified.
func (t *GoogleTokenInfo) IsEmailVerified() bool {
return t.EmailVerified == "true"
}
// GoogleAuthService handles Google Sign In token verification.
type GoogleAuthService struct {
cache *CacheService
config *config.Config
client *http.Client
}
// NewGoogleAuthService creates a new Google auth service.
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
return &GoogleAuthService{
cache: cache,
config: cfg,
client: &http.Client{Timeout: 10 * time.Second},
}
}
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
// the RS256 signature against Google's published JWKS and the iss, aud, and
// exp claims. It never sends the token to a third-party endpoint, so it no
// longer depends on the deprecated tokeninfo service and never leaks the
// token in a request URL.
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
// Parse the token header to get the key ID.
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, ErrInvalidGoogleToken
}
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode token header: %w", err)
}
var header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse token header: %w", err)
}
publicKey, err := s.getPublicKey(ctx, header.Kid)
if err != nil {
return nil, err
}
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrGoogleTokenExpired
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*GoogleTokenClaims)
if !ok || !token.Valid {
return nil, ErrInvalidGoogleToken
}
// Verify the issuer (audit C3).
if !googleIssuers[claims.Issuer] {
return nil, ErrInvalidGoogleIssuer
}
// Verify the audience matches one of our configured client IDs.
if !s.verifyAudience(claims.Audience, claims.Azp) {
return nil, ErrInvalidGoogleAudience
}
if claims.Subject == "" {
return nil, ErrInvalidGoogleToken
}
emailVerified := "false"
if claims.EmailVerified {
emailVerified = "true"
}
aud := ""
if len(claims.Audience) > 0 {
aud = claims.Audience[0]
}
return &GoogleTokenInfo{
Sub: claims.Subject,
Email: claims.Email,
EmailVerified: emailVerified,
Name: claims.Name,
GivenName: claims.GivenName,
FamilyName: claims.FamilyName,
Picture: claims.Picture,
Aud: aud,
Azp: claims.Azp,
Iss: claims.Issuer,
}, nil
}
// verifyAudience checks the token audience against our configured client IDs.
// In production (non-debug) an empty client ID fails verification rather than
// silently bypassing the check.
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
clientID := s.config.GoogleAuth.ClientID
if clientID == "" {
// In debug mode only, skip audience verification for local development.
return s.config.Server.Debug
}
candidates := []string{clientID}
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
candidates = append(candidates, id)
}
if id := s.config.GoogleAuth.IOSClientID; id != "" {
candidates = append(candidates, id)
}
for _, want := range candidates {
if azp == want {
return true
}
for _, aud := range audience {
if aud == want {
return true
}
}
}
return false
}
// getPublicKey returns the RSA public key for the given key ID, using a
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
// (Google rotates signing keys roughly daily).
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
keys, err := s.getCachedKeys(ctx)
if err != nil || keys == nil {
keys, err = s.fetchGooglePublicKeys(ctx)
if err != nil {
return nil, err
}
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
// Cache miss for this kid — keys may have rotated; fetch fresh.
keys, err = s.fetchGooglePublicKeys(ctx)
if err != nil {
return nil, err
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
return nil, ErrGoogleKeyNotFound
}
// getCachedKeys retrieves cached Google public keys from Redis.
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
if s.cache == nil {
return nil, nil
}
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
if err != nil || data == "" {
return nil, nil
}
var jwks GoogleJWKS
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
return nil, nil
}
return s.parseJWKS(&jwks), nil
}
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
}
var jwks GoogleJWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
}
if s.cache != nil {
keysJSON, _ := json.Marshal(jwks)
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
}
return s.parseJWKS(&jwks), nil
}
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwks.Keys {
if key.Kty != "RSA" {
continue
}
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
continue
}
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
continue
}
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
}
return keys
}
+21 -20
View File
@@ -22,6 +22,14 @@ import (
"github.com/treytartt/honeydue-api/internal/validator" "github.com/treytartt/honeydue-api/internal/validator"
) )
// authUserKey and authTokenKey mirror middleware.AuthUserKey / middleware.AuthTokenKey.
// We duplicate the string constants here to avoid an import cycle
// (testutil <- middleware <- repositories <- testutil).
const (
authUserKey = "auth_user"
authTokenKey = "auth_token"
)
var ( var (
i18nOnce sync.Once i18nOnce sync.Once
testDBCounter uint64 testDBCounter uint64
@@ -52,9 +60,6 @@ func SetupTestDB(t *testing.T) *gorm.DB {
err = db.AutoMigrate( err = db.AutoMigrate(
&models.User{}, &models.User{},
&models.UserProfile{}, &models.UserProfile{},
&models.AuthToken{},
&models.ConfirmationCode{},
&models.PasswordResetCode{},
&models.AdminUser{}, &models.AdminUser{},
&models.Residence{}, &models.Residence{},
&models.ResidenceType{}, &models.ResidenceType{},
@@ -73,8 +78,6 @@ func SetupTestDB(t *testing.T) *gorm.DB {
&models.NotificationPreference{}, &models.NotificationPreference{},
&models.APNSDevice{}, &models.APNSDevice{},
&models.GCMDevice{}, &models.GCMDevice{},
&models.AppleSocialAuth{},
&models.GoogleSocialAuth{},
&models.TaskReminderLog{}, &models.TaskReminderLog{},
&models.UserSubscription{}, &models.UserSubscription{},
&models.SubscriptionSettings{}, &models.SubscriptionSettings{},
@@ -177,29 +180,24 @@ func ParseJSONArray(t *testing.T, body []byte) []map[string]interface{} {
return result return result
} }
// CreateTestUser creates a test user in the database // CreateTestUser creates a test user in the database.
func CreateTestUser(t *testing.T, db *gorm.DB, username, email, password string) *models.User { // password is accepted for API compatibility but ignored — Kratos owns credentials.
// A synthetic KratosID is generated so the user satisfies the unique-index constraint.
func CreateTestUser(t *testing.T, db *gorm.DB, username, email, _ string) *models.User {
t.Helper()
user := &models.User{ user := &models.User{
KratosID: "test-kratos-" + username,
Username: username, Username: username,
Email: email, Email: email,
IsActive: true, IsActive: true,
} }
err := user.SetPassword(password)
require.NoError(t, err)
err = db.Create(user).Error err := db.Create(user).Error
require.NoError(t, err) require.NoError(t, err)
return user return user
} }
// CreateTestToken creates an auth token for a user
func CreateTestToken(t *testing.T, db *gorm.DB, userID uint) *models.AuthToken {
token, err := models.GetOrCreateToken(db, userID)
require.NoError(t, err)
return token
}
// CreateTestResidenceType creates a test residence type // CreateTestResidenceType creates a test residence type
func CreateTestResidenceType(t *testing.T, db *gorm.DB, name string) *models.ResidenceType { func CreateTestResidenceType(t *testing.T, db *gorm.DB, name string) *models.ResidenceType {
rt := &models.ResidenceType{Name: name} rt := &models.ResidenceType{Name: name}
@@ -362,12 +360,15 @@ func AssertStatusCode(t *testing.T, w *httptest.ResponseRecorder, expected int)
require.Equal(t, expected, w.Code, "unexpected status code: %s", w.Body.String()) require.Equal(t, expected, w.Code, "unexpected status code: %s", w.Body.String())
} }
// MockAuthMiddleware creates middleware that sets a test user in context // MockAuthMiddleware creates middleware that sets a test user in context.
// Uses the same context keys as KratosAuth (authUserKey / authTokenKey) so
// handlers are unaware of the swap. The constants are duplicated here to
// avoid an import cycle (testutil <- middleware <- repositories <- testutil).
func MockAuthMiddleware(user *models.User) echo.MiddlewareFunc { func MockAuthMiddleware(user *models.User) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
c.Set("auth_user", user) c.Set(authUserKey, user)
c.Set("auth_token", "test-token") c.Set(authTokenKey, "test-token")
return next(c) return next(c)
} }
} }
+37
View File
@@ -0,0 +1,37 @@
-- +goose Up
-- Phase 2: hand-rolled auth replaced by Ory Kratos. Kratos owns identities,
-- credentials, sessions, email verification, recovery and social sign-in.
-- honeyDue keeps a slim auth_user row linked to the Kratos identity by
-- kratos_id; all domain tables keep their existing integer auth_user FKs.
--
-- Pre-production: a clean slate is taken. auth_user is truncated (cascading
-- to all user-scoped domain data) so no auth_user row exists without a
-- Kratos identity behind it. There is no data migration.
-- honeyDue's hand-rolled auth tables are no longer used — Kratos owns this.
DROP TABLE IF EXISTS user_authtoken;
DROP TABLE IF EXISTS user_confirmationcode;
DROP TABLE IF EXISTS user_passwordresetcode;
DROP TABLE IF EXISTS user_applesocialauth;
DROP TABLE IF EXISTS user_googlesocialauth;
-- Link each auth_user row to its Kratos identity (UUID).
ALTER TABLE auth_user ADD COLUMN IF NOT EXISTS kratos_id uuid;
CREATE UNIQUE INDEX IF NOT EXISTS uq_auth_user_kratos_id
ON auth_user (kratos_id) WHERE kratos_id IS NOT NULL;
-- password is NOT NULL in the Django-era schema but is no longer used —
-- Kratos holds credentials. Make it nullable so provisioning need not
-- invent a placeholder hash.
ALTER TABLE auth_user ALTER COLUMN password DROP NOT NULL;
-- Clean slate (pre-production): drop every existing account and all
-- user-scoped domain data so nothing is left orphaned without a Kratos id.
TRUNCATE TABLE auth_user CASCADE;
-- +goose Down
-- The dropped tables' data cannot be restored. Down only removes the
-- kratos_id column and restores the password NOT NULL constraint; reverting
-- to hand-rolled auth means reverting the Phase 2 application code.
DROP INDEX IF EXISTS uq_auth_user_kratos_id;
ALTER TABLE auth_user DROP COLUMN IF EXISTS kratos_id;