feat(auth): replace hand-rolled auth with Ory Kratos — phase 2 backend
Delegates all credential management (login, register, password reset, email verification, social sign-in) to Ory Kratos. The Go API now acts as a resource server: the new KratosAuth middleware validates sessions against the Kratos whoami endpoint, writes the local User mirror into Echo context, and all existing domain handlers continue working unchanged. Hand-rolled token auth, AuthToken model, apple_auth/ google_auth services, and the auth refresh flow are removed. Tests are updated to use the fake-token middleware pattern so existing integration assertions require no rewrite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,215 +1,30 @@
|
||||
// apple_social_auth_handler is a stub — the user_applesocialauth table was
|
||||
// dropped in the Ory Kratos migration (phase 2). Social sign-in is now
|
||||
// handled by Kratos.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// AdminAppleSocialAuthHandler handles admin Apple social auth management endpoints
|
||||
// AdminAppleSocialAuthHandler is a no-op stub.
|
||||
type AdminAppleSocialAuthHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAdminAppleSocialAuthHandler creates a new admin Apple social auth handler
|
||||
func NewAdminAppleSocialAuthHandler(db *gorm.DB) *AdminAppleSocialAuthHandler {
|
||||
return &AdminAppleSocialAuthHandler{db: db}
|
||||
}
|
||||
|
||||
// AppleSocialAuthResponse represents the response for an Apple social auth entry
|
||||
type AppleSocialAuthResponse struct {
|
||||
ID uint `json:"id"`
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
UserEmail string `json:"user_email"`
|
||||
AppleID string `json:"apple_id"`
|
||||
Email string `json:"email"`
|
||||
IsPrivateEmail bool `json:"is_private_email"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UpdateAppleSocialAuthRequest represents the request to update an Apple social auth entry
|
||||
type UpdateAppleSocialAuthRequest struct {
|
||||
Email *string `json:"email"`
|
||||
IsPrivateEmail *bool `json:"is_private_email"`
|
||||
}
|
||||
|
||||
// List handles GET /api/admin/apple-social-auth
|
||||
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error {
|
||||
var filters dto.PaginationParams
|
||||
if err := c.Bind(&filters); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
var entries []models.AppleSocialAuth
|
||||
var total int64
|
||||
|
||||
query := h.db.Model(&models.AppleSocialAuth{}).Preload("User")
|
||||
|
||||
// Apply search
|
||||
if filters.Search != "" {
|
||||
search := "%" + filters.Search + "%"
|
||||
query = query.Joins("JOIN auth_user ON auth_user.id = user_applesocialauth.user_id").
|
||||
Where("user_applesocialauth.apple_id ILIKE ? OR user_applesocialauth.email ILIKE ? OR auth_user.username ILIKE ? OR auth_user.email ILIKE ?",
|
||||
search, search, search, search)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "apple_id", "email", "is_private_email",
|
||||
"created_at", "updated_at",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
if err := query.Find(&entries).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entries"})
|
||||
}
|
||||
|
||||
// Build response
|
||||
responses := make([]AppleSocialAuthResponse, len(entries))
|
||||
for i, entry := range entries {
|
||||
responses[i] = h.toResponse(&entry)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
||||
}
|
||||
|
||||
// Get handles GET /api/admin/apple-social-auth/:id
|
||||
func (h *AdminAppleSocialAuthHandler) Get(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
var entry models.AppleSocialAuth
|
||||
if err := h.db.Preload("User").First(&entry, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, h.toResponse(&entry))
|
||||
}
|
||||
|
||||
// GetByUser handles GET /api/admin/apple-social-auth/user/:user_id
|
||||
func (h *AdminAppleSocialAuthHandler) GetByUser(c echo.Context) error {
|
||||
userID, err := strconv.ParseUint(c.Param("user_id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"})
|
||||
}
|
||||
|
||||
var entry models.AppleSocialAuth
|
||||
if err := h.db.Preload("User").Where("user_id = ?", userID).First(&entry).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found for user"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, h.toResponse(&entry))
|
||||
}
|
||||
|
||||
// Update handles PUT /api/admin/apple-social-auth/:id
|
||||
func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
var entry models.AppleSocialAuth
|
||||
if err := h.db.First(&entry, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
|
||||
}
|
||||
|
||||
var req UpdateAppleSocialAuthRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
if req.Email != nil {
|
||||
entry.Email = *req.Email
|
||||
}
|
||||
if req.IsPrivateEmail != nil {
|
||||
entry.IsPrivateEmail = *req.IsPrivateEmail
|
||||
}
|
||||
|
||||
if err := h.db.Save(&entry).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to update Apple social auth entry"})
|
||||
}
|
||||
|
||||
h.db.Preload("User").First(&entry, id)
|
||||
return c.JSON(http.StatusOK, h.toResponse(&entry))
|
||||
}
|
||||
|
||||
// Delete handles DELETE /api/admin/apple-social-auth/:id
|
||||
func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
var entry models.AppleSocialAuth
|
||||
if err := h.db.First(&entry, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Apple social auth entry not found"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch Apple social auth entry"})
|
||||
}
|
||||
|
||||
if err := h.db.Delete(&entry).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entry"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entry deleted successfully"})
|
||||
}
|
||||
|
||||
// BulkDelete handles DELETE /api/admin/apple-social-auth/bulk
|
||||
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error {
|
||||
var req dto.BulkDeleteRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.AppleSocialAuth{})
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete Apple social auth entries"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Apple social auth entries deleted successfully", "count": result.RowsAffected})
|
||||
}
|
||||
|
||||
// toResponse converts an AppleSocialAuth model to AppleSocialAuthResponse
|
||||
func (h *AdminAppleSocialAuthHandler) toResponse(entry *models.AppleSocialAuth) AppleSocialAuthResponse {
|
||||
response := AppleSocialAuthResponse{
|
||||
ID: entry.ID,
|
||||
UserID: entry.UserID,
|
||||
AppleID: entry.AppleID,
|
||||
Email: entry.Email,
|
||||
IsPrivateEmail: entry.IsPrivateEmail,
|
||||
CreatedAt: entry.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
UpdatedAt: entry.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
if entry.User.ID != 0 {
|
||||
response.Username = entry.User.Username
|
||||
response.UserEmail = entry.User.Email
|
||||
}
|
||||
|
||||
return response
|
||||
func (h *AdminAppleSocialAuthHandler) gone(c echo.Context) error {
|
||||
return c.JSON(http.StatusGone, map[string]string{"message": "Apple social auth is managed by Ory Kratos"})
|
||||
}
|
||||
func (h *AdminAppleSocialAuthHandler) List(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAppleSocialAuthHandler) Get(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAppleSocialAuthHandler) Delete(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAppleSocialAuthHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAppleSocialAuthHandler) Update(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAppleSocialAuthHandler) GetByUser(c echo.Context) error { return h.gone(c) }
|
||||
|
||||
@@ -1,144 +1,27 @@
|
||||
// auth_token_handler is a stub — the user_authtoken table was dropped in the
|
||||
// Ory Kratos migration (phase 2). Auth tokens are now Kratos sessions.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// AdminAuthTokenHandler handles admin auth token management endpoints
|
||||
// AdminAuthTokenHandler is a no-op stub.
|
||||
type AdminAuthTokenHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAdminAuthTokenHandler creates a new admin auth token handler
|
||||
func NewAdminAuthTokenHandler(db *gorm.DB) *AdminAuthTokenHandler {
|
||||
return &AdminAuthTokenHandler{db: db}
|
||||
}
|
||||
|
||||
// AuthTokenResponse represents an auth token in API responses
|
||||
type AuthTokenResponse struct {
|
||||
Key string `json:"key"`
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Created string `json:"created"`
|
||||
}
|
||||
|
||||
// List handles GET /api/admin/auth-tokens
|
||||
func (h *AdminAuthTokenHandler) List(c echo.Context) error {
|
||||
var filters dto.PaginationParams
|
||||
if err := c.Bind(&filters); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
var tokens []models.AuthToken
|
||||
var total int64
|
||||
|
||||
query := h.db.Model(&models.AuthToken{}).Preload("User")
|
||||
|
||||
// Apply search (search by user info)
|
||||
if filters.Search != "" {
|
||||
search := "%" + filters.Search + "%"
|
||||
query = query.Joins("JOIN auth_user ON auth_user.id = user_authtoken.user_id").
|
||||
Where(
|
||||
"auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_authtoken.key ILIKE ?",
|
||||
search, search, search,
|
||||
)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"created", "user_id",
|
||||
}, "created")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
if err := query.Find(&tokens).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch auth tokens"})
|
||||
}
|
||||
|
||||
// Build response
|
||||
responses := make([]AuthTokenResponse, len(tokens))
|
||||
for i, token := range tokens {
|
||||
responses[i] = AuthTokenResponse{
|
||||
Key: token.Key,
|
||||
UserID: token.UserID,
|
||||
Username: token.User.Username,
|
||||
Email: token.User.Email,
|
||||
Created: token.Created.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
||||
}
|
||||
|
||||
// Get handles GET /api/admin/auth-tokens/:id (id is actually user_id)
|
||||
func (h *AdminAuthTokenHandler) Get(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"})
|
||||
}
|
||||
|
||||
var token models.AuthToken
|
||||
if err := h.db.Preload("User").Where("user_id = ?", id).First(&token).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Auth token not found"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch auth token"})
|
||||
}
|
||||
|
||||
response := AuthTokenResponse{
|
||||
Key: token.Key,
|
||||
UserID: token.UserID,
|
||||
Username: token.User.Username,
|
||||
Email: token.User.Email,
|
||||
Created: token.Created.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /api/admin/auth-tokens/:id (revoke token)
|
||||
func (h *AdminAuthTokenHandler) Delete(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid user ID"})
|
||||
}
|
||||
|
||||
result := h.db.Where("user_id = ?", id).Delete(&models.AuthToken{})
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to revoke token"})
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Auth token not found"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Auth token revoked successfully"})
|
||||
}
|
||||
|
||||
// BulkDelete handles DELETE /api/admin/auth-tokens/bulk
|
||||
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error {
|
||||
var req dto.BulkDeleteRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
result := h.db.Where("user_id IN ?", req.IDs).Delete(&models.AuthToken{})
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to revoke tokens"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Auth tokens revoked successfully", "count": result.RowsAffected})
|
||||
func (h *AdminAuthTokenHandler) gone(c echo.Context) error {
|
||||
return c.JSON(http.StatusGone, map[string]string{"message": "auth tokens are managed by Ory Kratos"})
|
||||
}
|
||||
func (h *AdminAuthTokenHandler) List(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAuthTokenHandler) Get(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAuthTokenHandler) Delete(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminAuthTokenHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
|
||||
|
||||
@@ -1,162 +1,28 @@
|
||||
// confirmation_code_handler is a stub — the user_confirmationcode table was
|
||||
// dropped in the Ory Kratos migration (phase 2). Email verification is now
|
||||
// handled by Kratos.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// maskCode masks a confirmation code, showing only the last 4 characters.
|
||||
func maskCode(code string) string {
|
||||
if len(code) <= 4 {
|
||||
return strings.Repeat("*", len(code))
|
||||
}
|
||||
return strings.Repeat("*", len(code)-4) + code[len(code)-4:]
|
||||
}
|
||||
|
||||
// AdminConfirmationCodeHandler handles admin confirmation code management endpoints
|
||||
// AdminConfirmationCodeHandler is a no-op stub.
|
||||
type AdminConfirmationCodeHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAdminConfirmationCodeHandler creates a new admin confirmation code handler
|
||||
func NewAdminConfirmationCodeHandler(db *gorm.DB) *AdminConfirmationCodeHandler {
|
||||
return &AdminConfirmationCodeHandler{db: db}
|
||||
}
|
||||
|
||||
// ConfirmationCodeResponse represents a confirmation code in API responses
|
||||
type ConfirmationCodeResponse struct {
|
||||
ID uint `json:"id"`
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Code string `json:"code"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
IsUsed bool `json:"is_used"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// List handles GET /api/admin/confirmation-codes
|
||||
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error {
|
||||
var filters dto.PaginationParams
|
||||
if err := c.Bind(&filters); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
var codes []models.ConfirmationCode
|
||||
var total int64
|
||||
|
||||
query := h.db.Model(&models.ConfirmationCode{}).Preload("User")
|
||||
|
||||
// Apply search (search by user info or code)
|
||||
if filters.Search != "" {
|
||||
search := "%" + filters.Search + "%"
|
||||
query = query.Joins("JOIN auth_user ON auth_user.id = user_confirmationcode.user_id").
|
||||
Where(
|
||||
"auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_confirmationcode.code ILIKE ?",
|
||||
search, search, search,
|
||||
)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "created_at", "expires_at", "is_used",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
if err := query.Find(&codes).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch confirmation codes"})
|
||||
}
|
||||
|
||||
// Build response
|
||||
responses := make([]ConfirmationCodeResponse, len(codes))
|
||||
for i, code := range codes {
|
||||
responses[i] = ConfirmationCodeResponse{
|
||||
ID: code.ID,
|
||||
UserID: code.UserID,
|
||||
Username: code.User.Username,
|
||||
Email: code.User.Email,
|
||||
Code: maskCode(code.Code),
|
||||
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||
IsUsed: code.IsUsed,
|
||||
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
||||
}
|
||||
|
||||
// Get handles GET /api/admin/confirmation-codes/:id
|
||||
func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
var code models.ConfirmationCode
|
||||
if err := h.db.Preload("User").First(&code, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Confirmation code not found"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch confirmation code"})
|
||||
}
|
||||
|
||||
response := ConfirmationCodeResponse{
|
||||
ID: code.ID,
|
||||
UserID: code.UserID,
|
||||
Username: code.User.Username,
|
||||
Email: code.User.Email,
|
||||
Code: maskCode(code.Code),
|
||||
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||
IsUsed: code.IsUsed,
|
||||
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /api/admin/confirmation-codes/:id
|
||||
func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
result := h.db.Delete(&models.ConfirmationCode{}, id)
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation code"})
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Confirmation code not found"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Confirmation code deleted successfully"})
|
||||
}
|
||||
|
||||
// BulkDelete handles DELETE /api/admin/confirmation-codes/bulk
|
||||
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error {
|
||||
var req dto.BulkDeleteRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.ConfirmationCode{})
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete confirmation codes"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Confirmation codes deleted successfully", "count": result.RowsAffected})
|
||||
func (h *AdminConfirmationCodeHandler) gone(c echo.Context) error {
|
||||
return c.JSON(http.StatusGone, map[string]string{"message": "confirmation codes are managed by Ory Kratos"})
|
||||
}
|
||||
func (h *AdminConfirmationCodeHandler) List(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminConfirmationCodeHandler) Get(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminConfirmationCodeHandler) Delete(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminConfirmationCodeHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
|
||||
|
||||
@@ -1,159 +1,28 @@
|
||||
// password_reset_code_handler is a stub — the user_passwordresetcode table
|
||||
// was dropped in the Ory Kratos migration (phase 2). Password resets are now
|
||||
// handled by Kratos.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/admin/dto"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// AdminPasswordResetCodeHandler handles admin password reset code management endpoints
|
||||
// AdminPasswordResetCodeHandler is a no-op stub.
|
||||
type AdminPasswordResetCodeHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAdminPasswordResetCodeHandler creates a new admin password reset code handler
|
||||
func NewAdminPasswordResetCodeHandler(db *gorm.DB) *AdminPasswordResetCodeHandler {
|
||||
return &AdminPasswordResetCodeHandler{db: db}
|
||||
}
|
||||
|
||||
// PasswordResetCodeResponse represents a password reset code in API responses
|
||||
type PasswordResetCodeResponse struct {
|
||||
ID uint `json:"id"`
|
||||
UserID uint `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
ResetToken string `json:"reset_token"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
Used bool `json:"used"`
|
||||
Attempts int `json:"attempts"`
|
||||
MaxAttempts int `json:"max_attempts"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
// List handles GET /api/admin/password-reset-codes
|
||||
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error {
|
||||
var filters dto.PaginationParams
|
||||
if err := c.Bind(&filters); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
var codes []models.PasswordResetCode
|
||||
var total int64
|
||||
|
||||
query := h.db.Model(&models.PasswordResetCode{}).Preload("User")
|
||||
|
||||
// Apply search (search by user info or token)
|
||||
if filters.Search != "" {
|
||||
search := "%" + filters.Search + "%"
|
||||
query = query.Joins("JOIN auth_user ON auth_user.id = user_passwordresetcode.user_id").
|
||||
Where(
|
||||
"auth_user.username ILIKE ? OR auth_user.email ILIKE ? OR user_passwordresetcode.reset_token ILIKE ?",
|
||||
search, search, search,
|
||||
)
|
||||
}
|
||||
|
||||
// Get total count
|
||||
query.Count(&total)
|
||||
|
||||
// Apply sorting (allowlist prevents SQL injection via sort_by parameter)
|
||||
sortBy := filters.GetSafeSortBy([]string{
|
||||
"id", "user_id", "created_at", "expires_at", "used",
|
||||
}, "created_at")
|
||||
query = query.Order(sortBy + " " + filters.GetSortDir())
|
||||
|
||||
// Apply pagination
|
||||
query = query.Offset(filters.GetOffset()).Limit(filters.GetPerPage())
|
||||
|
||||
if err := query.Find(&codes).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch password reset codes"})
|
||||
}
|
||||
|
||||
// Build response
|
||||
responses := make([]PasswordResetCodeResponse, len(codes))
|
||||
for i, code := range codes {
|
||||
responses[i] = PasswordResetCodeResponse{
|
||||
ID: code.ID,
|
||||
UserID: code.UserID,
|
||||
Username: code.User.Username,
|
||||
Email: code.User.Email,
|
||||
ResetToken: code.ResetToken[:8] + "..." + code.ResetToken[len(code.ResetToken)-4:], // Truncate for display
|
||||
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||
Used: code.Used,
|
||||
Attempts: code.Attempts,
|
||||
MaxAttempts: code.MaxAttempts,
|
||||
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, dto.NewPaginatedResponse(responses, total, filters.GetPage(), filters.GetPerPage()))
|
||||
}
|
||||
|
||||
// Get handles GET /api/admin/password-reset-codes/:id
|
||||
func (h *AdminPasswordResetCodeHandler) Get(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
var code models.PasswordResetCode
|
||||
if err := h.db.Preload("User").First(&code, id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Password reset code not found"})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to fetch password reset code"})
|
||||
}
|
||||
|
||||
response := PasswordResetCodeResponse{
|
||||
ID: code.ID,
|
||||
UserID: code.UserID,
|
||||
Username: code.User.Username,
|
||||
Email: code.User.Email,
|
||||
ResetToken: code.ResetToken[:8] + "..." + code.ResetToken[len(code.ResetToken)-4:],
|
||||
ExpiresAt: code.ExpiresAt.Format("2006-01-02T15:04:05Z"),
|
||||
Used: code.Used,
|
||||
Attempts: code.Attempts,
|
||||
MaxAttempts: code.MaxAttempts,
|
||||
CreatedAt: code.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Delete handles DELETE /api/admin/password-reset-codes/:id
|
||||
func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error {
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid ID"})
|
||||
}
|
||||
|
||||
result := h.db.Delete(&models.PasswordResetCode{}, id)
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset code"})
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return c.JSON(http.StatusNotFound, map[string]interface{}{"error": "Password reset code not found"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Password reset code deleted successfully"})
|
||||
}
|
||||
|
||||
// BulkDelete handles DELETE /api/admin/password-reset-codes/bulk
|
||||
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error {
|
||||
var req dto.BulkDeleteRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid request body"})
|
||||
}
|
||||
|
||||
result := h.db.Where("id IN ?", req.IDs).Delete(&models.PasswordResetCode{})
|
||||
if result.Error != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to delete password reset codes"})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Password reset codes deleted successfully", "count": result.RowsAffected})
|
||||
func (h *AdminPasswordResetCodeHandler) gone(c echo.Context) error {
|
||||
return c.JSON(http.StatusGone, map[string]string{"message": "password reset codes are managed by Ory Kratos"})
|
||||
}
|
||||
func (h *AdminPasswordResetCodeHandler) List(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminPasswordResetCodeHandler) Get(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminPasswordResetCodeHandler) Delete(c echo.Context) error { return h.gone(c) }
|
||||
func (h *AdminPasswordResetCodeHandler) BulkDelete(c echo.Context) error { return h.gone(c) }
|
||||
|
||||
@@ -207,9 +207,7 @@ func (h *AdminUserHandler) Create(c echo.Context) error {
|
||||
user.IsSuperuser = *req.IsSuperuser
|
||||
}
|
||||
|
||||
if err := user.SetPassword(req.Password); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to hash password"})
|
||||
}
|
||||
// Password management is handled by Ory Kratos; no local password hashing.
|
||||
|
||||
if err := h.db.Create(&user).Error; err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to create user"})
|
||||
@@ -284,10 +282,9 @@ func (h *AdminUserHandler) Update(c echo.Context) error {
|
||||
if req.IsSuperuser != nil {
|
||||
user.IsSuperuser = *req.IsSuperuser
|
||||
}
|
||||
// Password management is handled by Ory Kratos; local password update ignored.
|
||||
if req.Password != nil {
|
||||
if err := user.SetPassword(*req.Password); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Failed to hash password"})
|
||||
}
|
||||
_ = req.Password // Password changes must go through Kratos admin API
|
||||
}
|
||||
|
||||
if err := h.db.Save(&user).Error; err != nil {
|
||||
|
||||
@@ -142,6 +142,9 @@ type SecurityConfig struct {
|
||||
MaxPasswordResetRate int // per hour
|
||||
TokenExpiryDays int // Number of days before auth tokens expire (default 90)
|
||||
TokenRefreshDays int // Token must be at least this many days old before refresh (default 60)
|
||||
// KratosPublicURL is the Ory Kratos public API base URL. The auth
|
||||
// middleware validates sessions against {KratosPublicURL}/sessions/whoami.
|
||||
KratosPublicURL string
|
||||
}
|
||||
|
||||
// StorageConfig holds file storage settings.
|
||||
@@ -304,6 +307,7 @@ func Load() (*Config, error) {
|
||||
MaxPasswordResetRate: 3,
|
||||
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
|
||||
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
|
||||
KratosPublicURL: viper.GetString("KRATOS_PUBLIC_URL"),
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
|
||||
@@ -411,6 +415,7 @@ func setDefaults() {
|
||||
|
||||
// Token expiry defaults
|
||||
viper.SetDefault("TOKEN_EXPIRY_DAYS", 90) // Tokens expire after 90 days
|
||||
viper.SetDefault("KRATOS_PUBLIC_URL", "http://kratos:4433") // Ory Kratos public API
|
||||
viper.SetDefault("TOKEN_REFRESH_DAYS", 60) // Tokens can be refreshed after 60 days
|
||||
|
||||
// Storage defaults
|
||||
|
||||
@@ -244,12 +244,7 @@ func Migrate() error {
|
||||
|
||||
// User and auth tables
|
||||
&models.User{},
|
||||
&models.AuthToken{},
|
||||
&models.UserProfile{},
|
||||
&models.ConfirmationCode{},
|
||||
&models.PasswordResetCode{},
|
||||
&models.AppleSocialAuth{},
|
||||
&models.GoogleSocialAuth{},
|
||||
|
||||
// Admin users (separate from app users)
|
||||
&models.AdminUser{},
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
@@ -16,18 +14,18 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/validator"
|
||||
)
|
||||
|
||||
// AuthHandler handles authentication endpoints
|
||||
// AuthHandler handles user profile and account management endpoints.
|
||||
// Session lifecycle (login, register, logout, password reset) is delegated
|
||||
// to Ory Kratos; this handler only deals with the honeyDue user record.
|
||||
type AuthHandler struct {
|
||||
authService *services.AuthService
|
||||
emailService *services.EmailService
|
||||
cache *services.CacheService
|
||||
appleAuthService *services.AppleAuthService
|
||||
googleAuthService *services.GoogleAuthService
|
||||
storageService *services.StorageService
|
||||
auditService *services.AuditService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new auth handler
|
||||
// NewAuthHandler creates a new auth handler.
|
||||
func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService, cache *services.CacheService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: authService,
|
||||
@@ -36,136 +34,21 @@ func NewAuthHandler(authService *services.AuthService, emailService *services.Em
|
||||
}
|
||||
}
|
||||
|
||||
// SetAppleAuthService sets the Apple auth service (called after initialization)
|
||||
func (h *AuthHandler) SetAppleAuthService(appleAuth *services.AppleAuthService) {
|
||||
h.appleAuthService = appleAuth
|
||||
}
|
||||
|
||||
// SetGoogleAuthService sets the Google auth service (called after initialization)
|
||||
func (h *AuthHandler) SetGoogleAuthService(googleAuth *services.GoogleAuthService) {
|
||||
h.googleAuthService = googleAuth
|
||||
}
|
||||
|
||||
// SetStorageService sets the storage service for file deletion during account deletion
|
||||
// SetStorageService sets the storage service for file deletion during account deletion.
|
||||
func (h *AuthHandler) SetStorageService(storageService *services.StorageService) {
|
||||
h.storageService = storageService
|
||||
}
|
||||
|
||||
// SetAuditService sets the audit service for logging security events
|
||||
// SetAuditService sets the audit service for logging security events.
|
||||
func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
|
||||
h.auditService = auditService
|
||||
}
|
||||
|
||||
// noStore marks a response as non-cacheable (audit L2) — auth responses
|
||||
// carry tokens and user data that must never sit in any cache.
|
||||
// noStore marks a response as non-cacheable.
|
||||
func noStore(c echo.Context) {
|
||||
c.Response().Header().Set("Cache-Control", "no-store")
|
||||
}
|
||||
|
||||
// Login handles POST /api/auth/login/
|
||||
func (h *AuthHandler) Login(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.LoginRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
response, err := h.authService.Login(c.Request().Context(), &req, c.RealIP())
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("identifier", req.Username).
|
||||
Str("ip", c.RealIP()).Str("user_agent", c.Request().UserAgent()).
|
||||
Msg("Login failed")
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
|
||||
"identifier": req.Username,
|
||||
})
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
userID := response.User.ID
|
||||
h.auditService.LogEvent(c, &userID, services.AuditEventLogin, nil)
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Register handles POST /api/auth/register/
|
||||
func (h *AuthHandler) Register(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.RegisterRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
response, confirmationCode, err := h.authService.Register(c.Request().Context(), &req)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Registration failed")
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
userID := response.User.ID
|
||||
h.auditService.LogEvent(c, &userID, services.AuditEventRegister, map[string]interface{}{
|
||||
"username": req.Username,
|
||||
"email": req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Send welcome email with confirmation code (async)
|
||||
if h.emailService != nil && confirmationCode != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
|
||||
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// Logout handles POST /api/auth/logout/
|
||||
func (h *AuthHandler) Logout(c echo.Context) error {
|
||||
token := middleware.GetAuthToken(c)
|
||||
if token == "" {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
// Log audit event before invalidating the token
|
||||
if h.auditService != nil {
|
||||
user := middleware.GetAuthUser(c)
|
||||
if user != nil {
|
||||
h.auditService.LogEvent(c, &user.ID, services.AuditEventLogout, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate token in database
|
||||
if err := h.authService.Logout(c.Request().Context(), token); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to delete token from database")
|
||||
}
|
||||
|
||||
// Invalidate token in cache
|
||||
if h.cache != nil {
|
||||
if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to invalidate token in cache")
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Logged out successfully"})
|
||||
}
|
||||
|
||||
// CurrentUser handles GET /api/auth/me/
|
||||
func (h *AuthHandler) CurrentUser(c echo.Context) error {
|
||||
noStore(c)
|
||||
@@ -207,301 +90,6 @@ func (h *AuthHandler) UpdateProfile(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// VerifyEmail handles POST /api/auth/verify-email/
|
||||
func (h *AuthHandler) VerifyEmail(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req requests.VerifyEmailRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
err = h.authService.VerifyEmail(c.Request().Context(), user.ID, req.Code)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// Send post-verification welcome email with tips (async)
|
||||
if h.emailService != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.VerifyEmailResponse{
|
||||
Message: "Email verified successfully",
|
||||
Verified: true,
|
||||
})
|
||||
}
|
||||
|
||||
// ResendVerification handles POST /api/auth/resend-verification/
|
||||
func (h *AuthHandler) ResendVerification(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code, err := h.authService.ResendVerificationCode(c.Request().Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
|
||||
return err
|
||||
}
|
||||
|
||||
// Send verification email (async)
|
||||
if h.emailService != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Verification email sent"})
|
||||
}
|
||||
|
||||
// ForgotPassword handles POST /api/auth/forgot-password/
|
||||
func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
||||
var req requests.ForgotPasswordRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
noStore(c)
|
||||
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
|
||||
"email": req.Email,
|
||||
})
|
||||
}
|
||||
|
||||
// Audit LIVE-L13: run the user lookup, code generation, and email send
|
||||
// entirely in the background, then return the generic response
|
||||
// immediately. This makes the response time identical whether or not
|
||||
// the email belongs to a real account, defeating timing-based user
|
||||
// enumeration. context.Background() is used because the request context
|
||||
// is cancelled the moment this handler returns. Per-account rate
|
||||
// limiting still runs inside the service; the edge auth-rate-limit
|
||||
// middleware covers per-IP abuse.
|
||||
email := req.Email
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", email).Msg("Panic in forgot-password goroutine")
|
||||
}
|
||||
}()
|
||||
code, user, err := h.authService.ForgotPassword(context.Background(), email)
|
||||
if err != nil || code == "" || user == nil {
|
||||
return
|
||||
}
|
||||
if h.emailService != nil {
|
||||
if sendErr := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); sendErr != nil {
|
||||
log.Error().Err(sendErr).Str("email", user.Email).Msg("Failed to send password reset email")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Always return success to prevent email enumeration.
|
||||
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
|
||||
Message: "Password reset email sent",
|
||||
})
|
||||
}
|
||||
|
||||
// VerifyResetCode handles POST /api/auth/verify-reset-code/
|
||||
func (h *AuthHandler) VerifyResetCode(c echo.Context) error {
|
||||
var req requests.VerifyResetCodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
resetToken, err := h.authService.VerifyResetCode(c.Request().Context(), req.Email, req.Code)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("email", req.Email).Msg("Verify reset code failed")
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.VerifyResetCodeResponse{
|
||||
Message: "Reset code verified",
|
||||
ResetToken: resetToken,
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPassword handles POST /api/auth/reset-password/
|
||||
func (h *AuthHandler) ResetPassword(c echo.Context) error {
|
||||
var req requests.ResetPasswordRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
err := h.authService.ResetPassword(c.Request().Context(), req.ResetToken, req.NewPassword)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Password reset failed")
|
||||
return err
|
||||
}
|
||||
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventPasswordChanged, map[string]interface{}{
|
||||
"method": "reset_token",
|
||||
})
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.ResetPasswordResponse{
|
||||
Message: "Password reset successful",
|
||||
})
|
||||
}
|
||||
|
||||
// AppleSignIn handles POST /api/auth/apple-sign-in/
|
||||
func (h *AuthHandler) AppleSignIn(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.AppleSignInRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
if h.appleAuthService == nil {
|
||||
log.Error().Msg("Apple auth service not configured")
|
||||
return &apperrors.AppError{
|
||||
Code: 500,
|
||||
MessageKey: "error.apple_signin_not_configured",
|
||||
}
|
||||
}
|
||||
|
||||
response, err := h.authService.AppleSignIn(c.Request().Context(), h.appleAuthService, &req)
|
||||
if err != nil {
|
||||
// Check for legacy Apple Sign In error (not yet migrated)
|
||||
if errors.Is(err, services.ErrAppleSignInFailed) {
|
||||
log.Debug().Err(err).Msg("Apple Sign In failed (legacy error)")
|
||||
return apperrors.Unauthorized("error.invalid_apple_token")
|
||||
}
|
||||
|
||||
log.Debug().Err(err).Msg("Apple Sign In failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// Send welcome email for new users (async)
|
||||
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GoogleSignIn handles POST /api/auth/google-sign-in/
|
||||
func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.GoogleSignInRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
if h.googleAuthService == nil {
|
||||
log.Error().Msg("Google auth service not configured")
|
||||
return &apperrors.AppError{
|
||||
Code: 500,
|
||||
MessageKey: "error.google_signin_not_configured",
|
||||
}
|
||||
}
|
||||
|
||||
response, err := h.authService.GoogleSignIn(c.Request().Context(), h.googleAuthService, &req)
|
||||
if err != nil {
|
||||
// Check for legacy Google Sign In error (not yet migrated)
|
||||
if errors.Is(err, services.ErrGoogleSignInFailed) {
|
||||
log.Debug().Err(err).Msg("Google Sign In failed (legacy error)")
|
||||
return apperrors.Unauthorized("error.invalid_google_token")
|
||||
}
|
||||
|
||||
log.Debug().Err(err).Msg("Google Sign In failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// Send welcome email for new users (async)
|
||||
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// RefreshToken handles POST /api/auth/refresh/
|
||||
func (h *AuthHandler) RefreshToken(c echo.Context) error {
|
||||
noStore(c)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token := middleware.GetAuthToken(c)
|
||||
if token == "" {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
response, err := h.authService.RefreshToken(c.Request().Context(), token, user.ID)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Uint("user_id", user.ID).Msg("Token refresh failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// If the token was refreshed (new token), invalidate the old one from cache
|
||||
if response.Token != token && h.cache != nil {
|
||||
if cacheErr := h.cache.InvalidateAuthToken(c.Request().Context(), token); cacheErr != nil {
|
||||
log.Warn().Err(cacheErr).Msg("Failed to invalidate old token from cache during refresh")
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// DeleteAccount handles DELETE /api/auth/account/
|
||||
func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
@@ -544,13 +132,5 @@ func (h *AuthHandler) DeleteAccount(c echo.Context) error {
|
||||
}()
|
||||
}
|
||||
|
||||
// Invalidate auth token from cache
|
||||
token := middleware.GetAuthToken(c)
|
||||
if h.cache != nil && token != "" {
|
||||
if err := h.cache.InvalidateAuthToken(c.Request().Context(), token); err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to invalidate token in cache after account deletion")
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(http.StatusOK, responses.MessageResponse{Message: "Account deleted successfully"})
|
||||
}
|
||||
|
||||
@@ -35,26 +35,25 @@ func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
|
||||
// TestAuthHandler_DeleteAccount_WithConfirmation verifies that DELETE /account/
|
||||
// succeeds when the user sends confirmation: "DELETE".
|
||||
// Post-Kratos: all users (regardless of provider) must confirm with "DELETE".
|
||||
func TestAuthHandler_DeleteAccount_WithConfirmation(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "Password123")
|
||||
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "ignored")
|
||||
|
||||
// Create profile for the user
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
|
||||
// Create auth token
|
||||
testutil.CreateTestToken(t, db, user.ID)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("successful deletion with correct password", func(t *testing.T) {
|
||||
password := "Password123"
|
||||
t.Run("successful deletion with DELETE confirmation", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"password": password,
|
||||
"confirmation": "DELETE",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
@@ -74,106 +73,15 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
|
||||
// Verify profile is deleted
|
||||
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify auth token is deleted
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
// TestAuthHandler_DeleteAccount_MissingConfirmation verifies that a missing
|
||||
// confirmation string is rejected with 400.
|
||||
func TestAuthHandler_DeleteAccount_MissingConfirmation(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("wrong password returns 401", func(t *testing.T) {
|
||||
wrongPw := "wrongpassword"
|
||||
req := map[string]interface{}{
|
||||
"password": wrongPw,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("missing password returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_SocialAuthUser(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "randompassword")
|
||||
|
||||
// Create Apple social auth record
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_123",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
// Create profile
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("successful deletion with DELETE confirmation", func(t *testing.T) {
|
||||
confirmation := "DELETE"
|
||||
req := map[string]interface{}{
|
||||
"confirmation": confirmation,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
// Verify user is deleted
|
||||
var count int64
|
||||
db.Model(&models.User{}).Where("id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify apple auth is deleted
|
||||
db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "randompassword")
|
||||
|
||||
// Create Google social auth record
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_456",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "ignored")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
@@ -188,9 +96,8 @@ func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("wrong confirmation returns 400", func(t *testing.T) {
|
||||
wrongConfirmation := "delete"
|
||||
req := map[string]interface{}{
|
||||
"confirmation": wrongConfirmation,
|
||||
"confirmation": "delete", // lowercase — must be exact "DELETE"
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "test-token")
|
||||
@@ -199,6 +106,8 @@ func TestAuthHandler_DeleteAccount_SocialAuthMissingConfirmation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestAuthHandler_DeleteAccount_Unauthenticated verifies that 401 is returned
|
||||
// when no auth middleware is set.
|
||||
func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
|
||||
handler, e, _ := setupDeleteAccountHandler(t)
|
||||
|
||||
@@ -207,7 +116,7 @@ func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
|
||||
|
||||
t.Run("unauthenticated request returns 401", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"password": "Password123",
|
||||
"confirmation": "DELETE",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "")
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
// auth_handler_test.go tests the auth handler endpoints that survived the
|
||||
// Ory Kratos migration: GET /me/ and PUT/PATCH /profile/.
|
||||
// Login, register, logout, forgot-password, and social sign-in are now
|
||||
// handled by Kratos.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
@@ -34,204 +38,32 @@ func setupAuthHandler(t *testing.T) (*AuthHandler, *echo.Echo, *repositories.Use
|
||||
return handler, e, userRepo
|
||||
}
|
||||
|
||||
func TestAuthHandler_Register(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/register/", handler.Register)
|
||||
|
||||
t.Run("successful registration", func(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
FirstName: "New",
|
||||
LastName: "User",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertJSONFieldExists(t, response, "token")
|
||||
testutil.AssertJSONFieldExists(t, response, "user")
|
||||
testutil.AssertJSONFieldExists(t, response, "message")
|
||||
|
||||
user := response["user"].(map[string]interface{})
|
||||
assert.Equal(t, "newuser", user["username"])
|
||||
assert.Equal(t, "new@test.com", user["email"])
|
||||
assert.Equal(t, "New", user["first_name"])
|
||||
assert.Equal(t, "User", user["last_name"])
|
||||
})
|
||||
|
||||
t.Run("registration with missing fields", func(t *testing.T) {
|
||||
req := map[string]string{
|
||||
"username": "test",
|
||||
// Missing email and password
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
response := testutil.ParseJSON(t, w.Body.Bytes())
|
||||
testutil.AssertJSONFieldExists(t, response, "error")
|
||||
})
|
||||
|
||||
t.Run("registration with short password", func(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "short", // Less than 8 chars
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("registration with duplicate username", func(t *testing.T) {
|
||||
// First registration
|
||||
req := requests.RegisterRequest{
|
||||
Username: "duplicate",
|
||||
Email: "unique1@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
|
||||
// Try to register again with same username
|
||||
req.Email = "unique2@test.com"
|
||||
w = testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusConflict) // 409 for duplicate resource
|
||||
|
||||
response := testutil.ParseJSON(t, w.Body.Bytes())
|
||||
assert.Contains(t, response["error"], "Username already taken")
|
||||
})
|
||||
|
||||
t.Run("registration with duplicate email", func(t *testing.T) {
|
||||
// First registration
|
||||
req := requests.RegisterRequest{
|
||||
Username: "user1",
|
||||
Email: "duplicate@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
|
||||
// Try to register again with same email
|
||||
req.Username = "user2"
|
||||
w = testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusConflict) // 409 for duplicate resource
|
||||
|
||||
response := testutil.ParseJSON(t, w.Body.Bytes())
|
||||
assert.Contains(t, response["error"], "Email already registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/register/", handler.Register)
|
||||
e.POST("/api/auth/login/", handler.Login)
|
||||
|
||||
// Create a test user
|
||||
registerReq := requests.RegisterRequest{
|
||||
Username: "logintest",
|
||||
Email: "login@test.com",
|
||||
Password: "Password123",
|
||||
FirstName: "Test",
|
||||
LastName: "User",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
|
||||
t.Run("successful login with username", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "logintest",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertJSONFieldExists(t, response, "token")
|
||||
testutil.AssertJSONFieldExists(t, response, "user")
|
||||
|
||||
user := response["user"].(map[string]interface{})
|
||||
assert.Equal(t, "logintest", user["username"])
|
||||
assert.Equal(t, "login@test.com", user["email"])
|
||||
})
|
||||
|
||||
t.Run("successful login with email", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "login@test.com", // Using email as username
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("login with wrong password", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "logintest",
|
||||
Password: "wrongpassword",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
|
||||
response := testutil.ParseJSON(t, w.Body.Bytes())
|
||||
assert.Contains(t, response["error"], "Invalid credentials")
|
||||
})
|
||||
|
||||
t.Run("login with non-existent user", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("login with missing fields", func(t *testing.T) {
|
||||
req := map[string]string{
|
||||
"username": "logintest",
|
||||
// Missing password
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_CurrentUser(t *testing.T) {
|
||||
handler, e, userRepo := setupAuthHandler(t)
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "Password123")
|
||||
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "")
|
||||
user.FirstName = "Test"
|
||||
user.LastName = "User"
|
||||
userRepo.Update(user)
|
||||
// Use the userRepo from setupAuthHandler's DB, but since we need the user
|
||||
// in the same DB we re-create it there.
|
||||
db2 := testutil.SetupTestDB(t)
|
||||
user2 := testutil.CreateTestUser(t, db2, "metest2", "me2@test.com", "")
|
||||
user2.FirstName = "Test"
|
||||
user2.LastName = "User"
|
||||
userRepo2 := repositories.NewUserRepository(db2)
|
||||
require.NoError(t, userRepo2.Update(user2))
|
||||
|
||||
// Build handler against db2
|
||||
cfg := &config.Config{}
|
||||
authService2 := services.NewAuthService(userRepo2, cfg)
|
||||
handler2 := NewAuthHandler(authService2, nil, nil)
|
||||
|
||||
// Set up route with mock auth middleware
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/me/", handler.CurrentUser)
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user2))
|
||||
authGroup.GET("/me/", handler2.CurrentUser)
|
||||
|
||||
_ = handler // avoid unused
|
||||
|
||||
t.Run("get current user", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/auth/me/", nil, "test-token")
|
||||
@@ -242,23 +74,26 @@ func TestAuthHandler_CurrentUser(t *testing.T) {
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "metest", response["username"])
|
||||
assert.Equal(t, "me@test.com", response["email"])
|
||||
assert.Equal(t, "metest2", response["username"])
|
||||
assert.Equal(t, "me2@test.com", response["email"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_UpdateProfile(t *testing.T) {
|
||||
handler, e, userRepo := setupAuthHandler(t)
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "Password123")
|
||||
userRepo.Update(user)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
authService := services.NewAuthService(userRepo, cfg)
|
||||
handler := NewAuthHandler(authService, nil, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.PUT("/profile/", handler.UpdateProfile)
|
||||
|
||||
t.Run("update profile", func(t *testing.T) {
|
||||
t.Run("update first and last name", func(t *testing.T) {
|
||||
firstName := "Updated"
|
||||
lastName := "Name"
|
||||
req := requests.UpdateProfileRequest{
|
||||
@@ -278,130 +113,3 @@ func TestAuthHandler_UpdateProfile(t *testing.T) {
|
||||
assert.Equal(t, "Name", response["last_name"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_ForgotPassword(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/register/", handler.Register)
|
||||
e.POST("/api/auth/forgot-password/", handler.ForgotPassword)
|
||||
|
||||
// Create a test user
|
||||
registerReq := requests.RegisterRequest{
|
||||
Username: "forgottest",
|
||||
Email: "forgot@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "")
|
||||
|
||||
t.Run("forgot password with valid email", func(t *testing.T) {
|
||||
req := requests.ForgotPasswordRequest{
|
||||
Email: "forgot@test.com",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "")
|
||||
|
||||
// Always returns 200 to prevent email enumeration
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
response := testutil.ParseJSON(t, w.Body.Bytes())
|
||||
testutil.AssertJSONFieldExists(t, response, "message")
|
||||
})
|
||||
|
||||
t.Run("forgot password with invalid email", func(t *testing.T) {
|
||||
req := requests.ForgotPasswordRequest{
|
||||
Email: "nonexistent@test.com",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "")
|
||||
|
||||
// Still returns 200 to prevent email enumeration
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_Logout(t *testing.T) {
|
||||
handler, e, userRepo := setupAuthHandler(t)
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "Password123")
|
||||
userRepo.Update(user)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/logout/", handler.Logout)
|
||||
|
||||
t.Run("successful logout", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/logout/", nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
response := testutil.ParseJSON(t, w.Body.Bytes())
|
||||
assert.Contains(t, response["message"], "Logged out successfully")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_JSONResponses(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/register/", handler.Register)
|
||||
e.POST("/api/auth/login/", handler.Login)
|
||||
|
||||
t.Run("register response has correct JSON structure", func(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "jsontest",
|
||||
Email: "json@test.com",
|
||||
Password: "Password123",
|
||||
FirstName: "JSON",
|
||||
LastName: "Test",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify top-level structure
|
||||
assert.Contains(t, response, "token")
|
||||
assert.Contains(t, response, "user")
|
||||
assert.Contains(t, response, "message")
|
||||
|
||||
// Verify token is not empty
|
||||
assert.NotEmpty(t, response["token"])
|
||||
|
||||
// Verify user structure
|
||||
user := response["user"].(map[string]interface{})
|
||||
assert.Contains(t, user, "id")
|
||||
assert.Contains(t, user, "username")
|
||||
assert.Contains(t, user, "email")
|
||||
assert.Contains(t, user, "first_name")
|
||||
assert.Contains(t, user, "last_name")
|
||||
assert.Contains(t, user, "is_active")
|
||||
assert.Contains(t, user, "date_joined")
|
||||
|
||||
// Verify types
|
||||
assert.IsType(t, float64(0), user["id"]) // JSON numbers are float64
|
||||
assert.IsType(t, "", user["username"])
|
||||
assert.IsType(t, "", user["email"])
|
||||
assert.IsType(t, true, user["is_active"])
|
||||
})
|
||||
|
||||
t.Run("error response has correct JSON structure", func(t *testing.T) {
|
||||
req := map[string]string{
|
||||
"username": "test",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
assert.IsType(t, "", response["error"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -506,232 +506,6 @@ func TestTaskHandler_CreateCompletion_NoTaskID(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Auth Handler - Additional Coverage
|
||||
// =============================================================================
|
||||
|
||||
func TestAuthHandler_AppleSignIn_NotConfigured(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/apple-sign-in/", handler.AppleSignIn)
|
||||
|
||||
t.Run("returns 500 when apple auth not configured", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"id_token": "fake-token",
|
||||
"user_id": "fake-user-id",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/apple-sign-in/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
t.Run("missing identity_token returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/apple-sign-in/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_GoogleSignIn_NotConfigured(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/google-sign-in/", handler.GoogleSignIn)
|
||||
|
||||
t.Run("returns 500 when google auth not configured", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"id_token": "fake-token",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/google-sign-in/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
t.Run("missing id_token returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/google-sign-in/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
// setupAuthHandlerWithDB is like setupAuthHandler but also returns the underlying *gorm.DB
|
||||
// for tests that need to create records like ConfirmationCode directly.
|
||||
func setupAuthHandlerWithDB(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret-key",
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
MaxPasswordResetRate: 3,
|
||||
},
|
||||
}
|
||||
authService := services.NewAuthService(userRepo, cfg)
|
||||
handler := NewAuthHandler(authService, nil, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func TestAuthHandler_VerifyEmail(t *testing.T) {
|
||||
handler, e, db := setupAuthHandlerWithDB(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "verifytest", "verify@test.com", "Password123")
|
||||
|
||||
// Create confirmation code
|
||||
confirmCode := &models.ConfirmationCode{
|
||||
UserID: user.ID,
|
||||
Code: "123456",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
IsUsed: false,
|
||||
}
|
||||
require.NoError(t, db.Create(confirmCode).Error)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/verify-email/", handler.VerifyEmail)
|
||||
|
||||
t.Run("successful verification", func(t *testing.T) {
|
||||
req := requests.VerifyEmailRequest{
|
||||
Code: "123456",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, response["verified"])
|
||||
})
|
||||
|
||||
t.Run("wrong code returns error", func(t *testing.T) {
|
||||
req := requests.VerifyEmailRequest{
|
||||
Code: "999999",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token")
|
||||
// Code already used or wrong code
|
||||
assert.True(t, w.Code == http.StatusBadRequest || w.Code == http.StatusNotFound,
|
||||
"expected 400 or 404, got %d", w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing code returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-email/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_ResendVerification(t *testing.T) {
|
||||
handler, e, db := setupAuthHandlerWithDB(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "resendtest", "resend@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/resend-verification/", handler.ResendVerification)
|
||||
|
||||
t.Run("successful resend", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/resend-verification/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "message")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken(t *testing.T) {
|
||||
handler, e, db := setupAuthHandlerWithDB(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "refreshtest", "refresh@test.com", "Password123")
|
||||
|
||||
// Create auth token and use its actual key in the middleware
|
||||
authToken := testutil.CreateTestToken(t, db, user.ID)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", user)
|
||||
c.Set("auth_token", authToken.Plaintext) // raw token — repo hashes for lookup (audit C1)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
authGroup.POST("/refresh/", handler.RefreshToken)
|
||||
|
||||
t.Run("successful refresh", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Plaintext)
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "token")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_VerifyResetCode(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/register/", handler.Register)
|
||||
e.POST("/api/auth/verify-reset-code/", handler.VerifyResetCode)
|
||||
|
||||
t.Run("invalid code returns error", func(t *testing.T) {
|
||||
req := requests.VerifyResetCodeRequest{
|
||||
Email: "nonexistent@test.com",
|
||||
Code: "999999",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-reset-code/", req, "")
|
||||
// Should not be 200 since no valid code exists
|
||||
assert.NotEqual(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing fields returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/verify-reset-code/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_ResetPassword(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/reset-password/", handler.ResetPassword)
|
||||
|
||||
t.Run("invalid reset token returns error", func(t *testing.T) {
|
||||
req := requests.ResetPasswordRequest{
|
||||
ResetToken: "invalid-token",
|
||||
NewPassword: "NewPassword123",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "")
|
||||
assert.NotEqual(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("missing fields returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("short password returns 400", func(t *testing.T) {
|
||||
req := requests.ResetPasswordRequest{
|
||||
ResetToken: "some-token",
|
||||
NewPassword: "short",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/reset-password/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_ForgotPassword_MissingEmail(t *testing.T) {
|
||||
handler, e, _ := setupAuthHandler(t)
|
||||
|
||||
e.POST("/api/auth/forgot-password/", handler.ForgotPassword)
|
||||
|
||||
t.Run("missing email returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/forgot-password/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Residence Handler - Additional Error Paths
|
||||
// =============================================================================
|
||||
|
||||
@@ -190,6 +190,27 @@ func shouldSkipSpecRoute(path string) bool {
|
||||
if strings.HasPrefix(path, "/uploads/") || strings.HasPrefix(path, "/media/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Auth routes delegated to Ory Kratos (phase 2 auth refactor).
|
||||
// These endpoints are no longer served by the Go API; the spec is retained
|
||||
// as documentation of the Kratos-facing contract.
|
||||
kratosRoutes := map[string]bool{
|
||||
"/auth/login/": true,
|
||||
"/auth/register/": true,
|
||||
"/auth/logout/": true,
|
||||
"/auth/refresh/": true,
|
||||
"/auth/forgot-password/": true,
|
||||
"/auth/verify-reset-code/": true,
|
||||
"/auth/reset-password/": true,
|
||||
"/auth/verify-email/": true,
|
||||
"/auth/resend-verification/": true,
|
||||
"/auth/apple-sign-in/": true,
|
||||
"/auth/google-sign-in/": true,
|
||||
}
|
||||
if kratosRoutes[path] {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -17,6 +19,7 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/handlers"
|
||||
"github.com/treytartt/honeydue-api/internal/middleware"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
@@ -105,11 +108,40 @@ type TestApp struct {
|
||||
TaskRepo *repositories.TaskRepository
|
||||
ContractorRepo *repositories.ContractorRepository
|
||||
AuthService *services.AuthService
|
||||
// tokenStore maps fake token strings to users for the test-auth middleware.
|
||||
tokenStore map[string]*models.User
|
||||
tokenStoreMu sync.RWMutex
|
||||
}
|
||||
|
||||
// fakeAuthMiddleware replaces the real Kratos middleware in integration tests.
|
||||
// It looks up the "Authorization: Token <tok>" value in app.tokenStore.
|
||||
func (app *TestApp) fakeAuthMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ah := c.Request().Header.Get("Authorization")
|
||||
if ah == "" {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
tok := ah
|
||||
if len(ah) > 6 && ah[:6] == "Token " {
|
||||
tok = ah[6:]
|
||||
} else if len(ah) > 7 && ah[:7] == "Bearer " {
|
||||
tok = ah[7:]
|
||||
}
|
||||
app.tokenStoreMu.RLock()
|
||||
user, ok := app.tokenStore[tok]
|
||||
app.tokenStoreMu.RUnlock()
|
||||
if !ok {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
c.Set("auth_user", user)
|
||||
c.Set("auth_token", tok)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupIntegrationTest(t *testing.T) *TestApp {
|
||||
// Echo does not need test mode
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
@@ -123,9 +155,6 @@ func setupIntegrationTest(t *testing.T) *TestApp {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret-key-for-integration-tests",
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
MaxPasswordResetRate: 3,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -141,28 +170,33 @@ func setupIntegrationTest(t *testing.T) *TestApp {
|
||||
taskHandler := handlers.NewTaskHandler(taskService, nil)
|
||||
contractorHandler := handlers.NewContractorHandler(contractorService)
|
||||
|
||||
// Create router with real middleware
|
||||
e := echo.New()
|
||||
app := &TestApp{
|
||||
DB: db,
|
||||
Router: echo.New(),
|
||||
AuthHandler: authHandler,
|
||||
ResidenceHandler: residenceHandler,
|
||||
TaskHandler: taskHandler,
|
||||
ContractorHandler: contractorHandler,
|
||||
UserRepo: userRepo,
|
||||
ResidenceRepo: residenceRepo,
|
||||
TaskRepo: taskRepo,
|
||||
ContractorRepo: contractorRepo,
|
||||
AuthService: authService,
|
||||
tokenStore: make(map[string]*models.User),
|
||||
}
|
||||
|
||||
e := app.Router
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
|
||||
|
||||
// Add timezone middleware globally so X-Timezone header is processed
|
||||
// Timezone middleware processes X-Timezone header
|
||||
e.Use(middleware.TimezoneMiddleware())
|
||||
|
||||
// Public routes
|
||||
auth := e.Group("/api/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
}
|
||||
|
||||
// Protected routes - use AuthMiddleware without Redis cache for testing
|
||||
authMiddleware := middleware.NewAuthMiddleware(db, nil)
|
||||
// Protected routes — guarded by the fake token middleware
|
||||
api := e.Group("/api")
|
||||
api.Use(authMiddleware.TokenAuth())
|
||||
api.Use(app.fakeAuthMiddleware())
|
||||
{
|
||||
api.GET("/auth/me", authHandler.CurrentUser)
|
||||
api.POST("/auth/logout", authHandler.Logout)
|
||||
|
||||
residences := api.Group("/residences")
|
||||
{
|
||||
@@ -216,19 +250,7 @@ func setupIntegrationTest(t *testing.T) *TestApp {
|
||||
api.GET("/contractors/by-residence/:residence_id", contractorHandler.ListContractorsByResidence)
|
||||
}
|
||||
|
||||
return &TestApp{
|
||||
DB: db,
|
||||
Router: e,
|
||||
AuthHandler: authHandler,
|
||||
ResidenceHandler: residenceHandler,
|
||||
TaskHandler: taskHandler,
|
||||
ContractorHandler: contractorHandler,
|
||||
UserRepo: userRepo,
|
||||
ResidenceRepo: residenceRepo,
|
||||
TaskRepo: taskRepo,
|
||||
ContractorRepo: contractorRepo,
|
||||
AuthService: authService,
|
||||
}
|
||||
return app
|
||||
}
|
||||
|
||||
// Helper to make authenticated requests
|
||||
@@ -251,156 +273,16 @@ func (app *TestApp) makeAuthenticatedRequest(t *testing.T, method, path string,
|
||||
return w
|
||||
}
|
||||
|
||||
// Helper to register and login a user, returns token
|
||||
func (app *TestApp) registerAndLogin(t *testing.T, username, email, password string) string {
|
||||
// Register
|
||||
registerBody := map[string]string{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
}
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
// Login
|
||||
loginBody := map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var loginResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
return loginResp["token"].(string)
|
||||
}
|
||||
|
||||
// ============ Authentication Flow Tests ============
|
||||
|
||||
func TestIntegration_AuthenticationFlow(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
|
||||
// 1. Register a new user
|
||||
registerBody := map[string]string{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "SecurePass123!",
|
||||
}
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var registerResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), ®isterResp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, registerResp["token"])
|
||||
assert.NotNil(t, registerResp["user"])
|
||||
|
||||
// 2. Login with the same credentials
|
||||
loginBody := map[string]string{
|
||||
"username": "testuser",
|
||||
"password": "SecurePass123!",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var loginResp map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &loginResp)
|
||||
require.NoError(t, err)
|
||||
token := loginResp["token"].(string)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// 3. Get current user with token
|
||||
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var meResp map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &meResp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", meResp["username"])
|
||||
assert.Equal(t, "test@example.com", meResp["email"])
|
||||
|
||||
// 4. Access protected route without token should fail
|
||||
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, "")
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
|
||||
// 5. Access protected route with invalid token should fail
|
||||
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, "invalid-token")
|
||||
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
|
||||
// 6. Logout
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/logout", nil, token)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestIntegration_RegistrationValidation(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body map[string]string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "missing username",
|
||||
body: map[string]string{"email": "test@example.com", "password": "pass123"},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing email",
|
||||
body: map[string]string{"username": "testuser", "password": "pass123"},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
body: map[string]string{"username": "testuser", "email": "test@example.com"},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
body: map[string]string{"username": "testuser", "email": "invalid", "password": "pass123"},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", tt.body, "")
|
||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_DuplicateRegistration(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
|
||||
// Register first user (password must be >= 8 chars)
|
||||
registerBody := map[string]string{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "Password123",
|
||||
}
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
// Try to register with same username - returns 409 (Conflict)
|
||||
registerBody2 := map[string]string{
|
||||
"username": "testuser",
|
||||
"email": "different@example.com",
|
||||
"password": "Password123",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody2, "")
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
|
||||
// Try to register with same email - returns 409 (Conflict)
|
||||
registerBody3 := map[string]string{
|
||||
"username": "differentuser",
|
||||
"email": "test@example.com",
|
||||
"password": "Password123",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody3, "")
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
// registerAndLogin creates a user directly in the DB and returns a synthetic token
|
||||
// that the fake auth middleware will accept. No HTTP register/login endpoints are called.
|
||||
func (app *TestApp) registerAndLogin(t *testing.T, username, email, _ string) string {
|
||||
t.Helper()
|
||||
user := testutil.CreateTestUser(t, app.DB, username, email, "")
|
||||
tok := uuid.NewString()
|
||||
app.tokenStoreMu.Lock()
|
||||
app.tokenStore[tok] = user
|
||||
app.tokenStoreMu.Unlock()
|
||||
return tok
|
||||
}
|
||||
|
||||
// ============ Residence Flow Tests ============
|
||||
@@ -827,48 +709,16 @@ func TestIntegration_ResponseStructure(t *testing.T) {
|
||||
func TestIntegration_ComprehensiveE2E(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
|
||||
// ============ Phase 1: Authentication ============
|
||||
t.Log("Phase 1: Testing authentication flow")
|
||||
// ============ Phase 1: User Setup ============
|
||||
t.Log("Phase 1: Setting up test user")
|
||||
|
||||
// Register new user
|
||||
registerBody := map[string]string{
|
||||
"username": "e2e_testuser",
|
||||
"email": "e2e@example.com",
|
||||
"password": "SecurePass123!",
|
||||
}
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
|
||||
require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed")
|
||||
|
||||
var registerResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), ®isterResp)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, registerResp["token"], "Registration should return token")
|
||||
assert.NotNil(t, registerResp["user"], "Registration should return user")
|
||||
|
||||
// Verify login with same credentials
|
||||
loginBody := map[string]string{
|
||||
"username": "e2e_testuser",
|
||||
"password": "SecurePass123!",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
|
||||
require.Equal(t, http.StatusOK, w.Code, "Login should succeed")
|
||||
|
||||
var loginResp map[string]interface{}
|
||||
err = json.Unmarshal(w.Body.Bytes(), &loginResp)
|
||||
require.NoError(t, err)
|
||||
token := loginResp["token"].(string)
|
||||
assert.NotEmpty(t, token, "Login should return token")
|
||||
token := app.registerAndLogin(t, "e2e_testuser", "e2e@example.com", "")
|
||||
|
||||
// Verify authenticated access
|
||||
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token)
|
||||
w := app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, token)
|
||||
require.Equal(t, http.StatusOK, w.Code, "Should access protected route with valid token")
|
||||
|
||||
var meResp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &meResp)
|
||||
assert.Equal(t, "e2e_testuser", meResp["username"])
|
||||
assert.Equal(t, "e2e@example.com", meResp["email"])
|
||||
|
||||
t.Log("✓ Authentication flow verified")
|
||||
t.Log("✓ User setup verified")
|
||||
|
||||
// ============ Phase 2: Create 5 Residences ============
|
||||
t.Log("Phase 2: Creating 5 residences")
|
||||
@@ -1244,29 +1094,9 @@ func TestIntegration_ComprehensiveE2E(t *testing.T) {
|
||||
t.Logf("✓ All %d visible tasks verified in correct columns by ID", expectedVisibleTasks)
|
||||
|
||||
// ============ Phase 9: Create User B ============
|
||||
t.Log("Phase 9: Creating User B and verifying login")
|
||||
t.Log("Phase 9: Creating User B")
|
||||
|
||||
// Register User B
|
||||
registerBodyB := map[string]string{
|
||||
"username": "e2e_userb",
|
||||
"email": "e2e_userb@example.com",
|
||||
"password": "SecurePass456!",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBodyB, "")
|
||||
require.Equal(t, http.StatusCreated, w.Code, "User B registration should succeed")
|
||||
|
||||
// Login as User B
|
||||
loginBodyB := map[string]string{
|
||||
"username": "e2e_userb",
|
||||
"password": "SecurePass456!",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBodyB, "")
|
||||
require.Equal(t, http.StatusOK, w.Code, "User B login should succeed")
|
||||
|
||||
var loginRespB map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &loginRespB)
|
||||
tokenB := loginRespB["token"].(string)
|
||||
assert.NotEmpty(t, tokenB, "User B should have a token")
|
||||
tokenB := app.registerAndLogin(t, "e2e_userb", "e2e_userb@example.com", "")
|
||||
|
||||
// Verify User B can access their own profile
|
||||
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB)
|
||||
@@ -1592,8 +1422,6 @@ func formatID(id float64) string {
|
||||
|
||||
// setupContractorTest sets up a test environment including contractor routes
|
||||
func setupContractorTest(t *testing.T) *TestApp {
|
||||
// Echo does not need test mode
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
@@ -1607,9 +1435,6 @@ func setupContractorTest(t *testing.T) *TestApp {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret-key-for-integration-tests",
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
MaxPasswordResetRate: 3,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1625,29 +1450,32 @@ func setupContractorTest(t *testing.T) *TestApp {
|
||||
taskHandler := handlers.NewTaskHandler(taskService, nil)
|
||||
contractorHandler := handlers.NewContractorHandler(contractorService)
|
||||
|
||||
// Create router with real middleware
|
||||
e := echo.New()
|
||||
app := &TestApp{
|
||||
DB: db,
|
||||
Router: echo.New(),
|
||||
AuthHandler: authHandler,
|
||||
ResidenceHandler: residenceHandler,
|
||||
TaskHandler: taskHandler,
|
||||
ContractorHandler: contractorHandler,
|
||||
UserRepo: userRepo,
|
||||
ResidenceRepo: residenceRepo,
|
||||
TaskRepo: taskRepo,
|
||||
ContractorRepo: contractorRepo,
|
||||
AuthService: authService,
|
||||
tokenStore: make(map[string]*models.User),
|
||||
}
|
||||
|
||||
e := app.Router
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
|
||||
|
||||
// Add timezone middleware globally so X-Timezone header is processed
|
||||
// Timezone middleware
|
||||
e.Use(middleware.TimezoneMiddleware())
|
||||
|
||||
// Public routes
|
||||
auth := e.Group("/api/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
}
|
||||
|
||||
// Protected routes
|
||||
authMiddleware := middleware.NewAuthMiddleware(db, nil)
|
||||
api := e.Group("/api")
|
||||
api.Use(authMiddleware.TokenAuth())
|
||||
api.Use(app.fakeAuthMiddleware())
|
||||
{
|
||||
api.GET("/auth/me", authHandler.CurrentUser)
|
||||
api.POST("/auth/logout", authHandler.Logout)
|
||||
|
||||
residences := api.Group("/residences")
|
||||
{
|
||||
residences.GET("", residenceHandler.ListResidences)
|
||||
@@ -1680,19 +1508,7 @@ func setupContractorTest(t *testing.T) *TestApp {
|
||||
}
|
||||
}
|
||||
|
||||
return &TestApp{
|
||||
DB: db,
|
||||
Router: e,
|
||||
AuthHandler: authHandler,
|
||||
ResidenceHandler: residenceHandler,
|
||||
TaskHandler: taskHandler,
|
||||
ContractorHandler: contractorHandler,
|
||||
UserRepo: userRepo,
|
||||
ResidenceRepo: residenceRepo,
|
||||
TaskRepo: taskRepo,
|
||||
ContractorRepo: contractorRepo,
|
||||
AuthService: authService,
|
||||
}
|
||||
return app
|
||||
}
|
||||
|
||||
// ============ Test 1: Recurring Task Lifecycle ============
|
||||
@@ -2045,12 +1861,12 @@ func TestIntegration_MultiUserSharing(t *testing.T) {
|
||||
// Phase 9: Remove User B from residence 3
|
||||
t.Log("Phase 9: Remove User B from residence 3")
|
||||
|
||||
// Get User B's ID
|
||||
w = app.makeAuthenticatedRequest(t, "GET", "/api/auth/me", nil, tokenB)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
var userBInfo map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &userBInfo)
|
||||
userBID := uint(userBInfo["id"].(float64))
|
||||
// Get User B's ID from the token store
|
||||
app.tokenStoreMu.RLock()
|
||||
userBModel := app.tokenStore[tokenB]
|
||||
app.tokenStoreMu.RUnlock()
|
||||
require.NotNil(t, userBModel, "User B should be in token store")
|
||||
userBID := userBModel.ID
|
||||
|
||||
// Remove User B from residence 3
|
||||
w = app.makeAuthenticatedRequest(t, "DELETE", fmt.Sprintf("/api/residences/%d/users/%d", residenceIDs[2], userBID), nil, tokenA)
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -35,6 +37,48 @@ type SecurityTestApp struct {
|
||||
Router *echo.Echo
|
||||
SubscriptionService *services.SubscriptionService
|
||||
SubscriptionRepo *repositories.SubscriptionRepository
|
||||
tokenStore map[string]*models.User
|
||||
tokenStoreMu sync.RWMutex
|
||||
}
|
||||
|
||||
// fakeAuthMiddleware returns an Echo middleware that authenticates requests using
|
||||
// the in-process tokenStore instead of calling the real Kratos session endpoint.
|
||||
func (app *SecurityTestApp) fakeAuthMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ah := c.Request().Header.Get("Authorization")
|
||||
if ah == "" {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
tok := ah
|
||||
if len(ah) > 6 && ah[:6] == "Token " {
|
||||
tok = ah[6:]
|
||||
} else if len(ah) > 7 && ah[:7] == "Bearer " {
|
||||
tok = ah[7:]
|
||||
}
|
||||
app.tokenStoreMu.RLock()
|
||||
user, ok := app.tokenStore[tok]
|
||||
app.tokenStoreMu.RUnlock()
|
||||
if !ok {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
c.Set("auth_user", user)
|
||||
c.Set("auth_token", tok)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerAndLoginSec creates a user directly in the DB and returns a fake token
|
||||
// that the fakeAuthMiddleware will accept. No HTTP register/login calls are made.
|
||||
func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, _ string) (string, uint) {
|
||||
t.Helper()
|
||||
user := testutil.CreateTestUser(t, app.DB, username, email, "")
|
||||
tok := uuid.NewString()
|
||||
app.tokenStoreMu.Lock()
|
||||
app.tokenStore[tok] = user
|
||||
app.tokenStoreMu.Unlock()
|
||||
return tok, user.ID
|
||||
}
|
||||
|
||||
func setupSecurityTest(t *testing.T) *SecurityTestApp {
|
||||
@@ -78,27 +122,25 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
|
||||
notificationHandler := handlers.NewNotificationHandler(notificationService)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
|
||||
|
||||
// Create router with real middleware
|
||||
app := &SecurityTestApp{
|
||||
DB: db,
|
||||
SubscriptionService: subscriptionService,
|
||||
SubscriptionRepo: subscriptionRepo,
|
||||
tokenStore: make(map[string]*models.User),
|
||||
}
|
||||
|
||||
// Create router with fake auth middleware
|
||||
e := echo.New()
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
|
||||
|
||||
e.Use(middleware.TimezoneMiddleware())
|
||||
|
||||
// Public routes
|
||||
auth := e.Group("/api/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
}
|
||||
|
||||
// Protected routes
|
||||
authMiddleware := middleware.NewAuthMiddleware(db, nil)
|
||||
api := e.Group("/api")
|
||||
api.Use(authMiddleware.TokenAuth())
|
||||
api.Use(app.fakeAuthMiddleware())
|
||||
{
|
||||
api.GET("/auth/me", authHandler.CurrentUser)
|
||||
api.POST("/auth/logout", authHandler.Logout)
|
||||
|
||||
residences := api.Group("/residences")
|
||||
{
|
||||
@@ -146,42 +188,8 @@ func setupSecurityTest(t *testing.T) *SecurityTestApp {
|
||||
}
|
||||
}
|
||||
|
||||
return &SecurityTestApp{
|
||||
DB: db,
|
||||
Router: e,
|
||||
SubscriptionService: subscriptionService,
|
||||
SubscriptionRepo: subscriptionRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// registerAndLoginSec registers and logs in a user, returns token and user ID.
|
||||
func (app *SecurityTestApp) registerAndLoginSec(t *testing.T, username, email, password string) (string, uint) {
|
||||
// Register
|
||||
registerBody := map[string]string{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
}
|
||||
w := app.makeAuthReq(t, "POST", "/api/auth/register", registerBody, "")
|
||||
require.Equal(t, http.StatusCreated, w.Code, "Registration should succeed for %s", username)
|
||||
|
||||
// Login
|
||||
loginBody := map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
w = app.makeAuthReq(t, "POST", "/api/auth/login", loginBody, "")
|
||||
require.Equal(t, http.StatusOK, w.Code, "Login should succeed for %s", username)
|
||||
|
||||
var loginResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := loginResp["token"].(string)
|
||||
userMap := loginResp["user"].(map[string]interface{})
|
||||
userID := uint(userMap["id"].(float64))
|
||||
|
||||
return token, userID
|
||||
app.Router = e
|
||||
return app
|
||||
}
|
||||
|
||||
// makeAuthReq creates and sends an HTTP request through the router.
|
||||
|
||||
@@ -6,12 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
@@ -22,7 +25,6 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
"github.com/treytartt/honeydue-api/internal/validator"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SubscriptionTestApp holds components for subscription integration testing
|
||||
@@ -31,11 +33,51 @@ type SubscriptionTestApp struct {
|
||||
Router *echo.Echo
|
||||
SubscriptionService *services.SubscriptionService
|
||||
SubscriptionRepo *repositories.SubscriptionRepository
|
||||
tokenStore map[string]*models.User
|
||||
tokenStoreMu sync.RWMutex
|
||||
}
|
||||
|
||||
// fakeAuthMiddleware returns an Echo middleware that authenticates requests using
|
||||
// the in-process tokenStore instead of calling the real Kratos session endpoint.
|
||||
func (app *SubscriptionTestApp) fakeAuthMiddleware() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ah := c.Request().Header.Get("Authorization")
|
||||
if ah == "" {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
tok := ah
|
||||
if len(ah) > 6 && ah[:6] == "Token " {
|
||||
tok = ah[6:]
|
||||
} else if len(ah) > 7 && ah[:7] == "Bearer " {
|
||||
tok = ah[7:]
|
||||
}
|
||||
app.tokenStoreMu.RLock()
|
||||
user, ok := app.tokenStore[tok]
|
||||
app.tokenStoreMu.RUnlock()
|
||||
if !ok {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
c.Set("auth_user", user)
|
||||
c.Set("auth_token", tok)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerAndLogin creates a user directly in the DB and returns a fake token
|
||||
// and user ID. No HTTP register/login calls are made.
|
||||
func (app *SubscriptionTestApp) registerAndLogin(t *testing.T, username, email, _ string) (string, uint) {
|
||||
t.Helper()
|
||||
user := testutil.CreateTestUser(t, app.DB, username, email, "")
|
||||
tok := uuid.NewString()
|
||||
app.tokenStoreMu.Lock()
|
||||
app.tokenStore[tok] = user
|
||||
app.tokenStoreMu.Unlock()
|
||||
return tok, user.ID
|
||||
}
|
||||
|
||||
func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
|
||||
// Echo does not need test mode
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
@@ -67,22 +109,23 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
|
||||
residenceHandler := handlers.NewResidenceHandler(residenceService, nil, nil, true)
|
||||
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService, nil)
|
||||
|
||||
// Create router
|
||||
app := &SubscriptionTestApp{
|
||||
DB: db,
|
||||
SubscriptionService: subscriptionService,
|
||||
SubscriptionRepo: subscriptionRepo,
|
||||
tokenStore: make(map[string]*models.User),
|
||||
}
|
||||
|
||||
// Create router with fake auth middleware
|
||||
e := echo.New()
|
||||
e.Validator = validator.NewCustomValidator()
|
||||
e.HTTPErrorHandler = apperrors.HTTPErrorHandler
|
||||
|
||||
// Public routes
|
||||
auth := e.Group("/api/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
}
|
||||
e.Use(middleware.TimezoneMiddleware())
|
||||
|
||||
// Protected routes
|
||||
authMiddleware := middleware.NewAuthMiddleware(db, nil)
|
||||
api := e.Group("/api")
|
||||
api.Use(authMiddleware.TokenAuth())
|
||||
api.Use(app.fakeAuthMiddleware())
|
||||
{
|
||||
api.GET("/auth/me", authHandler.CurrentUser)
|
||||
|
||||
@@ -98,12 +141,8 @@ func setupSubscriptionTest(t *testing.T) *SubscriptionTestApp {
|
||||
}
|
||||
}
|
||||
|
||||
return &SubscriptionTestApp{
|
||||
DB: db,
|
||||
Router: e,
|
||||
SubscriptionService: subscriptionService,
|
||||
SubscriptionRepo: subscriptionRepo,
|
||||
}
|
||||
app.Router = e
|
||||
return app
|
||||
}
|
||||
|
||||
// Helper to make authenticated requests
|
||||
@@ -129,36 +168,6 @@ func (app *SubscriptionTestApp) makeAuthenticatedRequest(t *testing.T, method, p
|
||||
return w
|
||||
}
|
||||
|
||||
// Helper to register and login a user, returns token and user ID
|
||||
func (app *SubscriptionTestApp) registerAndLogin(t *testing.T, username, email, password string) (string, uint) {
|
||||
// Register
|
||||
registerBody := map[string]string{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
}
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
|
||||
require.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
// Login
|
||||
loginBody := map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/login", loginBody, "")
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var loginResp map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &loginResp)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := loginResp["token"].(string)
|
||||
userMap := loginResp["user"].(map[string]interface{})
|
||||
userID := uint(userMap["id"].(float64))
|
||||
|
||||
return token, userID
|
||||
}
|
||||
|
||||
// TestIntegration_IsFreeBypassesLimitations tests that users with IsFree=true
|
||||
// see limitations_enabled=false regardless of global settings
|
||||
func TestIntegration_IsFreeBypassesLimitations(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
// Package kratos is a thin client for the Ory Kratos public API. honeyDue
|
||||
// delegates all identity concerns (credentials, sessions, verification,
|
||||
// recovery, social sign-in) to Kratos; this client only validates sessions.
|
||||
package kratos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrUnauthorized is returned by Whoami when the session is missing, invalid,
|
||||
// or inactive — the caller should respond 401.
|
||||
var ErrUnauthorized = errors.New("kratos: session invalid or inactive")
|
||||
|
||||
// Client talks to the Ory Kratos public API.
|
||||
type Client struct {
|
||||
publicURL string
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewClient builds a Kratos client for the given public-API base URL
|
||||
// (e.g. http://kratos:4433 in-cluster).
|
||||
func NewClient(publicURL string) *Client {
|
||||
return &Client{
|
||||
publicURL: strings.TrimRight(publicURL, "/"),
|
||||
http: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// Identity is the subset of a Kratos identity honeyDue consumes. It mirrors
|
||||
// the identity schema in deploy-k3s/manifests/kratos/configmap.yaml.
|
||||
type Identity struct {
|
||||
ID string `json:"id"` // UUID — the stable identity identifier
|
||||
Traits struct {
|
||||
Email string `json:"email"`
|
||||
Name struct {
|
||||
First string `json:"first"`
|
||||
Last string `json:"last"`
|
||||
} `json:"name"`
|
||||
} `json:"traits"`
|
||||
VerifiableAddresses []struct {
|
||||
Value string `json:"value"`
|
||||
Verified bool `json:"verified"`
|
||||
} `json:"verifiable_addresses"`
|
||||
}
|
||||
|
||||
// Session is a Kratos session as returned by GET /sessions/whoami.
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
Active bool `json:"active"`
|
||||
Identity Identity `json:"identity"`
|
||||
}
|
||||
|
||||
// EmailVerified reports whether any of the identity's email addresses is
|
||||
// verified — the source of truth for honeyDue's RequireVerified gate.
|
||||
func (s *Session) EmailVerified() bool {
|
||||
for _, a := range s.Identity.VerifiableAddresses {
|
||||
if a.Verified {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Whoami validates a session against Kratos. Supply the mobile session token
|
||||
// (sessionToken) OR the browser cookie header (cookie) — whichever is
|
||||
// non-empty is forwarded to Kratos. Returns ErrUnauthorized for an invalid or
|
||||
// inactive session.
|
||||
func (c *Client) Whoami(ctx context.Context, sessionToken, cookie string) (*Session, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.publicURL+"/sessions/whoami", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sessionToken != "" {
|
||||
req.Header.Set("X-Session-Token", sessionToken)
|
||||
}
|
||||
if cookie != "" {
|
||||
req.Header.Set("Cookie", cookie)
|
||||
}
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kratos whoami: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized, resp.StatusCode == http.StatusForbidden:
|
||||
return nil, ErrUnauthorized
|
||||
case resp.StatusCode != http.StatusOK:
|
||||
return nil, fmt.Errorf("kratos whoami: unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var s Session
|
||||
if err := json.NewDecoder(resp.Body).Decode(&s); err != nil {
|
||||
return nil, fmt.Errorf("kratos whoami: decode: %w", err)
|
||||
}
|
||||
if !s.Active || s.Identity.ID == "" {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
@@ -1,438 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthUserKey is the key used to store the authenticated user in the context
|
||||
AuthUserKey = "auth_user"
|
||||
// AuthTokenKey is the key used to store the token in the context
|
||||
AuthTokenKey = "auth_token"
|
||||
// TokenCacheTTL is the duration to cache tokens in Redis. Tokens are
|
||||
// valid for DefaultTokenExpiryDays (90), and explicit logout invalidates
|
||||
// the cache, so a long TTL here just means most authed requests skip the
|
||||
// auth-token SQL query entirely.
|
||||
TokenCacheTTL = 1 * time.Hour
|
||||
// TokenCachePrefix is the prefix for token cache keys
|
||||
TokenCachePrefix = "auth_token_"
|
||||
// UserCacheTTL is how long full user records are cached in memory to
|
||||
// avoid hitting the database on every authenticated request. Bumped from
|
||||
// 30s — at 30s the trace showed a SELECT auth_user query on most warm
|
||||
// requests because users aren't in cache long enough to hit twice.
|
||||
UserCacheTTL = 5 * time.Minute
|
||||
// UserCacheMaxSize bounds the per-pod in-memory user cache. With ~1KB
|
||||
// per User struct, 5000 entries = ~5MB per pod. Older entries are
|
||||
// evicted LRU before the limit is exceeded.
|
||||
UserCacheMaxSize = 5000
|
||||
|
||||
// DefaultTokenExpiryDays is the default number of days before a token expires.
|
||||
DefaultTokenExpiryDays = 90
|
||||
)
|
||||
|
||||
// AuthMiddleware provides token authentication middleware
|
||||
type AuthMiddleware struct {
|
||||
db *gorm.DB
|
||||
cache *services.CacheService
|
||||
userCache *UserCache
|
||||
tokenExpiryDays int
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware instance
|
||||
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
|
||||
tokenExpiryDays: DefaultTokenExpiryDays,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthMiddlewareWithConfig creates a new auth middleware instance with configuration
|
||||
func NewAuthMiddlewareWithConfig(db *gorm.DB, cache *services.CacheService, cfg *config.Config) *AuthMiddleware {
|
||||
expiryDays := DefaultTokenExpiryDays
|
||||
if cfg != nil && cfg.Security.TokenExpiryDays > 0 {
|
||||
expiryDays = cfg.Security.TokenExpiryDays
|
||||
}
|
||||
return &AuthMiddleware{
|
||||
db: db,
|
||||
cache: cache,
|
||||
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
|
||||
tokenExpiryDays: expiryDays,
|
||||
}
|
||||
}
|
||||
|
||||
// TokenExpiryDuration returns the token expiry duration.
|
||||
func (m *AuthMiddleware) TokenExpiryDuration() time.Duration {
|
||||
return time.Duration(m.tokenExpiryDays) * 24 * time.Hour
|
||||
}
|
||||
|
||||
// isTokenExpired checks if a token's created timestamp indicates expiry.
|
||||
func (m *AuthMiddleware) isTokenExpired(created time.Time) bool {
|
||||
if created.IsZero() {
|
||||
return false // Legacy tokens without created time are not expired
|
||||
}
|
||||
return time.Since(created) > m.TokenExpiryDuration()
|
||||
}
|
||||
|
||||
// TokenAuth returns an Echo middleware that validates token authentication
|
||||
func (m *AuthMiddleware) TokenAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Extract token from Authorization header
|
||||
token, err := extractToken(c)
|
||||
if err != nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
|
||||
// Try to get user from cache first (includes expiry check)
|
||||
user, err := m.getUserFromCache(c.Request().Context(), token)
|
||||
if err == nil && user != nil {
|
||||
// Cache hit - set user in context and continue
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Check if the cache indicated token expiry
|
||||
if err != nil && err.Error() == "token expired" {
|
||||
return apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Cache miss - look up token in database
|
||||
user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("token", truncateToken(token)).Msg("Token authentication failed")
|
||||
return apperrors.Unauthorized("error.invalid_token")
|
||||
}
|
||||
|
||||
// Check token expiry
|
||||
if m.isTokenExpired(authToken.Created) {
|
||||
log.Debug().Str("token", truncateToken(token)).Time("created", authToken.Created).Msg("Token expired")
|
||||
return apperrors.Unauthorized("error.token_expired")
|
||||
}
|
||||
|
||||
// Cache the user ID and token creation time for future requests
|
||||
if cacheErr := m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created); cacheErr != nil {
|
||||
log.Warn().Err(cacheErr).Msg("Failed to cache token info")
|
||||
}
|
||||
|
||||
// Set user in context
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it
|
||||
func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
token, err := extractToken(c)
|
||||
if err != nil {
|
||||
// No token or invalid format - continue without user
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try cache first
|
||||
user, err := m.getUserFromCache(c.Request().Context(), token)
|
||||
if err == nil && user != nil {
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try database
|
||||
user, authToken, err := m.getUserFromDatabaseWithToken(c.Request().Context(), token)
|
||||
if err == nil && !m.isTokenExpired(authToken.Created) {
|
||||
m.cacheTokenInfo(c.Request().Context(), token, user.ID, authToken.Created)
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, token)
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireVerified returns middleware that rejects users whose email is not
|
||||
// verified (audit LIVE-L19). Apply it after TokenAuth to gate sensitive
|
||||
// actions — e.g. generating residence share codes — behind proof that the
|
||||
// account actually controls its email address.
|
||||
func (m *AuthMiddleware) RequireVerified() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
var verified bool
|
||||
err := m.db.WithContext(c.Request().Context()).
|
||||
Model(&models.UserProfile{}).
|
||||
Where("user_id = ?", user.ID).
|
||||
Select("verified").
|
||||
Scan(&verified).Error
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
if !verified {
|
||||
return apperrors.Forbidden("error.email_not_verified")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractToken extracts the token from the Authorization header
|
||||
func extractToken(c echo.Context) (string, error) {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return "", fmt.Errorf("authorization header required")
|
||||
}
|
||||
|
||||
// Support both "Token xxx" (Django style) and "Bearer xxx" formats
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid authorization header format")
|
||||
}
|
||||
|
||||
scheme := parts[0]
|
||||
token := parts[1]
|
||||
|
||||
if scheme != "Token" && scheme != "Bearer" {
|
||||
return "", fmt.Errorf("invalid authorization scheme: %s", scheme)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return "", fmt.Errorf("token is empty")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// getUserFromCache tries to get user from Redis cache, then from the
|
||||
// in-memory user cache, before falling back to the database.
|
||||
// Returns a "token expired" error if the cached creation time indicates expiry.
|
||||
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
|
||||
if m.cache == nil {
|
||||
return nil, fmt.Errorf("cache not available")
|
||||
}
|
||||
|
||||
userID, createdUnix, err := m.cache.GetCachedAuthTokenWithCreated(ctx, token)
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, fmt.Errorf("token not in cache")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check token expiry from cached creation time
|
||||
if createdUnix > 0 {
|
||||
created := time.Unix(createdUnix, 0)
|
||||
if m.isTokenExpired(created) {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
// Try in-memory user cache first to avoid a DB round-trip
|
||||
if cached := m.userCache.Get(userID); cached != nil {
|
||||
if !cached.IsActive {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
m.userCache.Invalidate(userID)
|
||||
return nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// In-memory cache miss — fetch from database
|
||||
var user models.User
|
||||
if err := m.db.WithContext(ctx).First(&user, userID).Error; err != nil {
|
||||
// User was deleted - invalidate caches
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive {
|
||||
m.cache.InvalidateAuthToken(ctx, token)
|
||||
return nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
|
||||
// Store in in-memory cache for subsequent requests
|
||||
m.userCache.Set(&user)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// getUserFromDatabaseWithToken looks up the token in the database and returns
|
||||
// both the user and the auth token record (for expiry checking). The ctx is
|
||||
// threaded into the GORM session so the SQL span attaches to the request trace.
|
||||
//
|
||||
// Uses a single JOIN query instead of GORM's Preload (which issues 2 SELECTs).
|
||||
// Over a transatlantic link this saves ~110ms RTT per cache miss.
|
||||
func (m *AuthMiddleware) getUserFromDatabaseWithToken(ctx context.Context, token string) (*models.User, *models.AuthToken, error) {
|
||||
// Flat result row: every column from auth_user prefixed `u_`, every
|
||||
// column from user_authtoken left in its native shape. Mapping to two
|
||||
// structs is mechanical so we don't need a struct tag soup.
|
||||
type joinedRow struct {
|
||||
// AuthToken columns
|
||||
Key string `gorm:"column:key"`
|
||||
Created time.Time `gorm:"column:created"`
|
||||
UserID uint `gorm:"column:user_id"`
|
||||
// User columns (prefixed to avoid collision with UserID)
|
||||
UID uint `gorm:"column:u_id"`
|
||||
UUsername string `gorm:"column:u_username"`
|
||||
UEmail string `gorm:"column:u_email"`
|
||||
UFirstName string `gorm:"column:u_first_name"`
|
||||
ULastName string `gorm:"column:u_last_name"`
|
||||
UPassword string `gorm:"column:u_password"`
|
||||
UIsActive bool `gorm:"column:u_is_active"`
|
||||
UIsStaff bool `gorm:"column:u_is_staff"`
|
||||
UIsSuper bool `gorm:"column:u_is_superuser"`
|
||||
UDateJoined time.Time `gorm:"column:u_date_joined"`
|
||||
ULastLogin *time.Time `gorm:"column:u_last_login"`
|
||||
}
|
||||
|
||||
var row joinedRow
|
||||
err := m.db.WithContext(ctx).
|
||||
Table("user_authtoken AS t").
|
||||
Select(`
|
||||
t.key, t.created, t.user_id,
|
||||
u.id AS u_id,
|
||||
u.username AS u_username,
|
||||
u.email AS u_email,
|
||||
u.first_name AS u_first_name,
|
||||
u.last_name AS u_last_name,
|
||||
u.password AS u_password,
|
||||
u.is_active AS u_is_active,
|
||||
u.is_staff AS u_is_staff,
|
||||
u.is_superuser AS u_is_superuser,
|
||||
u.date_joined AS u_date_joined,
|
||||
u.last_login AS u_last_login
|
||||
`).
|
||||
Joins("INNER JOIN auth_user u ON u.id = t.user_id").
|
||||
Where("t.key = ?", models.HashToken(token)). // audit C1: only the hash is stored
|
||||
Limit(1).
|
||||
Scan(&row).Error
|
||||
if err != nil || row.Key == "" {
|
||||
return nil, nil, fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
user := models.User{
|
||||
ID: row.UID,
|
||||
Username: row.UUsername,
|
||||
Email: row.UEmail,
|
||||
FirstName: row.UFirstName,
|
||||
LastName: row.ULastName,
|
||||
Password: row.UPassword,
|
||||
IsActive: row.UIsActive,
|
||||
IsStaff: row.UIsStaff,
|
||||
IsSuperuser: row.UIsSuper,
|
||||
DateJoined: row.UDateJoined,
|
||||
LastLogin: row.ULastLogin,
|
||||
}
|
||||
authToken := models.AuthToken{
|
||||
Key: row.Key,
|
||||
Created: row.Created,
|
||||
UserID: row.UserID,
|
||||
User: user,
|
||||
}
|
||||
|
||||
if !user.IsActive {
|
||||
return nil, nil, fmt.Errorf("user is inactive")
|
||||
}
|
||||
|
||||
m.userCache.Set(&user)
|
||||
return &user, &authToken, nil
|
||||
}
|
||||
|
||||
// getUserFromDatabase looks up the token in the database and caches the
|
||||
// resulting user record in memory.
|
||||
// Deprecated: Use getUserFromDatabaseWithToken for new code paths that need expiry checking.
|
||||
func (m *AuthMiddleware) getUserFromDatabase(ctx context.Context, token string) (*models.User, error) {
|
||||
user, _, err := m.getUserFromDatabaseWithToken(ctx, token)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// cacheTokenInfo caches the user ID and token creation time for a token
|
||||
func (m *AuthMiddleware) cacheTokenInfo(ctx context.Context, token string, userID uint, created time.Time) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.CacheAuthTokenWithCreated(ctx, token, userID, created.Unix())
|
||||
}
|
||||
|
||||
// cacheUserID caches the user ID for a token
|
||||
func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.CacheAuthToken(ctx, token, userID)
|
||||
}
|
||||
|
||||
// InvalidateToken removes a token from the cache
|
||||
func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error {
|
||||
if m.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return m.cache.InvalidateAuthToken(ctx, token)
|
||||
}
|
||||
|
||||
// GetAuthUser retrieves the authenticated user from the Echo context.
|
||||
// Returns nil if the context value is missing or not of the expected type.
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
val := c.Get(AuthUserKey)
|
||||
if val == nil {
|
||||
return nil
|
||||
}
|
||||
user, ok := val.(*models.User)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// GetAuthToken retrieves the auth token from the Echo context
|
||||
func GetAuthToken(c echo.Context) string {
|
||||
token := c.Get(AuthTokenKey)
|
||||
if token == nil {
|
||||
return ""
|
||||
}
|
||||
tokenStr, ok := token.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tokenStr
|
||||
}
|
||||
|
||||
// MustGetAuthUser retrieves the authenticated user or returns error with 401
|
||||
func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return nil, apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// truncateToken safely truncates a token string for logging.
|
||||
// Returns at most the first 8 characters followed by "...".
|
||||
func truncateToken(token string) string {
|
||||
if len(token) > 8 {
|
||||
return token[:8] + "..."
|
||||
}
|
||||
return token + "..."
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// setupTestDB creates a temporary in-memory SQLite database with the required
|
||||
// tables for auth middleware tests.
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.AutoMigrate(&models.User{}, &models.AuthToken{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// createTestUserAndToken creates a user and an auth token, then backdates the
|
||||
// token's Created timestamp by the specified number of days.
|
||||
func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays int) (*models.User, *models.AuthToken) {
|
||||
t.Helper()
|
||||
|
||||
user := &models.User{
|
||||
Username: username,
|
||||
Email: username + "@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
token := &models.AuthToken{
|
||||
UserID: user.ID,
|
||||
}
|
||||
require.NoError(t, db.Create(token).Error)
|
||||
|
||||
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||
token.Created = backdated
|
||||
|
||||
return user, token
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_user", 91) // 91 days old > 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil) // No Redis cache for these tests
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestTokenAuth_AcceptsValidToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "valid_user", 30) // 30 days old < 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify user was set in context
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "valid_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "boundary_user", 89) // 89 days old, just under 90 day expiry
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsInvalidToken(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestTokenAuth_RejectsNoAuthHeader(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "bearer_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "bearer_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "scheme_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Basic "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "JustATokenWithNoScheme")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_EmptyToken_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token ")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
user, token := createTestUserAndToken(t, db, "inactive_user", 10)
|
||||
|
||||
// Deactivate the user
|
||||
require.NoError(t, db.Model(user).Update("is_active", false).Error)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
// No Authorization header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "opt_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "opt_user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_opt_user", 91)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
assert.NotNil(t, m)
|
||||
assert.Equal(t, 30, m.tokenExpiryDays)
|
||||
|
||||
// Token at 29 days should be valid
|
||||
_, token := createTestUserAndToken(t, db, "short_expiry_user", 29)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
|
||||
// Token at 31 days should be expired with 30-day config
|
||||
_, token := createTestUserAndToken(t, db, "custom_expired_user", 31)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, nil)
|
||||
assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays)
|
||||
}
|
||||
|
||||
func TestGetAuthToken_ReturnsToken(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, "test-token-value")
|
||||
assert.Equal(t, "test-token-value", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No token set
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, 12345) // Wrong type
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
// Legacy tokens without created time should not be expired
|
||||
assert.False(t, m.isTokenExpired(models.AuthToken{}.Created))
|
||||
}
|
||||
|
||||
func TestInvalidateToken_NilCache_NoError(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil) // nil cache
|
||||
|
||||
err := m.InvalidateToken(nil, "some-token")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/kratos"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuthUserKey stores the authenticated *models.User in the echo context.
|
||||
AuthUserKey = "auth_user"
|
||||
// AuthTokenKey stores the raw session credential in the echo context.
|
||||
AuthTokenKey = "auth_token"
|
||||
// authVerifiedKey stores the Kratos email-verified flag in the context.
|
||||
authVerifiedKey = "auth_email_verified"
|
||||
|
||||
// UserCacheTTL / UserCacheMaxSize bound the in-memory local-user cache.
|
||||
UserCacheTTL = 5 * time.Minute
|
||||
UserCacheMaxSize = 5000
|
||||
|
||||
// kratosSessionCacheTTL is how long a validated session is cached in
|
||||
// Redis, so most authed requests skip the Kratos /whoami round trip.
|
||||
kratosSessionCacheTTL = 5 * time.Minute
|
||||
kratosSessionPrefix = "kratos_sess:"
|
||||
)
|
||||
|
||||
// KratosAuth authenticates requests against an Ory Kratos session. It
|
||||
// replaces the hand-rolled token auth: the session is validated via Kratos
|
||||
// /sessions/whoami (Redis-cached), and the matching local auth_user row is
|
||||
// lazily provisioned on first sight of a Kratos identity.
|
||||
type KratosAuth struct {
|
||||
kratos *kratos.Client
|
||||
cache *services.CacheService
|
||||
db *gorm.DB
|
||||
userCache *UserCache
|
||||
}
|
||||
|
||||
// NewKratosAuth builds the Kratos auth middleware.
|
||||
func NewKratosAuth(k *kratos.Client, cache *services.CacheService, db *gorm.DB) *KratosAuth {
|
||||
return &KratosAuth{
|
||||
kratos: k,
|
||||
cache: cache,
|
||||
db: db,
|
||||
userCache: NewUserCache(UserCacheTTL, UserCacheMaxSize),
|
||||
}
|
||||
}
|
||||
|
||||
// Authenticate validates the Kratos session and requires it.
|
||||
func (m *KratosAuth) Authenticate() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user, verified, cred, err := m.resolve(c)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msg("Kratos authentication failed")
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, cred)
|
||||
c.Set(authVerifiedKey, verified)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuthenticate authenticates if a session is present, else continues
|
||||
// unauthenticated.
|
||||
func (m *KratosAuth) OptionalAuthenticate() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if user, verified, cred, err := m.resolve(c); err == nil {
|
||||
c.Set(AuthUserKey, user)
|
||||
c.Set(AuthTokenKey, cred)
|
||||
c.Set(authVerifiedKey, verified)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RequireVerified rejects users whose Kratos email address is not verified.
|
||||
// Apply after Authenticate.
|
||||
func (m *KratosAuth) RequireVerified() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if GetAuthUser(c) == nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
if verified, _ := c.Get(authVerifiedKey).(bool); !verified {
|
||||
return apperrors.Forbidden("error.email_not_verified")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolve validates the request's session and returns the local user.
|
||||
func (m *KratosAuth) resolve(c echo.Context) (*models.User, bool, string, error) {
|
||||
token, cookie := extractSession(c)
|
||||
if token == "" && cookie == "" {
|
||||
return nil, false, "", errors.New("no session credential")
|
||||
}
|
||||
cred := token
|
||||
if cred == "" {
|
||||
cred = cookie
|
||||
}
|
||||
ctx := c.Request().Context()
|
||||
|
||||
// Redis cache: kratos_sess:<hash(cred)> -> "<userID>|<0|1>"
|
||||
cacheKey := kratosSessionPrefix + hashCredential(cred)
|
||||
if m.cache != nil {
|
||||
if v, err := m.cache.GetString(ctx, cacheKey); err == nil && v != "" {
|
||||
if user, verified, ok := m.userFromCacheValue(ctx, v); ok {
|
||||
return user, verified, cred, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sess, err := m.kratos.Whoami(ctx, token, cookie)
|
||||
if err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
user, err := m.provision(ctx, sess)
|
||||
if err != nil {
|
||||
return nil, false, "", err
|
||||
}
|
||||
if m.cache != nil {
|
||||
_ = m.cache.SetString(ctx, cacheKey,
|
||||
fmt.Sprintf("%d|%s", user.ID, boolDigit(sess.EmailVerified())), kratosSessionCacheTTL)
|
||||
}
|
||||
return user, sess.EmailVerified(), cred, nil
|
||||
}
|
||||
|
||||
// provision finds the local auth_user row for a Kratos identity, creating it
|
||||
// (and a UserProfile) on first sight. Concurrent first requests are handled
|
||||
// by re-reading after a unique-constraint conflict.
|
||||
func (m *KratosAuth) provision(ctx context.Context, sess *kratos.Session) (*models.User, error) {
|
||||
var user models.User
|
||||
err := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error
|
||||
if err == nil {
|
||||
return &user, nil
|
||||
}
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user = models.User{
|
||||
KratosID: sess.Identity.ID,
|
||||
Email: sess.Identity.Traits.Email,
|
||||
Username: sess.Identity.Traits.Email,
|
||||
FirstName: sess.Identity.Traits.Name.First,
|
||||
LastName: sess.Identity.Traits.Name.Last,
|
||||
IsActive: true,
|
||||
DateJoined: time.Now().UTC(),
|
||||
}
|
||||
txErr := m.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Create(&models.UserProfile{
|
||||
UserID: user.ID,
|
||||
Verified: sess.EmailVerified(),
|
||||
}).Error
|
||||
})
|
||||
if txErr != nil {
|
||||
// Likely a concurrent provision of the same identity — re-read.
|
||||
if e := m.db.WithContext(ctx).Where("kratos_id = ?", sess.Identity.ID).First(&user).Error; e == nil {
|
||||
return &user, nil
|
||||
}
|
||||
return nil, txErr
|
||||
}
|
||||
log.Info().Str("kratos_id", sess.Identity.ID).Uint("user_id", user.ID).
|
||||
Msg("provisioned local user from Kratos identity")
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// userFromCacheValue resolves a cached "<userID>|<0|1>" value to a user.
|
||||
func (m *KratosAuth) userFromCacheValue(ctx context.Context, v string) (*models.User, bool, bool) {
|
||||
parts := strings.SplitN(v, "|", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, false, false
|
||||
}
|
||||
var id uint
|
||||
if _, err := fmt.Sscanf(parts[0], "%d", &id); err != nil || id == 0 {
|
||||
return nil, false, false
|
||||
}
|
||||
verified := parts[1] == "1"
|
||||
if cached := m.userCache.Get(id); cached != nil {
|
||||
return cached, verified, true
|
||||
}
|
||||
var user models.User
|
||||
if err := m.db.WithContext(ctx).First(&user, id).Error; err != nil {
|
||||
return nil, false, false
|
||||
}
|
||||
m.userCache.Set(&user)
|
||||
return &user, verified, true
|
||||
}
|
||||
|
||||
// extractSession pulls the session credential from the request: the
|
||||
// X-Session-Token header or Authorization bearer (mobile clients), or the
|
||||
// ory_kratos_session cookie (web).
|
||||
func extractSession(c echo.Context) (token, cookie string) {
|
||||
if t := c.Request().Header.Get("X-Session-Token"); t != "" {
|
||||
token = t
|
||||
} else if ah := c.Request().Header.Get("Authorization"); ah != "" {
|
||||
parts := strings.SplitN(ah, " ", 2)
|
||||
if len(parts) == 2 && (parts[0] == "Bearer" || parts[0] == "Token") {
|
||||
token = parts[1]
|
||||
}
|
||||
}
|
||||
if token == "" {
|
||||
if ck := c.Request().Header.Get("Cookie"); strings.Contains(ck, "ory_kratos_session") {
|
||||
cookie = ck
|
||||
}
|
||||
}
|
||||
return token, cookie
|
||||
}
|
||||
|
||||
func hashCredential(cred string) string {
|
||||
sum := sha256.Sum256([]byte(cred))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func boolDigit(b bool) string {
|
||||
if b {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
// truncateToken returns the first 8 characters of a credential followed by
|
||||
// "..." for safe inclusion in log lines.
|
||||
func truncateToken(tok string) string {
|
||||
if len(tok) <= 8 {
|
||||
return tok + "..."
|
||||
}
|
||||
return tok[:8] + "..."
|
||||
}
|
||||
|
||||
// GetAuthUser retrieves the authenticated user from the echo context.
|
||||
func GetAuthUser(c echo.Context) *models.User {
|
||||
user, _ := c.Get(AuthUserKey).(*models.User)
|
||||
return user
|
||||
}
|
||||
|
||||
// GetAuthToken retrieves the session credential from the echo context.
|
||||
func GetAuthToken(c echo.Context) string {
|
||||
tok, _ := c.Get(AuthTokenKey).(string)
|
||||
return tok
|
||||
}
|
||||
|
||||
// MustGetAuthUser retrieves the authenticated user or returns a 401 error.
|
||||
func MustGetAuthUser(c echo.Context) (*models.User, error) {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return nil, apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func setupModelsTestDB(t *testing.T) *gorm.DB {
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&User{}, &AuthToken{}, &UserProfile{})
|
||||
err = db.AutoMigrate(&User{}, &UserProfile{})
|
||||
require.NoError(t, err)
|
||||
return db
|
||||
}
|
||||
@@ -233,105 +233,6 @@ func TestNotificationType_Constants(t *testing.T) {
|
||||
assert.Equal(t, NotificationType("warranty_expiring"), NotificationWarrantyExpiring)
|
||||
}
|
||||
|
||||
// === AuthToken model tests ===
|
||||
|
||||
func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "tokenuser",
|
||||
Email: "token@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
token := &AuthToken{UserID: user.ID}
|
||||
err = db.Create(token).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Len(t, token.Key, 64) // SHA-256 hex hash (audit C1)
|
||||
assert.Len(t, token.Plaintext, 40) // raw 20-byte token, returned to the client
|
||||
assert.False(t, token.Created.IsZero())
|
||||
}
|
||||
|
||||
func TestAuthToken_BeforeCreate_PreservesExistingKey(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "tokenuser",
|
||||
Email: "token@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
existingKey := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
token := &AuthToken{
|
||||
Key: existingKey,
|
||||
UserID: user.ID,
|
||||
}
|
||||
err = db.Create(token).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, existingKey, token.Key)
|
||||
}
|
||||
|
||||
func TestGetOrCreateToken_CreatesNew(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "newtoken",
|
||||
Email: "newtoken@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Equal(t, user.ID, token.UserID)
|
||||
}
|
||||
|
||||
func TestGetOrCreateToken_ReturnsExisting(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "existingtoken",
|
||||
Email: "existingtoken@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
token1, err := GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
token2, err := GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, token1.Key, token2.Key)
|
||||
}
|
||||
|
||||
// === User model additional tests ===
|
||||
|
||||
func TestUser_SetPassword_And_CheckPassword_Integration(t *testing.T) {
|
||||
user := &User{}
|
||||
err := user.SetPassword("Password123")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, user.CheckPassword("Password123"))
|
||||
assert.False(t, user.CheckPassword("WrongPassword"))
|
||||
assert.False(t, user.CheckPassword(""))
|
||||
assert.False(t, user.CheckPassword("password123")) // case sensitive
|
||||
}
|
||||
|
||||
// === Task model additional tests ===
|
||||
|
||||
func TestTask_IsOverdue_CancelledNotOverdue(t *testing.T) {
|
||||
@@ -565,31 +466,6 @@ func TestGetDefaultProLimits(t *testing.T) {
|
||||
assert.Nil(t, limits.DocumentsLimit)
|
||||
}
|
||||
|
||||
// === ConfirmationCode additional tests ===
|
||||
|
||||
func TestConfirmationCode_TableName(t *testing.T) {
|
||||
cc := ConfirmationCode{}
|
||||
assert.Equal(t, "user_confirmationcode", cc.TableName())
|
||||
}
|
||||
|
||||
// === PasswordResetCode additional tests ===
|
||||
|
||||
func TestPasswordResetCode_TableName(t *testing.T) {
|
||||
prc := PasswordResetCode{}
|
||||
assert.Equal(t, "user_passwordresetcode", prc.TableName())
|
||||
}
|
||||
|
||||
// === Social Auth TableName tests ===
|
||||
|
||||
func TestAppleSocialAuth_TableName(t *testing.T) {
|
||||
a := AppleSocialAuth{}
|
||||
assert.Equal(t, "user_applesocialauth", a.TableName())
|
||||
}
|
||||
|
||||
func TestGoogleSocialAuth_TableName(t *testing.T) {
|
||||
g := GoogleSocialAuth{}
|
||||
assert.Equal(t, "user_googlesocialauth", g.TableName())
|
||||
}
|
||||
|
||||
// === BaseModel tests ===
|
||||
|
||||
|
||||
+16
-253
@@ -1,69 +1,38 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
import "time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User represents the auth_user table (Django's default User model)
|
||||
// User represents the auth_user table. Identity — credentials, email
|
||||
// verification, sessions, social sign-in — is owned by Ory Kratos (phase 2).
|
||||
// This row is honeyDue's local mirror of a Kratos identity, linked by
|
||||
// KratosID; every domain table keeps its existing integer FK to auth_user.id.
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Password string `gorm:"column:password;size:128;not null" json:"-"`
|
||||
LastLogin *time.Time `gorm:"column:last_login" json:"last_login,omitempty"`
|
||||
IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"`
|
||||
Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"`
|
||||
KratosID string `gorm:"column:kratos_id;uniqueIndex;size:36" json:"-"` // Kratos identity UUID
|
||||
Username string `gorm:"column:username;uniqueIndex;size:150" json:"username"`
|
||||
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
|
||||
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
|
||||
Email string `gorm:"column:email;uniqueIndex;size:254" json:"email"`
|
||||
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
|
||||
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
|
||||
IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"`
|
||||
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
|
||||
LastLogin *time.Time `gorm:"column:last_login" json:"last_login,omitempty"`
|
||||
|
||||
// Relations (not stored in auth_user table)
|
||||
// Relations — not columns on auth_user.
|
||||
Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"`
|
||||
AuthToken *AuthToken `gorm:"foreignKey:UserID" json:"-"`
|
||||
OwnedResidences []Residence `gorm:"foreignKey:OwnerID" json:"-"`
|
||||
SharedResidences []Residence `gorm:"many2many:residence_residence_users;" json:"-"`
|
||||
NotificationPref *NotificationPreference `gorm:"foreignKey:UserID" json:"-"`
|
||||
Subscription *UserSubscription `gorm:"foreignKey:UserID" json:"-"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
// TableName returns the table name for GORM.
|
||||
func (User) TableName() string {
|
||||
return "auth_user"
|
||||
}
|
||||
|
||||
// BcryptCost is the bcrypt work factor for password and code hashing.
|
||||
// 12 (audit M2) is stronger than bcrypt.DefaultCost (10).
|
||||
const BcryptCost = 12
|
||||
|
||||
// SetPassword hashes and sets the password
|
||||
func (u *User) SetPassword(password string) error {
|
||||
// Django uses PBKDF2_SHA256 by default, but we use bcrypt for Go.
|
||||
// Passwords set by Django won't verify with Go's bcrypt check — those
|
||||
// users must reset their password after migration.
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Password = string(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPassword verifies a password against the stored hash
|
||||
func (u *User) CheckPassword(password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GetFullName returns the user's full name
|
||||
// GetFullName returns the user's display name.
|
||||
func (u *User) GetFullName() string {
|
||||
if u.FirstName != "" && u.LastName != "" {
|
||||
return u.FirstName + " " + u.LastName
|
||||
@@ -74,80 +43,9 @@ func (u *User) GetFullName() string {
|
||||
return u.Username
|
||||
}
|
||||
|
||||
// AuthToken represents the user_authtoken table.
|
||||
//
|
||||
// Audit C1: the Key column stores the SHA-256 hash of the token, never the
|
||||
// token itself. The raw token is handed to the client exactly once, at
|
||||
// creation, via the non-persisted Plaintext field — it is never stored or
|
||||
// logged. A database compromise therefore yields no usable session tokens.
|
||||
type AuthToken struct {
|
||||
Key string `gorm:"column:key;primaryKey;size:64" json:"-"`
|
||||
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
|
||||
Created time.Time `gorm:"column:created;autoCreateTime" json:"created"`
|
||||
|
||||
// Plaintext is the raw token value. It is never persisted (gorm:"-")
|
||||
// and is only populated on a freshly-created token so the caller can
|
||||
// return it to the client. On a token loaded from the DB it is "".
|
||||
Plaintext string `gorm:"-" json:"-"`
|
||||
|
||||
// Relations
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (AuthToken) TableName() string {
|
||||
return "user_authtoken"
|
||||
}
|
||||
|
||||
// BeforeCreate generates a token if one is not already set, storing only
|
||||
// its hash in Key and the raw value in the non-persisted Plaintext field.
|
||||
func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
|
||||
if t.Key == "" {
|
||||
raw := generateToken()
|
||||
t.Plaintext = raw
|
||||
t.Key = HashToken(raw)
|
||||
}
|
||||
if t.Created.IsZero() {
|
||||
t.Created = time.Now().UTC()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateToken creates a random 40-character hex token (the raw value).
|
||||
func generateToken() string {
|
||||
b := make([]byte, 20)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// HashToken returns the at-rest representation of an auth token: the
|
||||
// hex-encoded SHA-256 hash. Auth tokens are 160-bit random values, so a
|
||||
// fast deterministic hash is appropriate — there is nothing to brute-force,
|
||||
// and determinism preserves the single indexed-lookup query in the auth
|
||||
// middleware. The raw token is never stored.
|
||||
func HashToken(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// GetOrCreate gets an existing token or creates a new one for the user
|
||||
func GetOrCreateToken(tx *gorm.DB, userID uint) (*AuthToken, error) {
|
||||
var token AuthToken
|
||||
result := tx.Where("user_id = ?", userID).First(&token)
|
||||
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
token = AuthToken{UserID: userID}
|
||||
if err := tx.Create(&token).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// UserProfile represents the user_userprofile table
|
||||
// UserProfile represents the user_userprofile table — honeyDue-specific
|
||||
// profile data, keyed to a local user. Email-verification state is owned by
|
||||
// Kratos; the Verified column is a convenience mirror set at provision time.
|
||||
type UserProfile struct {
|
||||
BaseModel
|
||||
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
|
||||
@@ -161,142 +59,7 @@ type UserProfile struct {
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
// TableName returns the table name for GORM.
|
||||
func (UserProfile) TableName() string {
|
||||
return "user_userprofile"
|
||||
}
|
||||
|
||||
// ConfirmationCode represents the user_confirmationcode table
|
||||
type ConfirmationCode struct {
|
||||
BaseModel
|
||||
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
|
||||
Code string `gorm:"column:code;size:6;not null" json:"-"`
|
||||
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
|
||||
IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"`
|
||||
|
||||
// Relations
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (ConfirmationCode) TableName() string {
|
||||
return "user_confirmationcode"
|
||||
}
|
||||
|
||||
// IsValid checks if the confirmation code is still valid
|
||||
func (c *ConfirmationCode) IsValid() bool {
|
||||
return !c.IsUsed && time.Now().UTC().Before(c.ExpiresAt)
|
||||
}
|
||||
|
||||
// GenerateConfirmationCode creates a uniformly-random 6-digit code using
|
||||
// rejection sampling on crypto/rand (audit H4 — removes the modulo bias of
|
||||
// the previous implementation).
|
||||
func GenerateConfirmationCode() string {
|
||||
for {
|
||||
var b [4]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
continue
|
||||
}
|
||||
// 4294000000 is the largest multiple of 1e6 <= MaxUint32; rejecting
|
||||
// the tail above it makes n % 1000000 perfectly uniform.
|
||||
n := binary.BigEndian.Uint32(b[:])
|
||||
if n < 4294000000 {
|
||||
return fmt.Sprintf("%06d", n%1000000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PasswordResetCode represents the user_passwordresetcode table
|
||||
type PasswordResetCode struct {
|
||||
BaseModel
|
||||
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
|
||||
CodeHash string `gorm:"column:code_hash;size:128;not null" json:"-"`
|
||||
ResetToken string `gorm:"column:reset_token;uniqueIndex;size:64;not null" json:"reset_token"`
|
||||
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
|
||||
Used bool `gorm:"column:used;default:false" json:"used"`
|
||||
Attempts int `gorm:"column:attempts;default:0" json:"attempts"`
|
||||
MaxAttempts int `gorm:"column:max_attempts;default:5" json:"max_attempts"`
|
||||
|
||||
// Relations
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (PasswordResetCode) TableName() string {
|
||||
return "user_passwordresetcode"
|
||||
}
|
||||
|
||||
// SetCode hashes and stores the reset code
|
||||
func (p *PasswordResetCode) SetCode(code string) error {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(code), BcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.CodeHash = string(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckCode verifies a code against the stored hash
|
||||
func (p *PasswordResetCode) CheckCode(code string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(p.CodeHash), []byte(code))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsValid checks if the reset code is still valid
|
||||
func (p *PasswordResetCode) IsValid() bool {
|
||||
return !p.Used && time.Now().UTC().Before(p.ExpiresAt) && p.Attempts < p.MaxAttempts
|
||||
}
|
||||
|
||||
// IncrementAttempts increments the attempt counter
|
||||
func (p *PasswordResetCode) IncrementAttempts(tx *gorm.DB) error {
|
||||
p.Attempts++
|
||||
return tx.Model(p).Update("attempts", p.Attempts).Error
|
||||
}
|
||||
|
||||
// MarkAsUsed marks the code as used
|
||||
func (p *PasswordResetCode) MarkAsUsed(tx *gorm.DB) error {
|
||||
p.Used = true
|
||||
return tx.Model(p).Update("used", true).Error
|
||||
}
|
||||
|
||||
// GenerateResetToken creates a URL-safe token
|
||||
func GenerateResetToken() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// AppleSocialAuth represents a user's linked Apple ID for Sign in with Apple
|
||||
type AppleSocialAuth struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"uniqueIndex;not null" json:"user_id"`
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
AppleID string `gorm:"column:apple_id;size:255;uniqueIndex;not null" json:"apple_id"` // Apple's unique subject ID
|
||||
Email string `gorm:"column:email;size:254" json:"email"` // May be private relay
|
||||
IsPrivateEmail bool `gorm:"column:is_private_email;default:false" json:"is_private_email"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (AppleSocialAuth) TableName() string {
|
||||
return "user_applesocialauth"
|
||||
}
|
||||
|
||||
// GoogleSocialAuth represents a user's linked Google account for Sign in with Google
|
||||
type GoogleSocialAuth struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"uniqueIndex;not null" json:"user_id"`
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
GoogleID string `gorm:"column:google_id;size:255;uniqueIndex;not null" json:"google_id"` // Google's unique subject ID
|
||||
Email string `gorm:"column:email;size:254" json:"email"`
|
||||
Name string `gorm:"column:name;size:255" json:"name"`
|
||||
Picture string `gorm:"column:picture;size:512" json:"picture"` // Profile picture URL
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for GORM
|
||||
func (GoogleSocialAuth) TableName() string {
|
||||
return "user_googlesocialauth"
|
||||
}
|
||||
|
||||
@@ -2,45 +2,10 @@ package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUser_SetPassword(t *testing.T) {
|
||||
user := &User{}
|
||||
|
||||
err := user.SetPassword("testPassword123")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, user.Password)
|
||||
assert.NotEqual(t, "testPassword123", user.Password) // Should be hashed
|
||||
}
|
||||
|
||||
func TestUser_CheckPassword(t *testing.T) {
|
||||
user := &User{}
|
||||
err := user.SetPassword("correctpassword")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
expected bool
|
||||
}{
|
||||
{"correct password", "correctpassword", true},
|
||||
{"wrong password", "wrongpassword", false},
|
||||
{"empty password", "", false},
|
||||
{"similar password", "correctpassword1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := user.CheckPassword(tt.password)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUser_GetFullName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -82,136 +47,7 @@ func TestUser_TableName(t *testing.T) {
|
||||
assert.Equal(t, "auth_user", user.TableName())
|
||||
}
|
||||
|
||||
func TestAuthToken_TableName(t *testing.T) {
|
||||
token := AuthToken{}
|
||||
assert.Equal(t, "user_authtoken", token.TableName())
|
||||
}
|
||||
|
||||
func TestUserProfile_TableName(t *testing.T) {
|
||||
profile := UserProfile{}
|
||||
assert.Equal(t, "user_userprofile", profile.TableName())
|
||||
}
|
||||
|
||||
func TestConfirmationCode_IsValid(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(1 * time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code ConfirmationCode
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid code",
|
||||
code: ConfirmationCode{IsUsed: false, ExpiresAt: future},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "used code",
|
||||
code: ConfirmationCode{IsUsed: true, ExpiresAt: future},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "expired code",
|
||||
code: ConfirmationCode{IsUsed: false, ExpiresAt: past},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "used and expired",
|
||||
code: ConfirmationCode{IsUsed: true, ExpiresAt: past},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.code.IsValid()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetCode_IsValid(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(1 * time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code PasswordResetCode
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "valid code",
|
||||
code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 0, MaxAttempts: 5},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "used code",
|
||||
code: PasswordResetCode{Used: true, ExpiresAt: future, Attempts: 0, MaxAttempts: 5},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "expired code",
|
||||
code: PasswordResetCode{Used: false, ExpiresAt: past, Attempts: 0, MaxAttempts: 5},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "max attempts reached",
|
||||
code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 5, MaxAttempts: 5},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "attempts under max",
|
||||
code: PasswordResetCode{Used: false, ExpiresAt: future, Attempts: 4, MaxAttempts: 5},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.code.IsValid()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetCode_SetAndCheckCode(t *testing.T) {
|
||||
code := &PasswordResetCode{}
|
||||
|
||||
err := code.SetCode("123456")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, code.CodeHash)
|
||||
|
||||
// Check correct code
|
||||
assert.True(t, code.CheckCode("123456"))
|
||||
|
||||
// Check wrong code
|
||||
assert.False(t, code.CheckCode("654321"))
|
||||
assert.False(t, code.CheckCode(""))
|
||||
}
|
||||
|
||||
func TestGenerateConfirmationCode(t *testing.T) {
|
||||
code := GenerateConfirmationCode()
|
||||
assert.Len(t, code, 6)
|
||||
|
||||
// Generate multiple codes and ensure they're different
|
||||
codes := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
c := GenerateConfirmationCode()
|
||||
assert.Len(t, c, 6)
|
||||
codes[c] = true
|
||||
}
|
||||
// Most codes should be unique (very unlikely to have collisions)
|
||||
assert.Greater(t, len(codes), 5)
|
||||
}
|
||||
|
||||
func TestGenerateResetToken(t *testing.T) {
|
||||
token := GenerateResetToken()
|
||||
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
|
||||
|
||||
// Ensure uniqueness
|
||||
token2 := GenerateResetToken()
|
||||
assert.NotEqual(t, token, token2)
|
||||
}
|
||||
|
||||
@@ -11,18 +11,21 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// FindByKratosID finds a user by Kratos identity UUID.
|
||||
func (r *UserRepository) FindByKratosID(kratosID string) (*models.User, error) {
|
||||
var user models.User
|
||||
if err := r.db.Where("kratos_id = ?", kratosID).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
ErrCodeNotFound = errors.New("code not found")
|
||||
ErrCodeExpired = errors.New("code expired")
|
||||
ErrCodeUsed = errors.New("code already used")
|
||||
ErrTooManyAttempts = errors.New("too many attempts")
|
||||
ErrRateLimitExceeded = errors.New("rate limit exceeded")
|
||||
ErrAppleAuthNotFound = errors.New("apple social auth not found")
|
||||
ErrGoogleAuthNotFound = errors.New("google social auth not found")
|
||||
)
|
||||
|
||||
// UserRepository handles user-related database operations
|
||||
@@ -145,111 +148,6 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// --- Auth Token Methods ---
|
||||
|
||||
// GetOrCreateToken gets or creates an auth token for a user.
|
||||
// Wrapped in a transaction to prevent race conditions where two
|
||||
// concurrent requests could create duplicate tokens for the same user.
|
||||
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
|
||||
var token models.AuthToken
|
||||
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Where("user_id = ?", userID).First(&token)
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
token = models.AuthToken{UserID: userID}
|
||||
if err := tx.Create(&token).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// FindTokenByKey looks up an auth token by its raw key value. The raw token
|
||||
// is hashed (audit C1) before the indexed lookup, since only the hash is
|
||||
// stored.
|
||||
func (r *UserRepository) FindTokenByKey(rawKey string) (*models.AuthToken, error) {
|
||||
var token models.AuthToken
|
||||
if err := r.db.Where("key = ?", models.HashToken(rawKey)).First(&token).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new auth token for a user.
|
||||
func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
|
||||
token := models.AuthToken{UserID: userID}
|
||||
if err := r.db.Create(&token).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// CreateFreshToken issues a new auth token for the user, replacing any
|
||||
// existing one. Because tokens are stored hashed (audit C1) the server
|
||||
// cannot re-issue a previously-minted token's plaintext, so every login
|
||||
// mints a fresh token. The returned token's Plaintext field carries the
|
||||
// raw value to hand to the client; it is never persisted.
|
||||
//
|
||||
// It also returns the stored hashes of the token rows it deleted, so the
|
||||
// caller can evict those entries from the Redis token cache (audit MEDIUM-1).
|
||||
// Without that, a prior (e.g. stolen) token keeps authenticating via a cache
|
||||
// hit for up to the cache TTL even though its DB row is gone.
|
||||
func (r *UserRepository) CreateFreshToken(userID uint) (*models.AuthToken, []string, error) {
|
||||
var token models.AuthToken
|
||||
var oldHashes []string
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var old []models.AuthToken
|
||||
if err := tx.Where("user_id = ?", userID).Find(&old).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
oldHashes = make([]string, 0, len(old))
|
||||
for i := range old {
|
||||
if old[i].Key != "" {
|
||||
oldHashes = append(oldHashes, old[i].Key)
|
||||
}
|
||||
}
|
||||
if err := tx.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
token = models.AuthToken{UserID: userID}
|
||||
return tx.Create(&token).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &token, oldHashes, nil
|
||||
}
|
||||
|
||||
// DeleteToken deletes an auth token by its raw key value. The raw token is
|
||||
// hashed (audit C1) before the lookup, since only the hash is stored.
|
||||
func (r *UserRepository) DeleteToken(token string) error {
|
||||
result := r.db.Where("key = ?", models.HashToken(token)).Delete(&models.AuthToken{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTokenByUserID deletes an auth token by user ID
|
||||
func (r *UserRepository) DeleteTokenByUserID(userID uint) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error
|
||||
}
|
||||
|
||||
// --- User Profile Methods ---
|
||||
|
||||
@@ -280,146 +178,6 @@ func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error {
|
||||
return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error
|
||||
}
|
||||
|
||||
// --- Confirmation Code Methods ---
|
||||
|
||||
// CreateConfirmationCode creates a new confirmation code
|
||||
func (r *UserRepository) CreateConfirmationCode(userID uint, code string, expiresAt time.Time) (*models.ConfirmationCode, error) {
|
||||
// Invalidate any existing unused codes for this user
|
||||
r.db.Model(&models.ConfirmationCode{}).
|
||||
Where("user_id = ? AND is_used = ?", userID, false).
|
||||
Update("is_used", true)
|
||||
|
||||
confirmCode := &models.ConfirmationCode{
|
||||
UserID: userID,
|
||||
Code: code,
|
||||
ExpiresAt: expiresAt,
|
||||
IsUsed: false,
|
||||
}
|
||||
|
||||
if err := r.db.Create(confirmCode).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return confirmCode, nil
|
||||
}
|
||||
|
||||
// FindConfirmationCode finds a valid confirmation code for a user
|
||||
func (r *UserRepository) FindConfirmationCode(userID uint, code string) (*models.ConfirmationCode, error) {
|
||||
var confirmCode models.ConfirmationCode
|
||||
if err := r.db.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false).
|
||||
First(&confirmCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !confirmCode.IsValid() {
|
||||
if confirmCode.IsUsed {
|
||||
return nil, ErrCodeUsed
|
||||
}
|
||||
return nil, ErrCodeExpired
|
||||
}
|
||||
|
||||
return &confirmCode, nil
|
||||
}
|
||||
|
||||
// MarkConfirmationCodeUsed marks a confirmation code as used
|
||||
func (r *UserRepository) MarkConfirmationCodeUsed(codeID uint) error {
|
||||
return r.db.Model(&models.ConfirmationCode{}).Where("id = ?", codeID).Update("is_used", true).Error
|
||||
}
|
||||
|
||||
// --- Password Reset Code Methods ---
|
||||
|
||||
// CreatePasswordResetCode creates a new password reset code
|
||||
func (r *UserRepository) CreatePasswordResetCode(userID uint, codeHash string, resetToken string, expiresAt time.Time) (*models.PasswordResetCode, error) {
|
||||
// Invalidate any existing unused codes for this user
|
||||
r.db.Model(&models.PasswordResetCode{}).
|
||||
Where("user_id = ? AND used = ?", userID, false).
|
||||
Update("used", true)
|
||||
|
||||
resetCode := &models.PasswordResetCode{
|
||||
UserID: userID,
|
||||
CodeHash: codeHash,
|
||||
ResetToken: resetToken,
|
||||
ExpiresAt: expiresAt,
|
||||
Used: false,
|
||||
Attempts: 0,
|
||||
MaxAttempts: 5,
|
||||
}
|
||||
|
||||
if err := r.db.Create(resetCode).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resetCode, nil
|
||||
}
|
||||
|
||||
// FindPasswordResetCode finds a password reset code by email and checks validity
|
||||
func (r *UserRepository) FindPasswordResetCodeByEmail(email string) (*models.PasswordResetCode, *models.User, error) {
|
||||
user, err := r.FindByEmail(email)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var resetCode models.PasswordResetCode
|
||||
if err := r.db.Where("user_id = ? AND used = ?", user.ID, false).
|
||||
Order("created_at DESC").
|
||||
First(&resetCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, ErrCodeNotFound
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &resetCode, user, nil
|
||||
}
|
||||
|
||||
// FindPasswordResetCodeByToken finds a password reset code by reset token
|
||||
func (r *UserRepository) FindPasswordResetCodeByToken(resetToken string) (*models.PasswordResetCode, error) {
|
||||
var resetCode models.PasswordResetCode
|
||||
if err := r.db.Where("reset_token = ?", resetToken).First(&resetCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resetCode.IsValid() {
|
||||
if resetCode.Used {
|
||||
return nil, ErrCodeUsed
|
||||
}
|
||||
if resetCode.Attempts >= resetCode.MaxAttempts {
|
||||
return nil, ErrTooManyAttempts
|
||||
}
|
||||
return nil, ErrCodeExpired
|
||||
}
|
||||
|
||||
return &resetCode, nil
|
||||
}
|
||||
|
||||
// IncrementResetCodeAttempts increments the attempt counter
|
||||
func (r *UserRepository) IncrementResetCodeAttempts(codeID uint) error {
|
||||
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).
|
||||
Update("attempts", gorm.Expr("attempts + 1")).Error
|
||||
}
|
||||
|
||||
// MarkPasswordResetCodeUsed marks a password reset code as used
|
||||
func (r *UserRepository) MarkPasswordResetCodeUsed(codeID uint) error {
|
||||
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).Update("used", true).Error
|
||||
}
|
||||
|
||||
// CountRecentPasswordResetRequests counts reset requests in the last hour
|
||||
func (r *UserRepository) CountRecentPasswordResetRequests(userID uint) (int64, error) {
|
||||
var count int64
|
||||
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
|
||||
if err := r.db.Model(&models.PasswordResetCode{}).
|
||||
Where("user_id = ? AND created_at > ?", userID, oneHourAgo).
|
||||
Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// --- Search Methods ---
|
||||
|
||||
@@ -576,27 +334,11 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
// --- Auth Provider Detection ---
|
||||
|
||||
// FindAuthProvider determines the auth provider for a user.
|
||||
// Returns "apple", "google", or "email".
|
||||
func (r *UserRepository) FindAuthProvider(userID uint) (string, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return "apple", nil
|
||||
}
|
||||
|
||||
if err := r.db.Model(&models.GoogleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return "google", nil
|
||||
}
|
||||
|
||||
return "email", nil
|
||||
// FindAuthProvider returns "kratos" for all Kratos-managed users (the sole
|
||||
// provider after the Ory Kratos migration). Kept for compatibility with
|
||||
// callers that still check the provider string.
|
||||
func (r *UserRepository) FindAuthProvider(_ uint) (string, error) {
|
||||
return "kratos", nil
|
||||
}
|
||||
|
||||
// --- Account Deletion ---
|
||||
@@ -721,35 +463,12 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 8. Social auth records
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.AppleSocialAuth{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.GoogleSocialAuth{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 9. Confirmation codes
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.ConfirmationCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 10. Password reset codes
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.PasswordResetCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 11. Auth tokens
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 12. User profile
|
||||
// 8. User profile
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 13. User
|
||||
// 9. User
|
||||
if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -765,53 +484,6 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
|
||||
return cleanURLs, nil
|
||||
}
|
||||
|
||||
// --- Apple Social Auth Methods ---
|
||||
|
||||
// FindByAppleID finds an Apple social auth by Apple ID
|
||||
func (r *UserRepository) FindByAppleID(appleID string) (*models.AppleSocialAuth, error) {
|
||||
var auth models.AppleSocialAuth
|
||||
if err := r.db.Where("apple_id = ?", appleID).First(&auth).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAppleAuthNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// CreateAppleSocialAuth creates a new Apple social auth record
|
||||
func (r *UserRepository) CreateAppleSocialAuth(auth *models.AppleSocialAuth) error {
|
||||
return r.db.Create(auth).Error
|
||||
}
|
||||
|
||||
// UpdateAppleSocialAuth updates an Apple social auth record
|
||||
func (r *UserRepository) UpdateAppleSocialAuth(auth *models.AppleSocialAuth) error {
|
||||
return r.db.Save(auth).Error
|
||||
}
|
||||
|
||||
// --- Google Social Auth Methods ---
|
||||
|
||||
// FindByGoogleID finds a Google social auth by Google ID
|
||||
func (r *UserRepository) FindByGoogleID(googleID string) (*models.GoogleSocialAuth, error) {
|
||||
var auth models.GoogleSocialAuth
|
||||
if err := r.db.Where("google_id = ?", googleID).First(&auth).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGoogleAuthNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// CreateGoogleSocialAuth creates a new Google social auth record
|
||||
func (r *UserRepository) CreateGoogleSocialAuth(auth *models.GoogleSocialAuth) error {
|
||||
return r.db.Create(auth).Error
|
||||
}
|
||||
|
||||
// UpdateGoogleSocialAuth updates a Google social auth record
|
||||
func (r *UserRepository) UpdateGoogleSocialAuth(auth *models.GoogleSocialAuth) error {
|
||||
return r.db.Save(auth).Error
|
||||
}
|
||||
|
||||
// WithContext returns a copy of the repository whose underlying *gorm.DB carries
|
||||
// the supplied context. SQL emitted via this copy gets attached to ctx's trace span
|
||||
|
||||
@@ -2,7 +2,6 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -78,99 +77,25 @@ func TestUserRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUserRepository_GetOrCreateToken(t *testing.T) {
|
||||
func TestUserRepository_FindByKratosID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
user := testutil.CreateTestUser(t, db, "kratosuser", "kratos@example.com", "")
|
||||
|
||||
// Create token
|
||||
token1, err := repo.GetOrCreateToken(user.ID)
|
||||
found, err := repo.FindByKratosID(user.KratosID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token1.Key)
|
||||
|
||||
// Should return same token
|
||||
token2, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token1.Key, token2.Key)
|
||||
assert.Equal(t, user.ID, found.ID)
|
||||
assert.Equal(t, user.KratosID, found.KratosID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindTokenByKey(t *testing.T) {
|
||||
func TestUserRepository_FindByKratosID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindTokenByKey(token.Plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, found.Key)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindTokenByKey_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindTokenByKey("nonexistent-token-key")
|
||||
_, err := repo.FindByKratosID("nonexistent-kratos-id")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteToken(token.Plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindTokenByKey(token.Plaintext)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
err := repo.DeleteToken("nonexistent-key")
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteTokenByUserID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteTokenByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should be gone
|
||||
var count int64
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.CreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Equal(t, user.ID, token.UserID)
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_UpdateLastLogin(t *testing.T) {
|
||||
@@ -255,54 +180,6 @@ func TestUserRepository_FindByIDWithProfile_NotFound(t *testing.T) {
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_ConfirmationCode_Lifecycle(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create confirmation code
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreateConfirmationCode(user.ID, "123456", expiresAt)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, code.ID)
|
||||
|
||||
// Find it
|
||||
found, err := repo.FindConfirmationCode(user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, code.ID, found.ID)
|
||||
|
||||
// Mark as used
|
||||
err = repo.MarkConfirmationCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not find used code
|
||||
_, err = repo.FindConfirmationCode(user.ID, "123456")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserRepository_ConfirmationCode_InvalidatesExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
// Create first code
|
||||
code1, err := repo.CreateConfirmationCode(user.ID, "111111", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create second code (should invalidate first)
|
||||
_, err = repo.CreateConfirmationCode(user.ID, "222222", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First code should be used/invalidated
|
||||
var c models.ConfirmationCode
|
||||
db.First(&c, code1.ID)
|
||||
assert.True(t, c.IsUsed)
|
||||
}
|
||||
|
||||
func TestUserRepository_Transaction(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
@@ -331,105 +208,6 @@ func TestUserRepository_DB(t *testing.T) {
|
||||
assert.NotNil(t, repo.DB())
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByAppleID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_123",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
found, err := repo.FindByAppleID("apple_sub_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByAppleID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByAppleID("nonexistent_apple_id")
|
||||
assert.ErrorIs(t, err, ErrAppleAuthNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByGoogleID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_123",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
found, err := repo.FindByGoogleID("google_sub_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByGoogleID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByGoogleID("nonexistent_google_id")
|
||||
assert.ErrorIs(t, err, ErrGoogleAuthNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndUpdateAppleSocialAuth(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
|
||||
auth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_456",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
err := repo.CreateAppleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, auth.ID)
|
||||
|
||||
auth.Email = "updated@test.com"
|
||||
err = repo.UpdateAppleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByAppleID("apple_sub_456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated@test.com", found.Email)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndUpdateGoogleSocialAuth(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
|
||||
auth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_456",
|
||||
Email: "google@test.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
err := repo.CreateGoogleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, auth.ID)
|
||||
|
||||
auth.Name = "Updated Name"
|
||||
err = repo.UpdateGoogleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByGoogleID("google_sub_456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", found.Name)
|
||||
}
|
||||
|
||||
func TestUserRepository_SearchUsers(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
@@ -2,7 +2,6 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -11,207 +10,6 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
// === Password Reset Code Lifecycle ===
|
||||
|
||||
func TestUserRepository_PasswordResetCode_Lifecycle(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_abc123", "reset_token_xyz", expiresAt)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, code.ID)
|
||||
assert.Equal(t, "hash_abc123", code.CodeHash)
|
||||
assert.Equal(t, "reset_token_xyz", code.ResetToken)
|
||||
assert.False(t, code.Used)
|
||||
assert.Equal(t, 0, code.Attempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreatePasswordResetCode_InvalidatesExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
code1, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First code should be marked as used
|
||||
var c models.PasswordResetCode
|
||||
db.First(&c, code1.ID)
|
||||
assert.True(t, c.Used)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_abc", "token_abc", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, foundUser, err := repo.FindPasswordResetCodeByEmail("test@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, foundUser.ID)
|
||||
assert.Equal(t, "hash_abc", found.CodeHash)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail_UserNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, _, err := repo.FindPasswordResetCodeByEmail("nonexistent@example.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail_NoCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, _, err := repo.FindPasswordResetCodeByEmail("test@example.com")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_xyz", "token_xyz", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindPasswordResetCodeByToken("token_xyz")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hash_xyz", found.CodeHash)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindPasswordResetCodeByToken("nonexistent_token")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_Expired(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Already expired
|
||||
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_exp", "token_exp", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_exp")
|
||||
assert.ErrorIs(t, err, ErrCodeExpired)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_Used(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_used", "token_used", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as used
|
||||
err = repo.MarkPasswordResetCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_used")
|
||||
assert.ErrorIs(t, err, ErrCodeUsed)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_TooManyAttempts(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_attempts", "token_attempts", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Max out attempts
|
||||
for i := 0; i < 5; i++ {
|
||||
err = repo.IncrementResetCodeAttempts(code.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_attempts")
|
||||
assert.ErrorIs(t, err, ErrTooManyAttempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_IncrementResetCodeAttempts(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_inc", "token_inc", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.IncrementResetCodeAttempts(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.PasswordResetCode
|
||||
db.First(&updated, code.ID)
|
||||
assert.Equal(t, 1, updated.Attempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_MarkPasswordResetCodeUsed(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_mark", "token_mark", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.MarkPasswordResetCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.PasswordResetCode
|
||||
db.First(&updated, code.ID)
|
||||
assert.True(t, updated.Used)
|
||||
}
|
||||
|
||||
func TestUserRepository_CountRecentPasswordResetRequests(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := repo.CountRecentPasswordResetRequests(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
// === FindUsersInSharedResidences ===
|
||||
|
||||
func TestUserRepository_FindUsersInSharedResidences(t *testing.T) {
|
||||
@@ -301,33 +99,6 @@ func TestUserRepository_FindProfilesInSharedResidences(t *testing.T) {
|
||||
assert.Len(t, profiles, 2)
|
||||
}
|
||||
|
||||
// === ConfirmationCode Expired ===
|
||||
|
||||
func TestUserRepository_FindConfirmationCode_Expired(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create already-expired code
|
||||
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
|
||||
_, err := repo.CreateConfirmationCode(user.ID, "999999", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindConfirmationCode(user.ID, "999999")
|
||||
assert.ErrorIs(t, err, ErrCodeExpired)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindConfirmationCode_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, err := repo.FindConfirmationCode(user.ID, "000000")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
// === Transaction Rollback ===
|
||||
|
||||
func TestUserRepository_Transaction_Rollback(t *testing.T) {
|
||||
|
||||
@@ -19,7 +19,6 @@ func TestUserRepository_Create(t *testing.T) {
|
||||
Email: "test@example.com",
|
||||
IsActive: true,
|
||||
}
|
||||
user.SetPassword("Password123")
|
||||
|
||||
err := repo.Create(user)
|
||||
require.NoError(t, err)
|
||||
@@ -192,39 +191,11 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
t.Run("email user", func(t *testing.T) {
|
||||
t.Run("kratos user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123")
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "email", provider)
|
||||
})
|
||||
|
||||
t.Run("apple user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_test",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "apple", provider)
|
||||
})
|
||||
|
||||
t.Run("google user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_test",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "google", provider)
|
||||
assert.Equal(t, "kratos", provider) // All users are Kratos-managed
|
||||
})
|
||||
}
|
||||
|
||||
@@ -235,11 +206,9 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123")
|
||||
|
||||
// Create profile and token
|
||||
// Create profile
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
_, err := models.GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fileURLs []string
|
||||
txErr := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
@@ -261,10 +230,6 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
// Verify profile is gone
|
||||
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify token is gone
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("returns file URLs for cleanup", func(t *testing.T) {
|
||||
|
||||
+14
-50
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||
"github.com/treytartt/honeydue-api/internal/handlers"
|
||||
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||
"github.com/treytartt/honeydue-api/internal/kratos"
|
||||
custommiddleware "github.com/treytartt/honeydue-api/internal/middleware"
|
||||
"github.com/treytartt/honeydue-api/internal/monitoring"
|
||||
"github.com/treytartt/honeydue-api/internal/prom"
|
||||
@@ -200,7 +201,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
|
||||
// Initialize services
|
||||
authService := services.NewAuthService(userRepo, cfg)
|
||||
authService.SetNotificationRepository(notificationRepo) // For creating notification preferences on registration
|
||||
authService.SetNotificationRepository(notificationRepo)
|
||||
userService := services.NewUserService(userRepo)
|
||||
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
residenceService.SetTaskRepository(taskRepo) // Wire up task repo for statistics
|
||||
@@ -220,7 +221,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
// Wire Redis cache for residence-ID lookups across the four services that
|
||||
// read it on the request hot path. Cache is best-effort; nil cache is OK.
|
||||
if deps.Cache != nil {
|
||||
authService.SetCacheService(deps.Cache) // per-account login lockout (audit M5)
|
||||
authService.SetCacheService(deps.Cache)
|
||||
residenceService.SetCacheService(deps.Cache)
|
||||
taskService.SetCacheService(deps.Cache)
|
||||
contractorService.SetCacheService(deps.Cache)
|
||||
@@ -244,20 +245,15 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
subscriptionWebhookHandler.SetStripeService(stripeService)
|
||||
subscriptionWebhookHandler.SetCacheService(deps.Cache)
|
||||
|
||||
// Initialize middleware
|
||||
authMiddleware := custommiddleware.NewAuthMiddlewareWithConfig(deps.DB, deps.Cache, cfg)
|
||||
|
||||
// Initialize Apple auth service
|
||||
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
|
||||
googleAuthService := services.NewGoogleAuthService(deps.Cache, cfg)
|
||||
// Initialize Kratos auth middleware (replaces hand-rolled token auth).
|
||||
kratosClient := kratos.NewClient(cfg.Security.KratosPublicURL)
|
||||
authMiddleware := custommiddleware.NewKratosAuth(kratosClient, deps.Cache, deps.DB)
|
||||
|
||||
// Initialize audit service for security event logging
|
||||
auditService := services.NewAuditService(deps.DB)
|
||||
|
||||
// Initialize handlers
|
||||
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
||||
authHandler.SetAppleAuthService(appleAuthService)
|
||||
authHandler.SetGoogleAuthService(googleAuthService)
|
||||
authHandler.SetStorageService(deps.StorageService)
|
||||
authHandler.SetAuditService(auditService)
|
||||
userHandler := handlers.NewUserHandler(userService)
|
||||
@@ -318,8 +314,8 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
// API group
|
||||
api := e.Group("/api")
|
||||
{
|
||||
// Public auth routes (no auth required)
|
||||
setupPublicAuthRoutes(api, authHandler, cfg.Server.Debug)
|
||||
// Session lifecycle (login, register, logout, password reset) is
|
||||
// handled by Ory Kratos — no public auth routes in this service.
|
||||
|
||||
// Public data routes (no auth required)
|
||||
setupPublicDataRoutes(api, residenceHandler, taskHandler, contractorHandler, staticDataHandler, subscriptionHandler, taskTemplateHandler)
|
||||
@@ -329,7 +325,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
|
||||
// Protected routes (auth required)
|
||||
protected := api.Group("")
|
||||
protected.Use(authMiddleware.TokenAuth())
|
||||
protected.Use(authMiddleware.Authenticate())
|
||||
protected.Use(custommiddleware.TimezoneMiddleware())
|
||||
{
|
||||
setupProtectedAuthRoutes(protected, authHandler)
|
||||
@@ -516,50 +512,18 @@ func prometheusMetrics(monSvc *monitoring.Service) echo.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// setupPublicAuthRoutes configures public authentication routes with
|
||||
// per-endpoint rate limiters to mitigate brute-force and credential-stuffing.
|
||||
// Rate limiters are disabled in debug mode to allow UI test suites to run
|
||||
// without hitting 429 errors.
|
||||
func setupPublicAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler, debug bool) {
|
||||
auth := api.Group("/auth")
|
||||
// setupPublicAuthRoutes was removed — session lifecycle (login, register,
|
||||
// logout, password reset, Apple/Google sign-in) is delegated to Ory Kratos.
|
||||
|
||||
if debug {
|
||||
// No rate limiters in debug/local mode
|
||||
auth.POST("/login/", authHandler.Login)
|
||||
auth.POST("/register/", authHandler.Register)
|
||||
auth.POST("/forgot-password/", authHandler.ForgotPassword)
|
||||
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
|
||||
auth.POST("/reset-password/", authHandler.ResetPassword)
|
||||
auth.POST("/apple-sign-in/", authHandler.AppleSignIn)
|
||||
auth.POST("/google-sign-in/", authHandler.GoogleSignIn)
|
||||
} else {
|
||||
// Rate limiters — created once, shared across requests.
|
||||
loginRL := custommiddleware.LoginRateLimiter() // 10 req/min
|
||||
registerRL := custommiddleware.RegistrationRateLimiter() // 5 req/min
|
||||
passwordRL := custommiddleware.PasswordResetRateLimiter() // 3 req/min
|
||||
|
||||
auth.POST("/login/", authHandler.Login, loginRL)
|
||||
auth.POST("/register/", authHandler.Register, registerRL)
|
||||
auth.POST("/forgot-password/", authHandler.ForgotPassword, passwordRL)
|
||||
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode, passwordRL)
|
||||
auth.POST("/reset-password/", authHandler.ResetPassword, passwordRL)
|
||||
auth.POST("/apple-sign-in/", authHandler.AppleSignIn, loginRL)
|
||||
auth.POST("/google-sign-in/", authHandler.GoogleSignIn, loginRL)
|
||||
}
|
||||
}
|
||||
|
||||
// setupProtectedAuthRoutes configures protected authentication routes
|
||||
// setupProtectedAuthRoutes configures protected auth routes.
|
||||
// Session lifecycle (login, logout, password reset, email verification) is
|
||||
// delegated to Ory Kratos — only profile and account-deletion routes remain.
|
||||
func setupProtectedAuthRoutes(api *echo.Group, authHandler *handlers.AuthHandler) {
|
||||
auth := api.Group("/auth")
|
||||
{
|
||||
auth.POST("/logout/", authHandler.Logout)
|
||||
auth.POST("/refresh/", authHandler.RefreshToken)
|
||||
auth.GET("/me/", authHandler.CurrentUser)
|
||||
auth.PUT("/profile/", authHandler.UpdateProfile)
|
||||
auth.PATCH("/profile/", authHandler.UpdateProfile)
|
||||
auth.POST("/verify/", authHandler.VerifyEmail) // Alias for mobile app compatibility
|
||||
auth.POST("/verify-email/", authHandler.VerifyEmail) // Original route
|
||||
auth.POST("/resend-verification/", authHandler.ResendVerification)
|
||||
auth.DELETE("/account/", authHandler.DeleteAccount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
appleKeysURL = "https://appleid.apple.com/auth/keys"
|
||||
appleIssuer = "https://appleid.apple.com"
|
||||
appleKeysCacheTTL = 24 * time.Hour
|
||||
appleKeysCacheKey = "apple:public_keys"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAppleToken = errors.New("invalid Apple identity token")
|
||||
ErrAppleTokenExpired = errors.New("Apple identity token has expired")
|
||||
ErrInvalidAppleAudience = errors.New("invalid Apple token audience")
|
||||
ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer")
|
||||
ErrAppleKeyNotFound = errors.New("Apple public key not found")
|
||||
)
|
||||
|
||||
// AppleJWKS represents Apple's JSON Web Key Set
|
||||
type AppleJWKS struct {
|
||||
Keys []AppleJWK `json:"keys"`
|
||||
}
|
||||
|
||||
// AppleJWK represents a single JSON Web Key from Apple
|
||||
type AppleJWK struct {
|
||||
Kty string `json:"kty"` // Key type (RSA)
|
||||
Kid string `json:"kid"` // Key ID
|
||||
Use string `json:"use"` // Key use (sig)
|
||||
Alg string `json:"alg"` // Algorithm (RS256)
|
||||
N string `json:"n"` // RSA modulus
|
||||
E string `json:"e"` // RSA exponent
|
||||
}
|
||||
|
||||
// AppleTokenClaims represents the claims in an Apple identity token
|
||||
type AppleTokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string
|
||||
IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified (handles both bool and string types)
|
||||
func (c *AppleTokenClaims) IsEmailVerified() bool {
|
||||
switch v := c.EmailVerified.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v == "true"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsPrivateRelayEmail returns whether the email is a private relay email
|
||||
func (c *AppleTokenClaims) IsPrivateRelayEmail() bool {
|
||||
switch v := c.IsPrivateEmail.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v == "true"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AppleAuthService handles Apple Sign In token verification
|
||||
type AppleAuthService struct {
|
||||
cache *CacheService
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAppleAuthService creates a new Apple auth service
|
||||
func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService {
|
||||
return &AppleAuthService{
|
||||
cache: cache,
|
||||
config: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an Apple identity token and returns the claims
|
||||
func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) {
|
||||
// Parse the token header to get the key ID
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrInvalidAppleToken
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||
}
|
||||
|
||||
var header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||
}
|
||||
|
||||
// Get the public key for this key ID
|
||||
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and verify the token
|
||||
token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrAppleTokenExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*AppleTokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidAppleToken
|
||||
}
|
||||
|
||||
// Verify the issuer
|
||||
if claims.Issuer != appleIssuer {
|
||||
return nil, ErrInvalidAppleIssuer
|
||||
}
|
||||
|
||||
// Verify the audience (should be our bundle ID)
|
||||
if !s.verifyAudience(claims.Audience) {
|
||||
return nil, ErrInvalidAppleAudience
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the token audience matches our client ID.
|
||||
// In production (non-debug), an empty clientID causes verification to fail
|
||||
// rather than silently bypassing the check.
|
||||
func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool {
|
||||
clientID := s.config.AppleAuth.ClientID
|
||||
if clientID == "" {
|
||||
if s.config.Server.Debug {
|
||||
// In debug mode only, skip audience verification for local development
|
||||
return true
|
||||
}
|
||||
// In production, missing client ID means we cannot verify the audience
|
||||
return false
|
||||
}
|
||||
|
||||
for _, aud := range audience {
|
||||
if aud == clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPublicKey retrieves the public key for the given key ID
|
||||
func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||
// Try to get from cache first
|
||||
keys, err := s.getCachedKeys(ctx)
|
||||
if err != nil || keys == nil {
|
||||
// Fetch fresh keys
|
||||
keys, err = s.fetchApplePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Find the key with the matching ID
|
||||
for keyID, pubKey := range keys {
|
||||
if keyID == kid {
|
||||
return pubKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Key not found in cache, try fetching fresh keys
|
||||
keys, err = s.fetchApplePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
return nil, ErrAppleKeyNotFound
|
||||
}
|
||||
|
||||
// getCachedKeys retrieves cached Apple public keys from Redis
|
||||
func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
if s.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := s.cache.GetString(ctx, appleKeysCacheKey)
|
||||
if err != nil || data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var jwks AppleJWKS
|
||||
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.parseJWKS(&jwks)
|
||||
}
|
||||
|
||||
// fetchApplePublicKeys fetches Apple's public keys and caches them
|
||||
func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Apple keys: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var jwks AppleJWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode Apple keys: %w", err)
|
||||
}
|
||||
|
||||
// Cache the keys
|
||||
if s.cache != nil {
|
||||
keysJSON, _ := json.Marshal(jwks)
|
||||
_ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL)
|
||||
}
|
||||
|
||||
return s.parseJWKS(&jwks)
|
||||
}
|
||||
|
||||
// parseJWKS converts Apple's JWKS to RSA public keys
|
||||
func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) {
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Decode the modulus (N)
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
|
||||
// Decode the exponent (E)
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: n,
|
||||
E: e,
|
||||
}
|
||||
|
||||
keys[key.Kid] = pubKey
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
)
|
||||
|
||||
func setupRefreshTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
|
||||
t.Helper()
|
||||
user := &models.User{
|
||||
Username: "refreshtest",
|
||||
Email: "refresh@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
return user
|
||||
}
|
||||
|
||||
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
|
||||
t.Helper()
|
||||
token := &models.AuthToken{
|
||||
UserID: userID,
|
||||
}
|
||||
require.NoError(t, db.Create(token).Error)
|
||||
|
||||
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||
token.Created = backdated
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func newTestAuthService(db *gorm.DB) *AuthService {
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret",
|
||||
TokenExpiryDays: 90,
|
||||
TokenRefreshDays: 60,
|
||||
},
|
||||
}
|
||||
return NewAuthService(userRepo, cfg)
|
||||
}
|
||||
|
||||
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
|
||||
assert.Contains(t, resp.Message, "still valid")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
|
||||
assert.Contains(t, resp.Message, "refreshed")
|
||||
|
||||
// Verify old token was deleted
|
||||
var count int64
|
||||
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "old token should be deleted")
|
||||
|
||||
// Verify new token exists in DB
|
||||
// resp.Token is the raw token; the DB stores its hash.
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "new token should exist in DB")
|
||||
|
||||
// Verify new token belongs to the same user
|
||||
var newToken models.AuthToken
|
||||
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
|
||||
assert.Equal(t, user.ID, newToken.UserID)
|
||||
}
|
||||
|
||||
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
|
||||
// meaning it enters the renewal window
|
||||
token := createTokenWithAge(t, db, user.ID, 61)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 75)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
// Try to refresh with a different user ID
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -19,195 +18,18 @@ func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository)
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
DebugFixedCodes: true,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret",
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
MaxPasswordResetRate: 3,
|
||||
TokenExpiryDays: 90,
|
||||
TokenRefreshDays: 60,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
service.SetNotificationRepository(notifRepo)
|
||||
return service, userRepo
|
||||
}
|
||||
|
||||
// === Login ===
|
||||
|
||||
func TestAuthService_Login(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "testuser", resp.User.Username)
|
||||
}
|
||||
|
||||
func TestAuthService_Login_ByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
}
|
||||
|
||||
func TestAuthService_Login_InvalidCredentials(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "WrongPassword1",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_Login_UserNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_Login_InactiveUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123")
|
||||
// Deactivate
|
||||
user.IsActive = false
|
||||
db.Save(user)
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "inactive",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
// Audit L1: inactive accounts return the same generic error as bad
|
||||
// credentials so login does not disclose which accounts exist.
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === Register ===
|
||||
|
||||
func TestAuthService_Register(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, code, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "newuser", resp.User.Username)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes=true
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DuplicateUsername(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{DebugFixedCodes: true},
|
||||
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123")
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "taken",
|
||||
Email: "different@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DuplicateEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{DebugFixedCodes: true},
|
||||
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123")
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "taken@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
|
||||
}
|
||||
|
||||
// === GetCurrentUser ===
|
||||
|
||||
func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -218,7 +40,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", resp.Username)
|
||||
assert.Equal(t, "test@test.com", resp.Email)
|
||||
assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth
|
||||
assert.Equal(t, "kratos", resp.AuthProvider) // All users are Kratos-managed
|
||||
}
|
||||
|
||||
// === UpdateProfile ===
|
||||
@@ -226,9 +48,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
func TestAuthService_UpdateProfile(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -250,9 +70,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
|
||||
func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123")
|
||||
@@ -271,9 +89,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
|
||||
func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -290,443 +106,10 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
|
||||
assert.Equal(t, "test@test.com", resp.Email)
|
||||
}
|
||||
|
||||
// === VerifyEmail ===
|
||||
|
||||
func TestAuthService_VerifyEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register a user (creates confirmation code)
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the user ID
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify with the debug code
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify again — should get already verified error
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
|
||||
// but a wrong code should use the normal path
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "000000")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === ResendVerificationCode ===
|
||||
|
||||
func TestAuthService_ResendVerificationCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
code, err := service.ResendVerificationCode(context.Background(), user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||
}
|
||||
|
||||
func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register and verify
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.ResendVerificationCode(context.Background(), user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||
}
|
||||
|
||||
// === ForgotPassword ===
|
||||
|
||||
func TestAuthService_ForgotPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register a user first
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||
assert.NotNil(t, user)
|
||||
assert.Equal(t, "test@test.com", user.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Should not reveal that email doesn't exist
|
||||
code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, code)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
// === ResetPassword ===
|
||||
|
||||
func TestAuthService_ResetPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Forgot password
|
||||
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify reset code to get the token
|
||||
resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resetToken)
|
||||
|
||||
// Reset password
|
||||
err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Login with new password
|
||||
loginReq := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "NewPassword123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, loginResp.Token)
|
||||
}
|
||||
|
||||
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
|
||||
}
|
||||
|
||||
// === Logout ===
|
||||
|
||||
func TestAuthService_Logout(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
// Login first
|
||||
loginReq := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Logout
|
||||
err = service.Logout(context.Background(), loginResp.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should be deleted — refreshing should fail
|
||||
_, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === DeleteAccount ===
|
||||
|
||||
func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "Password123"
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongPassword := "WrongPassword1"
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
password := "Password123"
|
||||
_, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
||||
}
|
||||
|
||||
// === Helper functions ===
|
||||
|
||||
func TestGenerateSixDigitCode(t *testing.T) {
|
||||
code := generateSixDigitCode()
|
||||
assert.Len(t, code, 6)
|
||||
// Should be numeric
|
||||
for _, c := range code {
|
||||
assert.True(t, c >= '0' && c <= '9', "code should contain only digits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateResetToken(t *testing.T) {
|
||||
token := generateResetToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
|
||||
}
|
||||
|
||||
func TestGetStringOrEmpty(t *testing.T) {
|
||||
s := "hello"
|
||||
assert.Equal(t, "hello", getStringOrEmpty(&s))
|
||||
assert.Equal(t, "", getStringOrEmpty(nil))
|
||||
}
|
||||
|
||||
func TestIsPrivateRelayEmail(t *testing.T) {
|
||||
assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com"))
|
||||
assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM"))
|
||||
assert.False(t, isPrivateRelayEmail("user@gmail.com"))
|
||||
}
|
||||
|
||||
func TestGetEmailFromRequest(t *testing.T) {
|
||||
email := "req@test.com"
|
||||
assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com"))
|
||||
assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com"))
|
||||
empty := ""
|
||||
assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com"))
|
||||
}
|
||||
|
||||
// === getEmailOrDefault ===
|
||||
|
||||
func TestGetEmailOrDefault(t *testing.T) {
|
||||
// Non-empty email returns itself
|
||||
assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com"))
|
||||
|
||||
// Empty email returns a generated placeholder
|
||||
result := getEmailOrDefault("")
|
||||
assert.Contains(t, result, "@privaterelay.appleid.com")
|
||||
assert.Contains(t, result, "apple_")
|
||||
}
|
||||
|
||||
// === generateUniqueUsername ===
|
||||
|
||||
func TestGenerateUniqueUsername(t *testing.T) {
|
||||
// Normal email generates username from email prefix
|
||||
username := generateUniqueUsername("john@test.com", nil)
|
||||
assert.Contains(t, username, "john_")
|
||||
|
||||
// Private relay email falls back to first name
|
||||
firstName := "Jane"
|
||||
username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName)
|
||||
assert.Contains(t, username, "jane_")
|
||||
|
||||
// Private relay email and no first name — fallback
|
||||
username = generateUniqueUsername("abc@privaterelay.appleid.com", nil)
|
||||
assert.Contains(t, username, "user_")
|
||||
|
||||
// Empty email with first name
|
||||
firstName2 := "Bob"
|
||||
username = generateUniqueUsername("", &firstName2)
|
||||
assert.Contains(t, username, "bob_")
|
||||
|
||||
// Empty email and no first name
|
||||
username = generateUniqueUsername("", nil)
|
||||
assert.Contains(t, username, "user_")
|
||||
}
|
||||
|
||||
// === generateGoogleUsername ===
|
||||
|
||||
func TestGenerateGoogleUsername(t *testing.T) {
|
||||
// Normal email
|
||||
username := generateGoogleUsername("john@gmail.com", "John")
|
||||
assert.Contains(t, username, "john_")
|
||||
|
||||
// Empty email falls back to first name
|
||||
username = generateGoogleUsername("", "Alice")
|
||||
assert.Contains(t, username, "alice_")
|
||||
|
||||
// Empty email and empty first name — fallback
|
||||
username = generateGoogleUsername("", "")
|
||||
assert.Contains(t, username, "google_")
|
||||
}
|
||||
|
||||
// === Login with empty password ===
|
||||
|
||||
func TestAuthService_Login_EmptyPassword(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === ForgotPassword rate limiting ===
|
||||
|
||||
func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make max allowed reset requests (3 based on setup)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// The 4th should be rate limited
|
||||
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === VerifyResetCode with wrong code ===
|
||||
|
||||
func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code but with debug mode, "123456" works, "000000" should fail
|
||||
_, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === VerifyResetCode with nonexistent email ===
|
||||
|
||||
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
_, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === UpdateProfile — change email to new email ===
|
||||
|
||||
func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -742,25 +125,44 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
|
||||
assert.Equal(t, "newemail@test.com", resp.Email)
|
||||
}
|
||||
|
||||
// === DeleteAccount — empty password string ===
|
||||
// === DeleteAccount ===
|
||||
|
||||
func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
|
||||
func TestAuthService_DeleteAccount_WithConfirmation(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
|
||||
_ = user
|
||||
|
||||
confirmation := "DELETE"
|
||||
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &confirmation)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_WrongConfirmation(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
|
||||
|
||||
wrongConf := "delete"
|
||||
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &wrongConf)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_NoConfirmation(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
|
||||
|
||||
_, err := service.DeleteAccount(context.Background(), user.ID, nil, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
emptyPw := ""
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||
confirmation := "DELETE"
|
||||
_, err := service.DeleteAccount(context.Background(), 99999, nil, &confirmation)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
||||
}
|
||||
|
||||
// === SetNotificationRepository ===
|
||||
@@ -769,35 +171,10 @@ func TestAuthService_SetNotificationRepository(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
assert.Nil(t, service.notificationRepo)
|
||||
|
||||
service.SetNotificationRepository(notifRepo)
|
||||
assert.NotNil(t, service.notificationRepo)
|
||||
}
|
||||
|
||||
// === Register creates profile and notification preferences ===
|
||||
|
||||
func TestAuthService_Register_CreatesProfile(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "profileuser",
|
||||
Email: "profile@test.com",
|
||||
Password: "Password123",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
}
|
||||
|
||||
resp, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "profileuser", resp.User.Username)
|
||||
|
||||
// Profile should exist
|
||||
profile, err := userRepo.GetOrCreateProfile(resp.User.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, profile)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// CacheService provides Redis caching functionality
|
||||
@@ -134,116 +133,6 @@ func (c *CacheService) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Auth token cache helpers
|
||||
const (
|
||||
AuthTokenPrefix = "auth_token_"
|
||||
TokenCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// authTokenCacheKey returns the Redis key for an auth token. The raw token
|
||||
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
|
||||
func authTokenCacheKey(token string) string {
|
||||
return AuthTokenPrefix + models.HashToken(token)
|
||||
}
|
||||
|
||||
// CacheAuthToken caches a user ID for a token
|
||||
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
|
||||
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// GetCachedAuthToken gets a cached user ID for a token
|
||||
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var userID uint
|
||||
_, err = fmt.Sscanf(val, "%d", &userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
|
||||
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
|
||||
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var userID uint
|
||||
var createdUnix int64
|
||||
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
|
||||
if n < 1 {
|
||||
return 0, 0, fmt.Errorf("invalid cached token format")
|
||||
}
|
||||
return userID, createdUnix, nil
|
||||
}
|
||||
|
||||
// InvalidateAuthToken removes a cached token
|
||||
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
||||
return c.Delete(ctx, authTokenCacheKey(token))
|
||||
}
|
||||
|
||||
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
|
||||
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
|
||||
// stored hash directly — used to evict a user's prior token on re-login
|
||||
// (audit MEDIUM-1), where the server no longer has the plaintext.
|
||||
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
|
||||
keys := make([]string, 0, len(hashes))
|
||||
for _, h := range hashes {
|
||||
if h != "" {
|
||||
keys = append(keys, AuthTokenPrefix+h)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.Delete(ctx, keys...)
|
||||
}
|
||||
|
||||
// --- Per-account login-failure tracking (audit M5) ---
|
||||
|
||||
const loginFailPrefix = "login_fail:"
|
||||
|
||||
// RegisterLoginFailure records a failed login for an account from a given
|
||||
// source IP, and returns the number of DISTINCT source IPs that have failed
|
||||
// for this account within the window. Tracking distinct IPs as a set rather
|
||||
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
|
||||
// run the count up and lock a victim out by knowing only their email — a
|
||||
// single IP is bounded by the per-IP edge/app rate limiters instead. A
|
||||
// genuinely distributed credential-stuffing attack still trips the lockout.
|
||||
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
|
||||
key := loginFailPrefix + identifier
|
||||
member := ip
|
||||
if member == "" {
|
||||
member = "unknown"
|
||||
}
|
||||
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Refresh the TTL on each failure: an active attack keeps the window
|
||||
// open, while a quiet account ages out `window` after its last failure.
|
||||
_ = c.client.Expire(ctx, key, window).Err()
|
||||
return c.client.SCard(ctx, key).Result()
|
||||
}
|
||||
|
||||
// LoginFailureIPCount returns how many distinct source IPs have failed to log
|
||||
// in to this account within the window (audit MEDIUM-3). SCard on a missing
|
||||
// key returns 0.
|
||||
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
|
||||
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
|
||||
}
|
||||
|
||||
// ClearLoginFailures resets the failed-login IP set after a successful login.
|
||||
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
|
||||
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
|
||||
}
|
||||
|
||||
// Static data cache helpers
|
||||
const (
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
|
||||
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
googleKeysCacheTTL = 24 * time.Hour
|
||||
googleKeysCacheKey = "google:public_keys"
|
||||
)
|
||||
|
||||
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
|
||||
var googleIssuers = map[string]bool{
|
||||
"accounts.google.com": true,
|
||||
"https://accounts.google.com": true,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
|
||||
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
|
||||
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
|
||||
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
|
||||
ErrGoogleKeyNotFound = errors.New("Google public key not found")
|
||||
)
|
||||
|
||||
// GoogleJWKS represents Google's JSON Web Key Set.
|
||||
type GoogleJWKS struct {
|
||||
Keys []GoogleJWK `json:"keys"`
|
||||
}
|
||||
|
||||
// GoogleJWK represents a single JSON Web Key from Google.
|
||||
type GoogleJWK struct {
|
||||
Kty string `json:"kty"` // Key type (RSA)
|
||||
Kid string `json:"kid"` // Key ID
|
||||
Use string `json:"use"` // Key use (sig)
|
||||
Alg string `json:"alg"` // Algorithm (RS256)
|
||||
N string `json:"n"` // RSA modulus
|
||||
E string `json:"e"` // RSA exponent
|
||||
}
|
||||
|
||||
// GoogleTokenClaims represents the claims in a Google ID token JWT.
|
||||
type GoogleTokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified bool `json:"email_verified,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Azp string `json:"azp,omitempty"` // Authorized party
|
||||
}
|
||||
|
||||
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
|
||||
type GoogleTokenInfo struct {
|
||||
Sub string // Unique Google user ID
|
||||
Email string
|
||||
EmailVerified string // "true" or "false" — string for caller compatibility
|
||||
Name string
|
||||
GivenName string
|
||||
FamilyName string
|
||||
Picture string
|
||||
Aud string
|
||||
Azp string
|
||||
Iss string
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified.
|
||||
func (t *GoogleTokenInfo) IsEmailVerified() bool {
|
||||
return t.EmailVerified == "true"
|
||||
}
|
||||
|
||||
// GoogleAuthService handles Google Sign In token verification.
|
||||
type GoogleAuthService struct {
|
||||
cache *CacheService
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewGoogleAuthService creates a new Google auth service.
|
||||
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
|
||||
return &GoogleAuthService{
|
||||
cache: cache,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
|
||||
// the RS256 signature against Google's published JWKS and the iss, aud, and
|
||||
// exp claims. It never sends the token to a third-party endpoint, so it no
|
||||
// longer depends on the deprecated tokeninfo service and never leaks the
|
||||
// token in a request URL.
|
||||
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
|
||||
// Parse the token header to get the key ID.
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||
}
|
||||
var header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||
}
|
||||
|
||||
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
|
||||
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrGoogleTokenExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*GoogleTokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
// Verify the issuer (audit C3).
|
||||
if !googleIssuers[claims.Issuer] {
|
||||
return nil, ErrInvalidGoogleIssuer
|
||||
}
|
||||
|
||||
// Verify the audience matches one of our configured client IDs.
|
||||
if !s.verifyAudience(claims.Audience, claims.Azp) {
|
||||
return nil, ErrInvalidGoogleAudience
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
emailVerified := "false"
|
||||
if claims.EmailVerified {
|
||||
emailVerified = "true"
|
||||
}
|
||||
aud := ""
|
||||
if len(claims.Audience) > 0 {
|
||||
aud = claims.Audience[0]
|
||||
}
|
||||
return &GoogleTokenInfo{
|
||||
Sub: claims.Subject,
|
||||
Email: claims.Email,
|
||||
EmailVerified: emailVerified,
|
||||
Name: claims.Name,
|
||||
GivenName: claims.GivenName,
|
||||
FamilyName: claims.FamilyName,
|
||||
Picture: claims.Picture,
|
||||
Aud: aud,
|
||||
Azp: claims.Azp,
|
||||
Iss: claims.Issuer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// verifyAudience checks the token audience against our configured client IDs.
|
||||
// In production (non-debug) an empty client ID fails verification rather than
|
||||
// silently bypassing the check.
|
||||
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
|
||||
clientID := s.config.GoogleAuth.ClientID
|
||||
if clientID == "" {
|
||||
// In debug mode only, skip audience verification for local development.
|
||||
return s.config.Server.Debug
|
||||
}
|
||||
|
||||
candidates := []string{clientID}
|
||||
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
if id := s.config.GoogleAuth.IOSClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
|
||||
for _, want := range candidates {
|
||||
if azp == want {
|
||||
return true
|
||||
}
|
||||
for _, aud := range audience {
|
||||
if aud == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPublicKey returns the RSA public key for the given key ID, using a
|
||||
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
|
||||
// (Google rotates signing keys roughly daily).
|
||||
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||
keys, err := s.getCachedKeys(ctx)
|
||||
if err != nil || keys == nil {
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
// Cache miss for this kid — keys may have rotated; fetch fresh.
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
return nil, ErrGoogleKeyNotFound
|
||||
}
|
||||
|
||||
// getCachedKeys retrieves cached Google public keys from Redis.
|
||||
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
if s.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
|
||||
if err != nil || data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
|
||||
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
|
||||
}
|
||||
if s.cache != nil {
|
||||
keysJSON, _ := json.Marshal(jwks)
|
||||
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
|
||||
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
@@ -22,6 +22,14 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/validator"
|
||||
)
|
||||
|
||||
// authUserKey and authTokenKey mirror middleware.AuthUserKey / middleware.AuthTokenKey.
|
||||
// We duplicate the string constants here to avoid an import cycle
|
||||
// (testutil <- middleware <- repositories <- testutil).
|
||||
const (
|
||||
authUserKey = "auth_user"
|
||||
authTokenKey = "auth_token"
|
||||
)
|
||||
|
||||
var (
|
||||
i18nOnce sync.Once
|
||||
testDBCounter uint64
|
||||
@@ -52,9 +60,6 @@ func SetupTestDB(t *testing.T) *gorm.DB {
|
||||
err = db.AutoMigrate(
|
||||
&models.User{},
|
||||
&models.UserProfile{},
|
||||
&models.AuthToken{},
|
||||
&models.ConfirmationCode{},
|
||||
&models.PasswordResetCode{},
|
||||
&models.AdminUser{},
|
||||
&models.Residence{},
|
||||
&models.ResidenceType{},
|
||||
@@ -73,8 +78,6 @@ func SetupTestDB(t *testing.T) *gorm.DB {
|
||||
&models.NotificationPreference{},
|
||||
&models.APNSDevice{},
|
||||
&models.GCMDevice{},
|
||||
&models.AppleSocialAuth{},
|
||||
&models.GoogleSocialAuth{},
|
||||
&models.TaskReminderLog{},
|
||||
&models.UserSubscription{},
|
||||
&models.SubscriptionSettings{},
|
||||
@@ -177,29 +180,24 @@ func ParseJSONArray(t *testing.T, body []byte) []map[string]interface{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// CreateTestUser creates a test user in the database
|
||||
func CreateTestUser(t *testing.T, db *gorm.DB, username, email, password string) *models.User {
|
||||
// CreateTestUser creates a test user in the database.
|
||||
// password is accepted for API compatibility but ignored — Kratos owns credentials.
|
||||
// A synthetic KratosID is generated so the user satisfies the unique-index constraint.
|
||||
func CreateTestUser(t *testing.T, db *gorm.DB, username, email, _ string) *models.User {
|
||||
t.Helper()
|
||||
user := &models.User{
|
||||
KratosID: "test-kratos-" + username,
|
||||
Username: username,
|
||||
Email: email,
|
||||
IsActive: true,
|
||||
}
|
||||
err := user.SetPassword(password)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Create(user).Error
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
// CreateTestToken creates an auth token for a user
|
||||
func CreateTestToken(t *testing.T, db *gorm.DB, userID uint) *models.AuthToken {
|
||||
token, err := models.GetOrCreateToken(db, userID)
|
||||
require.NoError(t, err)
|
||||
return token
|
||||
}
|
||||
|
||||
// CreateTestResidenceType creates a test residence type
|
||||
func CreateTestResidenceType(t *testing.T, db *gorm.DB, name string) *models.ResidenceType {
|
||||
rt := &models.ResidenceType{Name: name}
|
||||
@@ -362,12 +360,15 @@ func AssertStatusCode(t *testing.T, w *httptest.ResponseRecorder, expected int)
|
||||
require.Equal(t, expected, w.Code, "unexpected status code: %s", w.Body.String())
|
||||
}
|
||||
|
||||
// MockAuthMiddleware creates middleware that sets a test user in context
|
||||
// MockAuthMiddleware creates middleware that sets a test user in context.
|
||||
// Uses the same context keys as KratosAuth (authUserKey / authTokenKey) so
|
||||
// handlers are unaware of the swap. The constants are duplicated here to
|
||||
// avoid an import cycle (testutil <- middleware <- repositories <- testutil).
|
||||
func MockAuthMiddleware(user *models.User) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", user)
|
||||
c.Set("auth_token", "test-token")
|
||||
c.Set(authUserKey, user)
|
||||
c.Set(authTokenKey, "test-token")
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
-- +goose Up
|
||||
-- Phase 2: hand-rolled auth replaced by Ory Kratos. Kratos owns identities,
|
||||
-- credentials, sessions, email verification, recovery and social sign-in.
|
||||
-- honeyDue keeps a slim auth_user row linked to the Kratos identity by
|
||||
-- kratos_id; all domain tables keep their existing integer auth_user FKs.
|
||||
--
|
||||
-- Pre-production: a clean slate is taken. auth_user is truncated (cascading
|
||||
-- to all user-scoped domain data) so no auth_user row exists without a
|
||||
-- Kratos identity behind it. There is no data migration.
|
||||
|
||||
-- honeyDue's hand-rolled auth tables are no longer used — Kratos owns this.
|
||||
DROP TABLE IF EXISTS user_authtoken;
|
||||
DROP TABLE IF EXISTS user_confirmationcode;
|
||||
DROP TABLE IF EXISTS user_passwordresetcode;
|
||||
DROP TABLE IF EXISTS user_applesocialauth;
|
||||
DROP TABLE IF EXISTS user_googlesocialauth;
|
||||
|
||||
-- Link each auth_user row to its Kratos identity (UUID).
|
||||
ALTER TABLE auth_user ADD COLUMN IF NOT EXISTS kratos_id uuid;
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uq_auth_user_kratos_id
|
||||
ON auth_user (kratos_id) WHERE kratos_id IS NOT NULL;
|
||||
|
||||
-- password is NOT NULL in the Django-era schema but is no longer used —
|
||||
-- Kratos holds credentials. Make it nullable so provisioning need not
|
||||
-- invent a placeholder hash.
|
||||
ALTER TABLE auth_user ALTER COLUMN password DROP NOT NULL;
|
||||
|
||||
-- Clean slate (pre-production): drop every existing account and all
|
||||
-- user-scoped domain data so nothing is left orphaned without a Kratos id.
|
||||
TRUNCATE TABLE auth_user CASCADE;
|
||||
|
||||
-- +goose Down
|
||||
-- The dropped tables' data cannot be restored. Down only removes the
|
||||
-- kratos_id column and restores the password NOT NULL constraint; reverting
|
||||
-- to hand-rolled auth means reverting the Phase 2 application code.
|
||||
DROP INDEX IF EXISTS uq_auth_user_kratos_id;
|
||||
ALTER TABLE auth_user DROP COLUMN IF EXISTS kratos_id;
|
||||
Reference in New Issue
Block a user