feat(auth): replace hand-rolled auth with Ory Kratos — phase 2 backend
Delegates all credential management (login, register, password reset, email verification, social sign-in) to Ory Kratos. The Go API now acts as a resource server: the new KratosAuth middleware validates sessions against the Kratos whoami endpoint, writes the local User mirror into Echo context, and all existing domain handlers continue working unchanged. Hand-rolled token auth, AuthToken model, apple_auth/ google_auth services, and the auth refresh flow are removed. Tests are updated to use the fake-token middleware pattern so existing integration assertions require no rewrite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -11,18 +11,21 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// FindByKratosID finds a user by Kratos identity UUID.
|
||||
func (r *UserRepository) FindByKratosID(kratosID string) (*models.User, error) {
|
||||
var user models.User
|
||||
if err := r.db.Where("kratos_id = ?", kratosID).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
ErrCodeNotFound = errors.New("code not found")
|
||||
ErrCodeExpired = errors.New("code expired")
|
||||
ErrCodeUsed = errors.New("code already used")
|
||||
ErrTooManyAttempts = errors.New("too many attempts")
|
||||
ErrRateLimitExceeded = errors.New("rate limit exceeded")
|
||||
ErrAppleAuthNotFound = errors.New("apple social auth not found")
|
||||
ErrGoogleAuthNotFound = errors.New("google social auth not found")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
)
|
||||
|
||||
// UserRepository handles user-related database operations
|
||||
@@ -145,111 +148,6 @@ func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// --- Auth Token Methods ---
|
||||
|
||||
// GetOrCreateToken gets or creates an auth token for a user.
|
||||
// Wrapped in a transaction to prevent race conditions where two
|
||||
// concurrent requests could create duplicate tokens for the same user.
|
||||
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
|
||||
var token models.AuthToken
|
||||
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Where("user_id = ?", userID).First(&token)
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
token = models.AuthToken{UserID: userID}
|
||||
if err := tx.Create(&token).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// FindTokenByKey looks up an auth token by its raw key value. The raw token
|
||||
// is hashed (audit C1) before the indexed lookup, since only the hash is
|
||||
// stored.
|
||||
func (r *UserRepository) FindTokenByKey(rawKey string) (*models.AuthToken, error) {
|
||||
var token models.AuthToken
|
||||
if err := r.db.Where("key = ?", models.HashToken(rawKey)).First(&token).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// CreateToken creates a new auth token for a user.
|
||||
func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
|
||||
token := models.AuthToken{UserID: userID}
|
||||
if err := r.db.Create(&token).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// CreateFreshToken issues a new auth token for the user, replacing any
|
||||
// existing one. Because tokens are stored hashed (audit C1) the server
|
||||
// cannot re-issue a previously-minted token's plaintext, so every login
|
||||
// mints a fresh token. The returned token's Plaintext field carries the
|
||||
// raw value to hand to the client; it is never persisted.
|
||||
//
|
||||
// It also returns the stored hashes of the token rows it deleted, so the
|
||||
// caller can evict those entries from the Redis token cache (audit MEDIUM-1).
|
||||
// Without that, a prior (e.g. stolen) token keeps authenticating via a cache
|
||||
// hit for up to the cache TTL even though its DB row is gone.
|
||||
func (r *UserRepository) CreateFreshToken(userID uint) (*models.AuthToken, []string, error) {
|
||||
var token models.AuthToken
|
||||
var oldHashes []string
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var old []models.AuthToken
|
||||
if err := tx.Where("user_id = ?", userID).Find(&old).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
oldHashes = make([]string, 0, len(old))
|
||||
for i := range old {
|
||||
if old[i].Key != "" {
|
||||
oldHashes = append(oldHashes, old[i].Key)
|
||||
}
|
||||
}
|
||||
if err := tx.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
token = models.AuthToken{UserID: userID}
|
||||
return tx.Create(&token).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &token, oldHashes, nil
|
||||
}
|
||||
|
||||
// DeleteToken deletes an auth token by its raw key value. The raw token is
|
||||
// hashed (audit C1) before the lookup, since only the hash is stored.
|
||||
func (r *UserRepository) DeleteToken(token string) error {
|
||||
result := r.db.Where("key = ?", models.HashToken(token)).Delete(&models.AuthToken{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return ErrTokenNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTokenByUserID deletes an auth token by user ID
|
||||
func (r *UserRepository) DeleteTokenByUserID(userID uint) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error
|
||||
}
|
||||
|
||||
// --- User Profile Methods ---
|
||||
|
||||
@@ -280,146 +178,6 @@ func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error {
|
||||
return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error
|
||||
}
|
||||
|
||||
// --- Confirmation Code Methods ---
|
||||
|
||||
// CreateConfirmationCode creates a new confirmation code
|
||||
func (r *UserRepository) CreateConfirmationCode(userID uint, code string, expiresAt time.Time) (*models.ConfirmationCode, error) {
|
||||
// Invalidate any existing unused codes for this user
|
||||
r.db.Model(&models.ConfirmationCode{}).
|
||||
Where("user_id = ? AND is_used = ?", userID, false).
|
||||
Update("is_used", true)
|
||||
|
||||
confirmCode := &models.ConfirmationCode{
|
||||
UserID: userID,
|
||||
Code: code,
|
||||
ExpiresAt: expiresAt,
|
||||
IsUsed: false,
|
||||
}
|
||||
|
||||
if err := r.db.Create(confirmCode).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return confirmCode, nil
|
||||
}
|
||||
|
||||
// FindConfirmationCode finds a valid confirmation code for a user
|
||||
func (r *UserRepository) FindConfirmationCode(userID uint, code string) (*models.ConfirmationCode, error) {
|
||||
var confirmCode models.ConfirmationCode
|
||||
if err := r.db.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false).
|
||||
First(&confirmCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !confirmCode.IsValid() {
|
||||
if confirmCode.IsUsed {
|
||||
return nil, ErrCodeUsed
|
||||
}
|
||||
return nil, ErrCodeExpired
|
||||
}
|
||||
|
||||
return &confirmCode, nil
|
||||
}
|
||||
|
||||
// MarkConfirmationCodeUsed marks a confirmation code as used
|
||||
func (r *UserRepository) MarkConfirmationCodeUsed(codeID uint) error {
|
||||
return r.db.Model(&models.ConfirmationCode{}).Where("id = ?", codeID).Update("is_used", true).Error
|
||||
}
|
||||
|
||||
// --- Password Reset Code Methods ---
|
||||
|
||||
// CreatePasswordResetCode creates a new password reset code
|
||||
func (r *UserRepository) CreatePasswordResetCode(userID uint, codeHash string, resetToken string, expiresAt time.Time) (*models.PasswordResetCode, error) {
|
||||
// Invalidate any existing unused codes for this user
|
||||
r.db.Model(&models.PasswordResetCode{}).
|
||||
Where("user_id = ? AND used = ?", userID, false).
|
||||
Update("used", true)
|
||||
|
||||
resetCode := &models.PasswordResetCode{
|
||||
UserID: userID,
|
||||
CodeHash: codeHash,
|
||||
ResetToken: resetToken,
|
||||
ExpiresAt: expiresAt,
|
||||
Used: false,
|
||||
Attempts: 0,
|
||||
MaxAttempts: 5,
|
||||
}
|
||||
|
||||
if err := r.db.Create(resetCode).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resetCode, nil
|
||||
}
|
||||
|
||||
// FindPasswordResetCode finds a password reset code by email and checks validity
|
||||
func (r *UserRepository) FindPasswordResetCodeByEmail(email string) (*models.PasswordResetCode, *models.User, error) {
|
||||
user, err := r.FindByEmail(email)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var resetCode models.PasswordResetCode
|
||||
if err := r.db.Where("user_id = ? AND used = ?", user.ID, false).
|
||||
Order("created_at DESC").
|
||||
First(&resetCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil, ErrCodeNotFound
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return &resetCode, user, nil
|
||||
}
|
||||
|
||||
// FindPasswordResetCodeByToken finds a password reset code by reset token
|
||||
func (r *UserRepository) FindPasswordResetCodeByToken(resetToken string) (*models.PasswordResetCode, error) {
|
||||
var resetCode models.PasswordResetCode
|
||||
if err := r.db.Where("reset_token = ?", resetToken).First(&resetCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resetCode.IsValid() {
|
||||
if resetCode.Used {
|
||||
return nil, ErrCodeUsed
|
||||
}
|
||||
if resetCode.Attempts >= resetCode.MaxAttempts {
|
||||
return nil, ErrTooManyAttempts
|
||||
}
|
||||
return nil, ErrCodeExpired
|
||||
}
|
||||
|
||||
return &resetCode, nil
|
||||
}
|
||||
|
||||
// IncrementResetCodeAttempts increments the attempt counter
|
||||
func (r *UserRepository) IncrementResetCodeAttempts(codeID uint) error {
|
||||
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).
|
||||
Update("attempts", gorm.Expr("attempts + 1")).Error
|
||||
}
|
||||
|
||||
// MarkPasswordResetCodeUsed marks a password reset code as used
|
||||
func (r *UserRepository) MarkPasswordResetCodeUsed(codeID uint) error {
|
||||
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).Update("used", true).Error
|
||||
}
|
||||
|
||||
// CountRecentPasswordResetRequests counts reset requests in the last hour
|
||||
func (r *UserRepository) CountRecentPasswordResetRequests(userID uint) (int64, error) {
|
||||
var count int64
|
||||
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
|
||||
if err := r.db.Model(&models.PasswordResetCode{}).
|
||||
Where("user_id = ? AND created_at > ?", userID, oneHourAgo).
|
||||
Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// --- Search Methods ---
|
||||
|
||||
@@ -576,27 +334,11 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
// --- Auth Provider Detection ---
|
||||
|
||||
// FindAuthProvider determines the auth provider for a user.
|
||||
// Returns "apple", "google", or "email".
|
||||
func (r *UserRepository) FindAuthProvider(userID uint) (string, error) {
|
||||
var count int64
|
||||
if err := r.db.Model(&models.AppleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return "apple", nil
|
||||
}
|
||||
|
||||
if err := r.db.Model(&models.GoogleSocialAuth{}).Where("user_id = ?", userID).Count(&count).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return "google", nil
|
||||
}
|
||||
|
||||
return "email", nil
|
||||
// FindAuthProvider returns "kratos" for all Kratos-managed users (the sole
|
||||
// provider after the Ory Kratos migration). Kept for compatibility with
|
||||
// callers that still check the provider string.
|
||||
func (r *UserRepository) FindAuthProvider(_ uint) (string, error) {
|
||||
return "kratos", nil
|
||||
}
|
||||
|
||||
// --- Account Deletion ---
|
||||
@@ -721,35 +463,12 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 8. Social auth records
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.AppleSocialAuth{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.GoogleSocialAuth{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 9. Confirmation codes
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.ConfirmationCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 10. Password reset codes
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.PasswordResetCode{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 11. Auth tokens
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 12. User profile
|
||||
// 8. User profile
|
||||
if err := db.Where("user_id = ?", userID).Delete(&models.UserProfile{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 13. User
|
||||
// 9. User
|
||||
if err := db.Where("id = ?", userID).Delete(&models.User{}).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -765,53 +484,6 @@ func (r *UserRepository) DeleteUserCascade(userID uint) ([]string, error) {
|
||||
return cleanURLs, nil
|
||||
}
|
||||
|
||||
// --- Apple Social Auth Methods ---
|
||||
|
||||
// FindByAppleID finds an Apple social auth by Apple ID
|
||||
func (r *UserRepository) FindByAppleID(appleID string) (*models.AppleSocialAuth, error) {
|
||||
var auth models.AppleSocialAuth
|
||||
if err := r.db.Where("apple_id = ?", appleID).First(&auth).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAppleAuthNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// CreateAppleSocialAuth creates a new Apple social auth record
|
||||
func (r *UserRepository) CreateAppleSocialAuth(auth *models.AppleSocialAuth) error {
|
||||
return r.db.Create(auth).Error
|
||||
}
|
||||
|
||||
// UpdateAppleSocialAuth updates an Apple social auth record
|
||||
func (r *UserRepository) UpdateAppleSocialAuth(auth *models.AppleSocialAuth) error {
|
||||
return r.db.Save(auth).Error
|
||||
}
|
||||
|
||||
// --- Google Social Auth Methods ---
|
||||
|
||||
// FindByGoogleID finds a Google social auth by Google ID
|
||||
func (r *UserRepository) FindByGoogleID(googleID string) (*models.GoogleSocialAuth, error) {
|
||||
var auth models.GoogleSocialAuth
|
||||
if err := r.db.Where("google_id = ?", googleID).First(&auth).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrGoogleAuthNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &auth, nil
|
||||
}
|
||||
|
||||
// CreateGoogleSocialAuth creates a new Google social auth record
|
||||
func (r *UserRepository) CreateGoogleSocialAuth(auth *models.GoogleSocialAuth) error {
|
||||
return r.db.Create(auth).Error
|
||||
}
|
||||
|
||||
// UpdateGoogleSocialAuth updates a Google social auth record
|
||||
func (r *UserRepository) UpdateGoogleSocialAuth(auth *models.GoogleSocialAuth) error {
|
||||
return r.db.Save(auth).Error
|
||||
}
|
||||
|
||||
// WithContext returns a copy of the repository whose underlying *gorm.DB carries
|
||||
// the supplied context. SQL emitted via this copy gets attached to ctx's trace span
|
||||
|
||||
@@ -2,7 +2,6 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -78,99 +77,25 @@ func TestUserRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUserRepository_GetOrCreateToken(t *testing.T) {
|
||||
func TestUserRepository_FindByKratosID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
user := testutil.CreateTestUser(t, db, "kratosuser", "kratos@example.com", "")
|
||||
|
||||
// Create token
|
||||
token1, err := repo.GetOrCreateToken(user.ID)
|
||||
found, err := repo.FindByKratosID(user.KratosID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token1.Key)
|
||||
|
||||
// Should return same token
|
||||
token2, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token1.Key, token2.Key)
|
||||
assert.Equal(t, user.ID, found.ID)
|
||||
assert.Equal(t, user.KratosID, found.KratosID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindTokenByKey(t *testing.T) {
|
||||
func TestUserRepository_FindByKratosID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindTokenByKey(token.Plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, found.Key)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindTokenByKey_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindTokenByKey("nonexistent-token-key")
|
||||
_, err := repo.FindByKratosID("nonexistent-kratos-id")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteToken(token.Plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindTokenByKey(token.Plaintext)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
err := repo.DeleteToken("nonexistent-key")
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteTokenByUserID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteTokenByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should be gone
|
||||
var count int64
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.CreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Equal(t, user.ID, token.UserID)
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_UpdateLastLogin(t *testing.T) {
|
||||
@@ -255,54 +180,6 @@ func TestUserRepository_FindByIDWithProfile_NotFound(t *testing.T) {
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_ConfirmationCode_Lifecycle(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create confirmation code
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreateConfirmationCode(user.ID, "123456", expiresAt)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, code.ID)
|
||||
|
||||
// Find it
|
||||
found, err := repo.FindConfirmationCode(user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, code.ID, found.ID)
|
||||
|
||||
// Mark as used
|
||||
err = repo.MarkConfirmationCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not find used code
|
||||
_, err = repo.FindConfirmationCode(user.ID, "123456")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserRepository_ConfirmationCode_InvalidatesExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
// Create first code
|
||||
code1, err := repo.CreateConfirmationCode(user.ID, "111111", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create second code (should invalidate first)
|
||||
_, err = repo.CreateConfirmationCode(user.ID, "222222", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First code should be used/invalidated
|
||||
var c models.ConfirmationCode
|
||||
db.First(&c, code1.ID)
|
||||
assert.True(t, c.IsUsed)
|
||||
}
|
||||
|
||||
func TestUserRepository_Transaction(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
@@ -331,105 +208,6 @@ func TestUserRepository_DB(t *testing.T) {
|
||||
assert.NotNil(t, repo.DB())
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByAppleID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_123",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
found, err := repo.FindByAppleID("apple_sub_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByAppleID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByAppleID("nonexistent_apple_id")
|
||||
assert.ErrorIs(t, err, ErrAppleAuthNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByGoogleID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_123",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
found, err := repo.FindByGoogleID("google_sub_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByGoogleID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByGoogleID("nonexistent_google_id")
|
||||
assert.ErrorIs(t, err, ErrGoogleAuthNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndUpdateAppleSocialAuth(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
|
||||
auth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_456",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
err := repo.CreateAppleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, auth.ID)
|
||||
|
||||
auth.Email = "updated@test.com"
|
||||
err = repo.UpdateAppleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByAppleID("apple_sub_456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated@test.com", found.Email)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndUpdateGoogleSocialAuth(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
|
||||
auth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_456",
|
||||
Email: "google@test.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
err := repo.CreateGoogleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, auth.ID)
|
||||
|
||||
auth.Name = "Updated Name"
|
||||
err = repo.UpdateGoogleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByGoogleID("google_sub_456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", found.Name)
|
||||
}
|
||||
|
||||
func TestUserRepository_SearchUsers(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
@@ -2,7 +2,6 @@ package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -11,207 +10,6 @@ import (
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
// === Password Reset Code Lifecycle ===
|
||||
|
||||
func TestUserRepository_PasswordResetCode_Lifecycle(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_abc123", "reset_token_xyz", expiresAt)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, code.ID)
|
||||
assert.Equal(t, "hash_abc123", code.CodeHash)
|
||||
assert.Equal(t, "reset_token_xyz", code.ResetToken)
|
||||
assert.False(t, code.Used)
|
||||
assert.Equal(t, 0, code.Attempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreatePasswordResetCode_InvalidatesExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
code1, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First code should be marked as used
|
||||
var c models.PasswordResetCode
|
||||
db.First(&c, code1.ID)
|
||||
assert.True(t, c.Used)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_abc", "token_abc", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, foundUser, err := repo.FindPasswordResetCodeByEmail("test@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, foundUser.ID)
|
||||
assert.Equal(t, "hash_abc", found.CodeHash)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail_UserNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, _, err := repo.FindPasswordResetCodeByEmail("nonexistent@example.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail_NoCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, _, err := repo.FindPasswordResetCodeByEmail("test@example.com")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_xyz", "token_xyz", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindPasswordResetCodeByToken("token_xyz")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hash_xyz", found.CodeHash)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindPasswordResetCodeByToken("nonexistent_token")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_Expired(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Already expired
|
||||
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_exp", "token_exp", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_exp")
|
||||
assert.ErrorIs(t, err, ErrCodeExpired)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_Used(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_used", "token_used", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as used
|
||||
err = repo.MarkPasswordResetCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_used")
|
||||
assert.ErrorIs(t, err, ErrCodeUsed)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_TooManyAttempts(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_attempts", "token_attempts", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Max out attempts
|
||||
for i := 0; i < 5; i++ {
|
||||
err = repo.IncrementResetCodeAttempts(code.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_attempts")
|
||||
assert.ErrorIs(t, err, ErrTooManyAttempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_IncrementResetCodeAttempts(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_inc", "token_inc", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.IncrementResetCodeAttempts(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.PasswordResetCode
|
||||
db.First(&updated, code.ID)
|
||||
assert.Equal(t, 1, updated.Attempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_MarkPasswordResetCodeUsed(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_mark", "token_mark", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.MarkPasswordResetCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.PasswordResetCode
|
||||
db.First(&updated, code.ID)
|
||||
assert.True(t, updated.Used)
|
||||
}
|
||||
|
||||
func TestUserRepository_CountRecentPasswordResetRequests(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := repo.CountRecentPasswordResetRequests(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
// === FindUsersInSharedResidences ===
|
||||
|
||||
func TestUserRepository_FindUsersInSharedResidences(t *testing.T) {
|
||||
@@ -301,33 +99,6 @@ func TestUserRepository_FindProfilesInSharedResidences(t *testing.T) {
|
||||
assert.Len(t, profiles, 2)
|
||||
}
|
||||
|
||||
// === ConfirmationCode Expired ===
|
||||
|
||||
func TestUserRepository_FindConfirmationCode_Expired(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create already-expired code
|
||||
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
|
||||
_, err := repo.CreateConfirmationCode(user.ID, "999999", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindConfirmationCode(user.ID, "999999")
|
||||
assert.ErrorIs(t, err, ErrCodeExpired)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindConfirmationCode_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, err := repo.FindConfirmationCode(user.ID, "000000")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
// === Transaction Rollback ===
|
||||
|
||||
func TestUserRepository_Transaction_Rollback(t *testing.T) {
|
||||
|
||||
@@ -19,7 +19,6 @@ func TestUserRepository_Create(t *testing.T) {
|
||||
Email: "test@example.com",
|
||||
IsActive: true,
|
||||
}
|
||||
user.SetPassword("Password123")
|
||||
|
||||
err := repo.Create(user)
|
||||
require.NoError(t, err)
|
||||
@@ -192,39 +191,11 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
t.Run("email user", func(t *testing.T) {
|
||||
t.Run("kratos user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123")
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "email", provider)
|
||||
})
|
||||
|
||||
t.Run("apple user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_test",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "apple", provider)
|
||||
})
|
||||
|
||||
t.Run("google user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_test",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "google", provider)
|
||||
assert.Equal(t, "kratos", provider) // All users are Kratos-managed
|
||||
})
|
||||
}
|
||||
|
||||
@@ -235,11 +206,9 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123")
|
||||
|
||||
// Create profile and token
|
||||
// Create profile
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
require.NoError(t, db.Create(profile).Error)
|
||||
_, err := models.GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var fileURLs []string
|
||||
txErr := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
@@ -261,10 +230,6 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
// Verify profile is gone
|
||||
db.Model(&models.UserProfile{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
// Verify token is gone
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("returns file URLs for cleanup", func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user