feat(auth): replace hand-rolled auth with Ory Kratos — phase 2 backend
Delegates all credential management (login, register, password reset, email verification, social sign-in) to Ory Kratos. The Go API now acts as a resource server: the new KratosAuth middleware validates sessions against the Kratos whoami endpoint, writes the local User mirror into Echo context, and all existing domain handlers continue working unchanged. Hand-rolled token auth, AuthToken model, apple_auth/ google_auth services, and the auth refresh flow are removed. Tests are updated to use the fake-token middleware pattern so existing integration assertions require no rewrite. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,301 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
appleKeysURL = "https://appleid.apple.com/auth/keys"
|
||||
appleIssuer = "https://appleid.apple.com"
|
||||
appleKeysCacheTTL = 24 * time.Hour
|
||||
appleKeysCacheKey = "apple:public_keys"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAppleToken = errors.New("invalid Apple identity token")
|
||||
ErrAppleTokenExpired = errors.New("Apple identity token has expired")
|
||||
ErrInvalidAppleAudience = errors.New("invalid Apple token audience")
|
||||
ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer")
|
||||
ErrAppleKeyNotFound = errors.New("Apple public key not found")
|
||||
)
|
||||
|
||||
// AppleJWKS represents Apple's JSON Web Key Set
|
||||
type AppleJWKS struct {
|
||||
Keys []AppleJWK `json:"keys"`
|
||||
}
|
||||
|
||||
// AppleJWK represents a single JSON Web Key from Apple
|
||||
type AppleJWK struct {
|
||||
Kty string `json:"kty"` // Key type (RSA)
|
||||
Kid string `json:"kid"` // Key ID
|
||||
Use string `json:"use"` // Key use (sig)
|
||||
Alg string `json:"alg"` // Algorithm (RS256)
|
||||
N string `json:"n"` // RSA modulus
|
||||
E string `json:"e"` // RSA exponent
|
||||
}
|
||||
|
||||
// AppleTokenClaims represents the claims in an Apple identity token
|
||||
type AppleTokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string
|
||||
IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string
|
||||
AuthTime int64 `json:"auth_time,omitempty"`
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified (handles both bool and string types)
|
||||
func (c *AppleTokenClaims) IsEmailVerified() bool {
|
||||
switch v := c.EmailVerified.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v == "true"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsPrivateRelayEmail returns whether the email is a private relay email
|
||||
func (c *AppleTokenClaims) IsPrivateRelayEmail() bool {
|
||||
switch v := c.IsPrivateEmail.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
return v == "true"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AppleAuthService handles Apple Sign In token verification
|
||||
type AppleAuthService struct {
|
||||
cache *CacheService
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewAppleAuthService creates a new Apple auth service
|
||||
func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService {
|
||||
return &AppleAuthService{
|
||||
cache: cache,
|
||||
config: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an Apple identity token and returns the claims
|
||||
func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) {
|
||||
// Parse the token header to get the key ID
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrInvalidAppleToken
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||
}
|
||||
|
||||
var header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||
}
|
||||
|
||||
// Get the public key for this key ID
|
||||
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and verify the token
|
||||
token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrAppleTokenExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*AppleTokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidAppleToken
|
||||
}
|
||||
|
||||
// Verify the issuer
|
||||
if claims.Issuer != appleIssuer {
|
||||
return nil, ErrInvalidAppleIssuer
|
||||
}
|
||||
|
||||
// Verify the audience (should be our bundle ID)
|
||||
if !s.verifyAudience(claims.Audience) {
|
||||
return nil, ErrInvalidAppleAudience
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the token audience matches our client ID.
|
||||
// In production (non-debug), an empty clientID causes verification to fail
|
||||
// rather than silently bypassing the check.
|
||||
func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool {
|
||||
clientID := s.config.AppleAuth.ClientID
|
||||
if clientID == "" {
|
||||
if s.config.Server.Debug {
|
||||
// In debug mode only, skip audience verification for local development
|
||||
return true
|
||||
}
|
||||
// In production, missing client ID means we cannot verify the audience
|
||||
return false
|
||||
}
|
||||
|
||||
for _, aud := range audience {
|
||||
if aud == clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPublicKey retrieves the public key for the given key ID
|
||||
func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||
// Try to get from cache first
|
||||
keys, err := s.getCachedKeys(ctx)
|
||||
if err != nil || keys == nil {
|
||||
// Fetch fresh keys
|
||||
keys, err = s.fetchApplePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Find the key with the matching ID
|
||||
for keyID, pubKey := range keys {
|
||||
if keyID == kid {
|
||||
return pubKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Key not found in cache, try fetching fresh keys
|
||||
keys, err = s.fetchApplePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
return nil, ErrAppleKeyNotFound
|
||||
}
|
||||
|
||||
// getCachedKeys retrieves cached Apple public keys from Redis
|
||||
func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
if s.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, err := s.cache.GetString(ctx, appleKeysCacheKey)
|
||||
if err != nil || data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var jwks AppleJWKS
|
||||
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.parseJWKS(&jwks)
|
||||
}
|
||||
|
||||
// fetchApplePublicKeys fetches Apple's public keys and caches them
|
||||
func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Apple keys: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var jwks AppleJWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode Apple keys: %w", err)
|
||||
}
|
||||
|
||||
// Cache the keys
|
||||
if s.cache != nil {
|
||||
keysJSON, _ := json.Marshal(jwks)
|
||||
_ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL)
|
||||
}
|
||||
|
||||
return s.parseJWKS(&jwks)
|
||||
}
|
||||
|
||||
// parseJWKS converts Apple's JWKS to RSA public keys
|
||||
func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) {
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Decode the modulus (N)
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
|
||||
// Decode the exponent (E)
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: n,
|
||||
E: e,
|
||||
}
|
||||
|
||||
keys[key.Kid] = pubKey
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
)
|
||||
|
||||
func setupRefreshTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
|
||||
t.Helper()
|
||||
user := &models.User{
|
||||
Username: "refreshtest",
|
||||
Email: "refresh@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
return user
|
||||
}
|
||||
|
||||
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
|
||||
t.Helper()
|
||||
token := &models.AuthToken{
|
||||
UserID: userID,
|
||||
}
|
||||
require.NoError(t, db.Create(token).Error)
|
||||
|
||||
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
|
||||
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
|
||||
require.NoError(t, db.Model(token).Update("created", backdated).Error)
|
||||
token.Created = backdated
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func newTestAuthService(db *gorm.DB) *AuthService {
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret",
|
||||
TokenExpiryDays: 90,
|
||||
TokenRefreshDays: 60,
|
||||
},
|
||||
}
|
||||
return NewAuthService(userRepo, cfg)
|
||||
}
|
||||
|
||||
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
|
||||
assert.Contains(t, resp.Message, "still valid")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
|
||||
assert.Contains(t, resp.Message, "refreshed")
|
||||
|
||||
// Verify old token was deleted
|
||||
var count int64
|
||||
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "old token should be deleted")
|
||||
|
||||
// Verify new token exists in DB
|
||||
// resp.Token is the raw token; the DB stores its hash.
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "new token should exist in DB")
|
||||
|
||||
// Verify new token belongs to the same user
|
||||
var newToken models.AuthToken
|
||||
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
|
||||
assert.Equal(t, user.ID, newToken.UserID)
|
||||
}
|
||||
|
||||
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
|
||||
// meaning it enters the renewal window
|
||||
token := createTokenWithAge(t, db, user.ID, 61)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 75)
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
// Try to refresh with a different user ID
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
||||
db := setupRefreshTestDB(t)
|
||||
user := createRefreshTestUser(t, db)
|
||||
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -19,195 +18,18 @@ func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository)
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
DebugFixedCodes: true,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret",
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
MaxPasswordResetRate: 3,
|
||||
TokenExpiryDays: 90,
|
||||
TokenRefreshDays: 60,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
service.SetNotificationRepository(notifRepo)
|
||||
return service, userRepo
|
||||
}
|
||||
|
||||
// === Login ===
|
||||
|
||||
func TestAuthService_Login(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "testuser", resp.User.Username)
|
||||
}
|
||||
|
||||
func TestAuthService_Login_ByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
}
|
||||
|
||||
func TestAuthService_Login_InvalidCredentials(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "WrongPassword1",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_Login_UserNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_Login_InactiveUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123")
|
||||
// Deactivate
|
||||
user.IsActive = false
|
||||
db.Save(user)
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "inactive",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
// Audit L1: inactive accounts return the same generic error as bad
|
||||
// credentials so login does not disclose which accounts exist.
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === Register ===
|
||||
|
||||
func TestAuthService_Register(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, code, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "newuser", resp.User.Username)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes=true
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DuplicateUsername(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{DebugFixedCodes: true},
|
||||
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123")
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "taken",
|
||||
Email: "different@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DuplicateEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{DebugFixedCodes: true},
|
||||
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123")
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "taken@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
|
||||
}
|
||||
|
||||
// === GetCurrentUser ===
|
||||
|
||||
func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -218,7 +40,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", resp.Username)
|
||||
assert.Equal(t, "test@test.com", resp.Email)
|
||||
assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth
|
||||
assert.Equal(t, "kratos", resp.AuthProvider) // All users are Kratos-managed
|
||||
}
|
||||
|
||||
// === UpdateProfile ===
|
||||
@@ -226,9 +48,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
func TestAuthService_UpdateProfile(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -250,9 +70,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
|
||||
func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123")
|
||||
@@ -271,9 +89,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
|
||||
func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -290,443 +106,10 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
|
||||
assert.Equal(t, "test@test.com", resp.Email)
|
||||
}
|
||||
|
||||
// === VerifyEmail ===
|
||||
|
||||
func TestAuthService_VerifyEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register a user (creates confirmation code)
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the user ID
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify with the debug code
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify again — should get already verified error
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
|
||||
// but a wrong code should use the normal path
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "000000")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === ResendVerificationCode ===
|
||||
|
||||
func TestAuthService_ResendVerificationCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
code, err := service.ResendVerificationCode(context.Background(), user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||
}
|
||||
|
||||
func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register and verify
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.VerifyEmail(context.Background(), user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.ResendVerificationCode(context.Background(), user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||
}
|
||||
|
||||
// === ForgotPassword ===
|
||||
|
||||
func TestAuthService_ForgotPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register a user first
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||
assert.NotNil(t, user)
|
||||
assert.Equal(t, "test@test.com", user.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Should not reveal that email doesn't exist
|
||||
code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, code)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
// === ResetPassword ===
|
||||
|
||||
func TestAuthService_ResetPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Forgot password
|
||||
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify reset code to get the token
|
||||
resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resetToken)
|
||||
|
||||
// Reset password
|
||||
err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Login with new password
|
||||
loginReq := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "NewPassword123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, loginResp.Token)
|
||||
}
|
||||
|
||||
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
|
||||
}
|
||||
|
||||
// === Logout ===
|
||||
|
||||
func TestAuthService_Logout(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
// Login first
|
||||
loginReq := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Logout
|
||||
err = service.Logout(context.Background(), loginResp.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should be deleted — refreshing should fail
|
||||
_, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === DeleteAccount ===
|
||||
|
||||
func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "Password123"
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongPassword := "WrongPassword1"
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
password := "Password123"
|
||||
_, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
||||
}
|
||||
|
||||
// === Helper functions ===
|
||||
|
||||
func TestGenerateSixDigitCode(t *testing.T) {
|
||||
code := generateSixDigitCode()
|
||||
assert.Len(t, code, 6)
|
||||
// Should be numeric
|
||||
for _, c := range code {
|
||||
assert.True(t, c >= '0' && c <= '9', "code should contain only digits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateResetToken(t *testing.T) {
|
||||
token := generateResetToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
|
||||
}
|
||||
|
||||
func TestGetStringOrEmpty(t *testing.T) {
|
||||
s := "hello"
|
||||
assert.Equal(t, "hello", getStringOrEmpty(&s))
|
||||
assert.Equal(t, "", getStringOrEmpty(nil))
|
||||
}
|
||||
|
||||
func TestIsPrivateRelayEmail(t *testing.T) {
|
||||
assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com"))
|
||||
assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM"))
|
||||
assert.False(t, isPrivateRelayEmail("user@gmail.com"))
|
||||
}
|
||||
|
||||
func TestGetEmailFromRequest(t *testing.T) {
|
||||
email := "req@test.com"
|
||||
assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com"))
|
||||
assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com"))
|
||||
empty := ""
|
||||
assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com"))
|
||||
}
|
||||
|
||||
// === getEmailOrDefault ===
|
||||
|
||||
func TestGetEmailOrDefault(t *testing.T) {
|
||||
// Non-empty email returns itself
|
||||
assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com"))
|
||||
|
||||
// Empty email returns a generated placeholder
|
||||
result := getEmailOrDefault("")
|
||||
assert.Contains(t, result, "@privaterelay.appleid.com")
|
||||
assert.Contains(t, result, "apple_")
|
||||
}
|
||||
|
||||
// === generateUniqueUsername ===
|
||||
|
||||
func TestGenerateUniqueUsername(t *testing.T) {
|
||||
// Normal email generates username from email prefix
|
||||
username := generateUniqueUsername("john@test.com", nil)
|
||||
assert.Contains(t, username, "john_")
|
||||
|
||||
// Private relay email falls back to first name
|
||||
firstName := "Jane"
|
||||
username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName)
|
||||
assert.Contains(t, username, "jane_")
|
||||
|
||||
// Private relay email and no first name — fallback
|
||||
username = generateUniqueUsername("abc@privaterelay.appleid.com", nil)
|
||||
assert.Contains(t, username, "user_")
|
||||
|
||||
// Empty email with first name
|
||||
firstName2 := "Bob"
|
||||
username = generateUniqueUsername("", &firstName2)
|
||||
assert.Contains(t, username, "bob_")
|
||||
|
||||
// Empty email and no first name
|
||||
username = generateUniqueUsername("", nil)
|
||||
assert.Contains(t, username, "user_")
|
||||
}
|
||||
|
||||
// === generateGoogleUsername ===
|
||||
|
||||
func TestGenerateGoogleUsername(t *testing.T) {
|
||||
// Normal email
|
||||
username := generateGoogleUsername("john@gmail.com", "John")
|
||||
assert.Contains(t, username, "john_")
|
||||
|
||||
// Empty email falls back to first name
|
||||
username = generateGoogleUsername("", "Alice")
|
||||
assert.Contains(t, username, "alice_")
|
||||
|
||||
// Empty email and empty first name — fallback
|
||||
username = generateGoogleUsername("", "")
|
||||
assert.Contains(t, username, "google_")
|
||||
}
|
||||
|
||||
// === Login with empty password ===
|
||||
|
||||
func TestAuthService_Login_EmptyPassword(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === ForgotPassword rate limiting ===
|
||||
|
||||
func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make max allowed reset requests (3 based on setup)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// The 4th should be rate limited
|
||||
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === VerifyResetCode with wrong code ===
|
||||
|
||||
func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code but with debug mode, "123456" works, "000000" should fail
|
||||
_, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === VerifyResetCode with nonexistent email ===
|
||||
|
||||
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
_, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === UpdateProfile — change email to new email ===
|
||||
|
||||
func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
@@ -742,25 +125,44 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
|
||||
assert.Equal(t, "newemail@test.com", resp.Email)
|
||||
}
|
||||
|
||||
// === DeleteAccount — empty password string ===
|
||||
// === DeleteAccount ===
|
||||
|
||||
func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
|
||||
func TestAuthService_DeleteAccount_WithConfirmation(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
|
||||
_ = user
|
||||
|
||||
confirmation := "DELETE"
|
||||
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &confirmation)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_WrongConfirmation(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
|
||||
|
||||
wrongConf := "delete"
|
||||
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &wrongConf)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_NoConfirmation(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
|
||||
|
||||
_, err := service.DeleteAccount(context.Background(), user.ID, nil, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(context.Background(), registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
emptyPw := ""
|
||||
_, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||
confirmation := "DELETE"
|
||||
_, err := service.DeleteAccount(context.Background(), 99999, nil, &confirmation)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
||||
}
|
||||
|
||||
// === SetNotificationRepository ===
|
||||
@@ -769,35 +171,10 @@ func TestAuthService_SetNotificationRepository(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
assert.Nil(t, service.notificationRepo)
|
||||
|
||||
service.SetNotificationRepository(notifRepo)
|
||||
assert.NotNil(t, service.notificationRepo)
|
||||
}
|
||||
|
||||
// === Register creates profile and notification preferences ===
|
||||
|
||||
func TestAuthService_Register_CreatesProfile(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "profileuser",
|
||||
Email: "profile@test.com",
|
||||
Password: "Password123",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
}
|
||||
|
||||
resp, _, err := service.Register(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "profileuser", resp.User.Username)
|
||||
|
||||
// Profile should exist
|
||||
profile, err := userRepo.GetOrCreateProfile(resp.User.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, profile)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// CacheService provides Redis caching functionality
|
||||
@@ -134,116 +133,6 @@ func (c *CacheService) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Auth token cache helpers
|
||||
const (
|
||||
AuthTokenPrefix = "auth_token_"
|
||||
TokenCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// authTokenCacheKey returns the Redis key for an auth token. The raw token
|
||||
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
|
||||
func authTokenCacheKey(token string) string {
|
||||
return AuthTokenPrefix + models.HashToken(token)
|
||||
}
|
||||
|
||||
// CacheAuthToken caches a user ID for a token
|
||||
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
|
||||
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// GetCachedAuthToken gets a cached user ID for a token
|
||||
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var userID uint
|
||||
_, err = fmt.Sscanf(val, "%d", &userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
|
||||
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
|
||||
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var userID uint
|
||||
var createdUnix int64
|
||||
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
|
||||
if n < 1 {
|
||||
return 0, 0, fmt.Errorf("invalid cached token format")
|
||||
}
|
||||
return userID, createdUnix, nil
|
||||
}
|
||||
|
||||
// InvalidateAuthToken removes a cached token
|
||||
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
||||
return c.Delete(ctx, authTokenCacheKey(token))
|
||||
}
|
||||
|
||||
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
|
||||
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
|
||||
// stored hash directly — used to evict a user's prior token on re-login
|
||||
// (audit MEDIUM-1), where the server no longer has the plaintext.
|
||||
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
|
||||
keys := make([]string, 0, len(hashes))
|
||||
for _, h := range hashes {
|
||||
if h != "" {
|
||||
keys = append(keys, AuthTokenPrefix+h)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.Delete(ctx, keys...)
|
||||
}
|
||||
|
||||
// --- Per-account login-failure tracking (audit M5) ---
|
||||
|
||||
const loginFailPrefix = "login_fail:"
|
||||
|
||||
// RegisterLoginFailure records a failed login for an account from a given
|
||||
// source IP, and returns the number of DISTINCT source IPs that have failed
|
||||
// for this account within the window. Tracking distinct IPs as a set rather
|
||||
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
|
||||
// run the count up and lock a victim out by knowing only their email — a
|
||||
// single IP is bounded by the per-IP edge/app rate limiters instead. A
|
||||
// genuinely distributed credential-stuffing attack still trips the lockout.
|
||||
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
|
||||
key := loginFailPrefix + identifier
|
||||
member := ip
|
||||
if member == "" {
|
||||
member = "unknown"
|
||||
}
|
||||
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Refresh the TTL on each failure: an active attack keeps the window
|
||||
// open, while a quiet account ages out `window` after its last failure.
|
||||
_ = c.client.Expire(ctx, key, window).Err()
|
||||
return c.client.SCard(ctx, key).Result()
|
||||
}
|
||||
|
||||
// LoginFailureIPCount returns how many distinct source IPs have failed to log
|
||||
// in to this account within the window (audit MEDIUM-3). SCard on a missing
|
||||
// key returns 0.
|
||||
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
|
||||
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
|
||||
}
|
||||
|
||||
// ClearLoginFailures resets the failed-login IP set after a successful login.
|
||||
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
|
||||
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
|
||||
}
|
||||
|
||||
// Static data cache helpers
|
||||
const (
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
|
||||
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
googleKeysCacheTTL = 24 * time.Hour
|
||||
googleKeysCacheKey = "google:public_keys"
|
||||
)
|
||||
|
||||
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
|
||||
var googleIssuers = map[string]bool{
|
||||
"accounts.google.com": true,
|
||||
"https://accounts.google.com": true,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
|
||||
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
|
||||
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
|
||||
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
|
||||
ErrGoogleKeyNotFound = errors.New("Google public key not found")
|
||||
)
|
||||
|
||||
// GoogleJWKS represents Google's JSON Web Key Set.
|
||||
type GoogleJWKS struct {
|
||||
Keys []GoogleJWK `json:"keys"`
|
||||
}
|
||||
|
||||
// GoogleJWK represents a single JSON Web Key from Google.
|
||||
type GoogleJWK struct {
|
||||
Kty string `json:"kty"` // Key type (RSA)
|
||||
Kid string `json:"kid"` // Key ID
|
||||
Use string `json:"use"` // Key use (sig)
|
||||
Alg string `json:"alg"` // Algorithm (RS256)
|
||||
N string `json:"n"` // RSA modulus
|
||||
E string `json:"e"` // RSA exponent
|
||||
}
|
||||
|
||||
// GoogleTokenClaims represents the claims in a Google ID token JWT.
|
||||
type GoogleTokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified bool `json:"email_verified,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Azp string `json:"azp,omitempty"` // Authorized party
|
||||
}
|
||||
|
||||
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
|
||||
type GoogleTokenInfo struct {
|
||||
Sub string // Unique Google user ID
|
||||
Email string
|
||||
EmailVerified string // "true" or "false" — string for caller compatibility
|
||||
Name string
|
||||
GivenName string
|
||||
FamilyName string
|
||||
Picture string
|
||||
Aud string
|
||||
Azp string
|
||||
Iss string
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified.
|
||||
func (t *GoogleTokenInfo) IsEmailVerified() bool {
|
||||
return t.EmailVerified == "true"
|
||||
}
|
||||
|
||||
// GoogleAuthService handles Google Sign In token verification.
|
||||
type GoogleAuthService struct {
|
||||
cache *CacheService
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewGoogleAuthService creates a new Google auth service.
|
||||
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
|
||||
return &GoogleAuthService{
|
||||
cache: cache,
|
||||
config: cfg,
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
|
||||
// the RS256 signature against Google's published JWKS and the iss, aud, and
|
||||
// exp claims. It never sends the token to a third-party endpoint, so it no
|
||||
// longer depends on the deprecated tokeninfo service and never leaks the
|
||||
// token in a request URL.
|
||||
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
|
||||
// Parse the token header to get the key ID.
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||
}
|
||||
var header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||
}
|
||||
|
||||
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
|
||||
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrGoogleTokenExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*GoogleTokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
// Verify the issuer (audit C3).
|
||||
if !googleIssuers[claims.Issuer] {
|
||||
return nil, ErrInvalidGoogleIssuer
|
||||
}
|
||||
|
||||
// Verify the audience matches one of our configured client IDs.
|
||||
if !s.verifyAudience(claims.Audience, claims.Azp) {
|
||||
return nil, ErrInvalidGoogleAudience
|
||||
}
|
||||
|
||||
if claims.Subject == "" {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
emailVerified := "false"
|
||||
if claims.EmailVerified {
|
||||
emailVerified = "true"
|
||||
}
|
||||
aud := ""
|
||||
if len(claims.Audience) > 0 {
|
||||
aud = claims.Audience[0]
|
||||
}
|
||||
return &GoogleTokenInfo{
|
||||
Sub: claims.Subject,
|
||||
Email: claims.Email,
|
||||
EmailVerified: emailVerified,
|
||||
Name: claims.Name,
|
||||
GivenName: claims.GivenName,
|
||||
FamilyName: claims.FamilyName,
|
||||
Picture: claims.Picture,
|
||||
Aud: aud,
|
||||
Azp: claims.Azp,
|
||||
Iss: claims.Issuer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// verifyAudience checks the token audience against our configured client IDs.
|
||||
// In production (non-debug) an empty client ID fails verification rather than
|
||||
// silently bypassing the check.
|
||||
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
|
||||
clientID := s.config.GoogleAuth.ClientID
|
||||
if clientID == "" {
|
||||
// In debug mode only, skip audience verification for local development.
|
||||
return s.config.Server.Debug
|
||||
}
|
||||
|
||||
candidates := []string{clientID}
|
||||
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
if id := s.config.GoogleAuth.IOSClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
|
||||
for _, want := range candidates {
|
||||
if azp == want {
|
||||
return true
|
||||
}
|
||||
for _, aud := range audience {
|
||||
if aud == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPublicKey returns the RSA public key for the given key ID, using a
|
||||
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
|
||||
// (Google rotates signing keys roughly daily).
|
||||
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||
keys, err := s.getCachedKeys(ctx)
|
||||
if err != nil || keys == nil {
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
// Cache miss for this kid — keys may have rotated; fetch fresh.
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
return nil, ErrGoogleKeyNotFound
|
||||
}
|
||||
|
||||
// getCachedKeys retrieves cached Google public keys from Redis.
|
||||
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
if s.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
|
||||
if err != nil || data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
|
||||
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
|
||||
}
|
||||
if s.cache != nil {
|
||||
keysJSON, _ := json.Marshal(jwks)
|
||||
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
|
||||
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
Reference in New Issue
Block a user