feat(auth): replace hand-rolled auth with Ory Kratos — phase 2 backend
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled
Backend CI / Build (push) Has been cancelled

Delegates all credential management (login, register, password reset,
email verification, social sign-in) to Ory Kratos. The Go API now acts
as a resource server: the new KratosAuth middleware validates sessions
against the Kratos whoami endpoint, writes the local User mirror into
Echo context, and all existing domain handlers continue working
unchanged. Hand-rolled token auth, AuthToken model, apple_auth/
google_auth services, and the auth refresh flow are removed. Tests are
updated to use the fake-token middleware pattern so existing integration
assertions require no rewrite.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-05-18 17:55:56 -05:00
parent b66151ddd9
commit 81578f6e27
36 changed files with 927 additions and 7002 deletions
-301
View File
@@ -1,301 +0,0 @@
package services
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/treytartt/honeydue-api/internal/config"
)
const (
appleKeysURL = "https://appleid.apple.com/auth/keys"
appleIssuer = "https://appleid.apple.com"
appleKeysCacheTTL = 24 * time.Hour
appleKeysCacheKey = "apple:public_keys"
)
var (
ErrInvalidAppleToken = errors.New("invalid Apple identity token")
ErrAppleTokenExpired = errors.New("Apple identity token has expired")
ErrInvalidAppleAudience = errors.New("invalid Apple token audience")
ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer")
ErrAppleKeyNotFound = errors.New("Apple public key not found")
)
// AppleJWKS represents Apple's JSON Web Key Set
type AppleJWKS struct {
Keys []AppleJWK `json:"keys"`
}
// AppleJWK represents a single JSON Web Key from Apple
type AppleJWK struct {
Kty string `json:"kty"` // Key type (RSA)
Kid string `json:"kid"` // Key ID
Use string `json:"use"` // Key use (sig)
Alg string `json:"alg"` // Algorithm (RS256)
N string `json:"n"` // RSA modulus
E string `json:"e"` // RSA exponent
}
// AppleTokenClaims represents the claims in an Apple identity token
type AppleTokenClaims struct {
jwt.RegisteredClaims
Email string `json:"email,omitempty"`
EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string
IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string
AuthTime int64 `json:"auth_time,omitempty"`
}
// IsEmailVerified returns whether the email is verified (handles both bool and string types)
func (c *AppleTokenClaims) IsEmailVerified() bool {
switch v := c.EmailVerified.(type) {
case bool:
return v
case string:
return v == "true"
default:
return false
}
}
// IsPrivateRelayEmail returns whether the email is a private relay email
func (c *AppleTokenClaims) IsPrivateRelayEmail() bool {
switch v := c.IsPrivateEmail.(type) {
case bool:
return v
case string:
return v == "true"
default:
return false
}
}
// AppleAuthService handles Apple Sign In token verification
type AppleAuthService struct {
cache *CacheService
config *config.Config
client *http.Client
}
// NewAppleAuthService creates a new Apple auth service
func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService {
return &AppleAuthService{
cache: cache,
config: cfg,
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// VerifyIdentityToken verifies an Apple identity token and returns the claims
func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) {
// Parse the token header to get the key ID
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, ErrInvalidAppleToken
}
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode token header: %w", err)
}
var header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse token header: %w", err)
}
// Get the public key for this key ID
publicKey, err := s.getPublicKey(ctx, header.Kid)
if err != nil {
return nil, err
}
// Parse and verify the token
token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
// Verify the signing method
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrAppleTokenExpired
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*AppleTokenClaims)
if !ok || !token.Valid {
return nil, ErrInvalidAppleToken
}
// Verify the issuer
if claims.Issuer != appleIssuer {
return nil, ErrInvalidAppleIssuer
}
// Verify the audience (should be our bundle ID)
if !s.verifyAudience(claims.Audience) {
return nil, ErrInvalidAppleAudience
}
return claims, nil
}
// verifyAudience checks if the token audience matches our client ID.
// In production (non-debug), an empty clientID causes verification to fail
// rather than silently bypassing the check.
func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool {
clientID := s.config.AppleAuth.ClientID
if clientID == "" {
if s.config.Server.Debug {
// In debug mode only, skip audience verification for local development
return true
}
// In production, missing client ID means we cannot verify the audience
return false
}
for _, aud := range audience {
if aud == clientID {
return true
}
}
return false
}
// getPublicKey retrieves the public key for the given key ID
func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
// Try to get from cache first
keys, err := s.getCachedKeys(ctx)
if err != nil || keys == nil {
// Fetch fresh keys
keys, err = s.fetchApplePublicKeys(ctx)
if err != nil {
return nil, err
}
}
// Find the key with the matching ID
for keyID, pubKey := range keys {
if keyID == kid {
return pubKey, nil
}
}
// Key not found in cache, try fetching fresh keys
keys, err = s.fetchApplePublicKeys(ctx)
if err != nil {
return nil, err
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
return nil, ErrAppleKeyNotFound
}
// getCachedKeys retrieves cached Apple public keys from Redis
func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
if s.cache == nil {
return nil, nil
}
data, err := s.cache.GetString(ctx, appleKeysCacheKey)
if err != nil || data == "" {
return nil, nil
}
var jwks AppleJWKS
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
return nil, nil
}
return s.parseJWKS(&jwks)
}
// fetchApplePublicKeys fetches Apple's public keys and caches them
func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch Apple keys: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode)
}
var jwks AppleJWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode Apple keys: %w", err)
}
// Cache the keys
if s.cache != nil {
keysJSON, _ := json.Marshal(jwks)
_ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL)
}
return s.parseJWKS(&jwks)
}
// parseJWKS converts Apple's JWKS to RSA public keys
func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) {
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwks.Keys {
if key.Kty != "RSA" {
continue
}
// Decode the modulus (N)
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
continue
}
n := new(big.Int).SetBytes(nBytes)
// Decode the exponent (E)
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
continue
}
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
pubKey := &rsa.PublicKey{
N: n,
E: e,
}
keys[key.Kid] = pubKey
}
return keys, nil
}
-176
View File
@@ -1,176 +0,0 @@
package services
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
"github.com/treytartt/honeydue-api/internal/repositories"
)
func setupRefreshTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
require.NoError(t, err)
err = db.AutoMigrate(&models.User{}, &models.UserProfile{}, &models.AuthToken{})
require.NoError(t, err)
return db
}
func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
t.Helper()
user := &models.User{
Username: "refreshtest",
Email: "refresh@test.com",
IsActive: true,
}
require.NoError(t, user.SetPassword("Password123"))
require.NoError(t, db.Create(user).Error)
return user
}
func createTokenWithAge(t *testing.T, db *gorm.DB, userID uint, ageDays int) *models.AuthToken {
t.Helper()
token := &models.AuthToken{
UserID: userID,
}
require.NoError(t, db.Create(token).Error)
// Backdate the token's Created timestamp after creation to bypass autoCreateTime
backdated := time.Now().UTC().AddDate(0, 0, -ageDays)
require.NoError(t, db.Model(token).Update("created", backdated).Error)
token.Created = backdated
return token
}
func newTestAuthService(db *gorm.DB) *AuthService {
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{
SecretKey: "test-secret",
TokenExpiryDays: 90,
TokenRefreshDays: 60,
},
}
return NewAuthService(userRepo, cfg)
}
func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 30) // 30 days old, well within fresh window
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
assert.Contains(t, resp.Message, "still valid")
}
func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 75) // 75 days old, in renewal window (60-90)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
assert.Contains(t, resp.Message, "refreshed")
// Verify old token was deleted
var count int64
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
assert.Equal(t, int64(0), count, "old token should be deleted")
// Verify new token exists in DB
// resp.Token is the raw token; the DB stores its hash.
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
assert.Equal(t, int64(1), count, "new token should exist in DB")
// Verify new token belongs to the same user
var newToken models.AuthToken
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
assert.Equal(t, user.ID, newToken.UserID)
}
func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 91) // 91 days old, past 90-day expiry
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.token_expired")
}
func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
// Exactly 60 days: token age == refreshDays, so tokenAge < refreshDuration is false,
// meaning it enters the renewal window
token := createTokenWithAge(t, db, user.ID, 61)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
}
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), "nonexistent-token-key", user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 75)
svc := newTestAuthService(db)
// Try to refresh with a different user ID
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
}
func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
db := setupRefreshTestDB(t)
user := createRefreshTestUser(t, db)
token := createTokenWithAge(t, db, user.ID, 59) // 59 days, just under the 60-day threshold
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
}
File diff suppressed because it is too large Load Diff
+43 -666
View File
@@ -4,7 +4,6 @@ import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -19,195 +18,18 @@ func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository)
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
notifRepo := repositories.NewNotificationRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{
DebugFixedCodes: true,
},
Security: config.SecurityConfig{
SecretKey: "test-secret",
ConfirmationExpiry: 24 * time.Hour,
PasswordResetExpiry: 15 * time.Minute,
MaxPasswordResetRate: 3,
TokenExpiryDays: 90,
TokenRefreshDays: 60,
},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
service.SetNotificationRepository(notifRepo)
return service, userRepo
}
// === Login ===
func TestAuthService_Login(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "Password123",
}
resp, err := service.Login(context.Background(), req, "")
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "testuser", resp.User.Username)
}
func TestAuthService_Login_ByEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Email: "test@test.com",
Password: "Password123",
}
resp, err := service.Login(context.Background(), req, "")
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
}
func TestAuthService_Login_InvalidCredentials(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "WrongPassword1",
}
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_Login_UserNotFound(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
req := &requests.LoginRequest{
Username: "nonexistent",
Password: "Password123",
}
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_Login_InactiveUser(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123")
// Deactivate
user.IsActive = false
db.Save(user)
req := &requests.LoginRequest{
Username: "inactive",
Password: "Password123",
}
_, err := service.Login(context.Background(), req, "")
// Audit L1: inactive accounts return the same generic error as bad
// credentials so login does not disclose which accounts exist.
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
// === Register ===
func TestAuthService_Register(t *testing.T) {
service, _ := setupAuthService(t)
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
resp, code, err := service.Register(context.Background(), req)
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "newuser", resp.User.Username)
assert.Equal(t, "123456", code) // DebugFixedCodes=true
}
func TestAuthService_Register_DuplicateUsername(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{DebugFixedCodes: true},
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123")
req := &requests.RegisterRequest{
Username: "taken",
Email: "different@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
}
func TestAuthService_Register_DuplicateEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Server: config.ServerConfig{DebugFixedCodes: true},
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123")
req := &requests.RegisterRequest{
Username: "newuser",
Email: "taken@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
}
// === GetCurrentUser ===
func TestAuthService_GetCurrentUser(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -218,7 +40,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "testuser", resp.Username)
assert.Equal(t, "test@test.com", resp.Email)
assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth
assert.Equal(t, "kratos", resp.AuthProvider) // All users are Kratos-managed
}
// === UpdateProfile ===
@@ -226,9 +48,7 @@ func TestAuthService_GetCurrentUser(t *testing.T) {
func TestAuthService_UpdateProfile(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -250,9 +70,7 @@ func TestAuthService_UpdateProfile(t *testing.T) {
func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123")
@@ -271,9 +89,7 @@ func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -290,443 +106,10 @@ func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
assert.Equal(t, "test@test.com", resp.Email)
}
// === VerifyEmail ===
func TestAuthService_VerifyEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Register a user (creates confirmation code)
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
// Get the user ID
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
// Verify with the debug code
err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err)
// Verify again — should get already verified error
err = service.VerifyEmail(context.Background(), user.ID, "123456")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
service, _ := setupAuthService(t)
// Register
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
// but a wrong code should use the normal path
err = service.VerifyEmail(context.Background(), user.ID, "000000")
assert.Error(t, err)
}
// === ResendVerificationCode ===
func TestAuthService_ResendVerificationCode(t *testing.T) {
service, _ := setupAuthService(t)
// Register
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
code, err := service.ResendVerificationCode(context.Background(), user.ID)
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
}
func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
service, _ := setupAuthService(t)
// Register and verify
req := &requests.RegisterRequest{
Username: "newuser",
Email: "new@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("new@test.com")
require.NoError(t, err)
err = service.VerifyEmail(context.Background(), user.ID, "123456")
require.NoError(t, err)
_, err = service.ResendVerificationCode(context.Background(), user.ID)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
}
// === ForgotPassword ===
func TestAuthService_ForgotPassword(t *testing.T) {
service, _ := setupAuthService(t)
// Register a user first
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
code, user, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
assert.Equal(t, "123456", code) // DebugFixedCodes
assert.NotNil(t, user)
assert.Equal(t, "test@test.com", user.Email)
}
func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
// Should not reveal that email doesn't exist
code, user, err := service.ForgotPassword(context.Background(), "nonexistent@test.com")
require.NoError(t, err)
assert.Empty(t, code)
assert.Nil(t, user)
}
// === ResetPassword ===
func TestAuthService_ResetPassword(t *testing.T) {
service, _ := setupAuthService(t)
// Register
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
// Forgot password
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
// Verify reset code to get the token
resetToken, err := service.VerifyResetCode(context.Background(), "test@test.com", "123456")
require.NoError(t, err)
assert.NotEmpty(t, resetToken)
// Reset password
err = service.ResetPassword(context.Background(), resetToken, "NewPassword123")
require.NoError(t, err)
// Login with new password
loginReq := &requests.LoginRequest{
Username: "testuser",
Password: "NewPassword123",
}
loginResp, err := service.Login(context.Background(), loginReq, "")
require.NoError(t, err)
assert.NotEmpty(t, loginResp.Token)
}
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
service, _ := setupAuthService(t)
err := service.ResetPassword(context.Background(), "invalid-token", "NewPassword123")
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
}
// === Logout ===
func TestAuthService_Logout(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
// Login first
loginReq := &requests.LoginRequest{
Username: "testuser",
Password: "Password123",
}
loginResp, err := service.Login(context.Background(), loginReq, "")
require.NoError(t, err)
// Logout
err = service.Logout(context.Background(), loginResp.Token)
require.NoError(t, err)
// Token should be deleted — refreshing should fail
_, err = service.RefreshToken(context.Background(), loginResp.Token, user.ID)
assert.Error(t, err)
}
// === DeleteAccount ===
func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
service, _ := setupAuthService(t)
// Register
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
password := "Password123"
_, err = service.DeleteAccount(context.Background(), user.ID, &password, nil)
require.NoError(t, err)
}
func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
wrongPassword := "WrongPassword1"
_, err = service.DeleteAccount(context.Background(), user.ID, &wrongPassword, nil)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
_, err = service.DeleteAccount(context.Background(), user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
}
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t)
password := "Password123"
_, err := service.DeleteAccount(context.Background(), 99999, &password, nil)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
}
// === Helper functions ===
func TestGenerateSixDigitCode(t *testing.T) {
code := generateSixDigitCode()
assert.Len(t, code, 6)
// Should be numeric
for _, c := range code {
assert.True(t, c >= '0' && c <= '9', "code should contain only digits")
}
}
func TestGenerateResetToken(t *testing.T) {
token := generateResetToken()
assert.NotEmpty(t, token)
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
}
func TestGetStringOrEmpty(t *testing.T) {
s := "hello"
assert.Equal(t, "hello", getStringOrEmpty(&s))
assert.Equal(t, "", getStringOrEmpty(nil))
}
func TestIsPrivateRelayEmail(t *testing.T) {
assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com"))
assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM"))
assert.False(t, isPrivateRelayEmail("user@gmail.com"))
}
func TestGetEmailFromRequest(t *testing.T) {
email := "req@test.com"
assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com"))
assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com"))
empty := ""
assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com"))
}
// === getEmailOrDefault ===
func TestGetEmailOrDefault(t *testing.T) {
// Non-empty email returns itself
assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com"))
// Empty email returns a generated placeholder
result := getEmailOrDefault("")
assert.Contains(t, result, "@privaterelay.appleid.com")
assert.Contains(t, result, "apple_")
}
// === generateUniqueUsername ===
func TestGenerateUniqueUsername(t *testing.T) {
// Normal email generates username from email prefix
username := generateUniqueUsername("john@test.com", nil)
assert.Contains(t, username, "john_")
// Private relay email falls back to first name
firstName := "Jane"
username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName)
assert.Contains(t, username, "jane_")
// Private relay email and no first name — fallback
username = generateUniqueUsername("abc@privaterelay.appleid.com", nil)
assert.Contains(t, username, "user_")
// Empty email with first name
firstName2 := "Bob"
username = generateUniqueUsername("", &firstName2)
assert.Contains(t, username, "bob_")
// Empty email and no first name
username = generateUniqueUsername("", nil)
assert.Contains(t, username, "user_")
}
// === generateGoogleUsername ===
func TestGenerateGoogleUsername(t *testing.T) {
// Normal email
username := generateGoogleUsername("john@gmail.com", "John")
assert.Contains(t, username, "john_")
// Empty email falls back to first name
username = generateGoogleUsername("", "Alice")
assert.Contains(t, username, "alice_")
// Empty email and empty first name — fallback
username = generateGoogleUsername("", "")
assert.Contains(t, username, "google_")
}
// === Login with empty password ===
func TestAuthService_Login_EmptyPassword(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
service := NewAuthService(userRepo, cfg)
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
req := &requests.LoginRequest{
Username: "testuser",
Password: "",
}
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
// === ForgotPassword rate limiting ===
func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
// Make max allowed reset requests (3 based on setup)
for i := 0; i < 3; i++ {
_, _, err := service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
}
// The 4th should be rate limited
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
assert.Error(t, err)
}
// === VerifyResetCode with wrong code ===
func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
_, _, err = service.ForgotPassword(context.Background(), "test@test.com")
require.NoError(t, err)
// Wrong code but with debug mode, "123456" works, "000000" should fail
_, err = service.VerifyResetCode(context.Background(), "test@test.com", "000000")
assert.Error(t, err)
}
// === VerifyResetCode with nonexistent email ===
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
service, _ := setupAuthService(t)
_, err := service.VerifyResetCode(context.Background(), "nonexistent@test.com", "123456")
assert.Error(t, err)
}
// === UpdateProfile — change email to new email ===
func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
@@ -742,25 +125,44 @@ func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
assert.Equal(t, "newemail@test.com", resp.Email)
}
// === DeleteAccount — empty password string ===
// === DeleteAccount ===
func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
func TestAuthService_DeleteAccount_WithConfirmation(t *testing.T) {
service, userRepo := setupAuthService(t)
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
_ = user
confirmation := "DELETE"
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &confirmation)
require.NoError(t, err)
}
func TestAuthService_DeleteAccount_WrongConfirmation(t *testing.T) {
service, userRepo := setupAuthService(t)
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
wrongConf := "delete"
_, err := service.DeleteAccount(context.Background(), user.ID, nil, &wrongConf)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
}
func TestAuthService_DeleteAccount_NoConfirmation(t *testing.T) {
service, userRepo := setupAuthService(t)
user := testutil.CreateTestUser(t, (*userRepo).DB(), "testuser", "test@test.com", "")
_, err := service.DeleteAccount(context.Background(), user.ID, nil, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.confirmation_required")
}
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
service, _ := setupAuthService(t)
registerReq := &requests.RegisterRequest{
Username: "testuser",
Email: "test@test.com",
Password: "Password123",
}
_, _, err := service.Register(context.Background(), registerReq)
require.NoError(t, err)
user, err := service.userRepo.FindByEmail("test@test.com")
require.NoError(t, err)
emptyPw := ""
_, err = service.DeleteAccount(context.Background(), user.ID, &emptyPw, nil)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
confirmation := "DELETE"
_, err := service.DeleteAccount(context.Background(), 99999, nil, &confirmation)
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
}
// === SetNotificationRepository ===
@@ -769,35 +171,10 @@ func TestAuthService_SetNotificationRepository(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
notifRepo := repositories.NewNotificationRepository(db)
cfg := &config.Config{
Security: config.SecurityConfig{SecretKey: "test-secret"},
}
cfg := &config.Config{}
service := NewAuthService(userRepo, cfg)
assert.Nil(t, service.notificationRepo)
service.SetNotificationRepository(notifRepo)
assert.NotNil(t, service.notificationRepo)
}
// === Register creates profile and notification preferences ===
func TestAuthService_Register_CreatesProfile(t *testing.T) {
service, userRepo := setupAuthService(t)
req := &requests.RegisterRequest{
Username: "profileuser",
Email: "profile@test.com",
Password: "Password123",
FirstName: "John",
LastName: "Doe",
}
resp, _, err := service.Register(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, "profileuser", resp.User.Username)
// Profile should exist
profile, err := userRepo.GetOrCreateProfile(resp.User.ID)
require.NoError(t, err)
assert.NotNil(t, profile)
}
-111
View File
@@ -12,7 +12,6 @@ import (
"github.com/rs/zerolog/log"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
)
// CacheService provides Redis caching functionality
@@ -134,116 +133,6 @@ func (c *CacheService) Close() error {
return nil
}
// Auth token cache helpers
const (
AuthTokenPrefix = "auth_token_"
TokenCacheTTL = 5 * time.Minute
)
// authTokenCacheKey returns the Redis key for an auth token. The raw token
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
func authTokenCacheKey(token string) string {
return AuthTokenPrefix + models.HashToken(token)
}
// CacheAuthToken caches a user ID for a token
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
}
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
}
// GetCachedAuthToken gets a cached user ID for a token
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
val, err := c.GetString(ctx, authTokenCacheKey(token))
if err != nil {
return 0, err
}
var userID uint
_, err = fmt.Sscanf(val, "%d", &userID)
return userID, err
}
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
val, err := c.GetString(ctx, authTokenCacheKey(token))
if err != nil {
return 0, 0, err
}
var userID uint
var createdUnix int64
n, _ := fmt.Sscanf(val, "%d|%d", &userID, &createdUnix)
if n < 1 {
return 0, 0, fmt.Errorf("invalid cached token format")
}
return userID, createdUnix, nil
}
// InvalidateAuthToken removes a cached token
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
return c.Delete(ctx, authTokenCacheKey(token))
}
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
// stored hash directly — used to evict a user's prior token on re-login
// (audit MEDIUM-1), where the server no longer has the plaintext.
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
keys := make([]string, 0, len(hashes))
for _, h := range hashes {
if h != "" {
keys = append(keys, AuthTokenPrefix+h)
}
}
if len(keys) == 0 {
return nil
}
return c.Delete(ctx, keys...)
}
// --- Per-account login-failure tracking (audit M5) ---
const loginFailPrefix = "login_fail:"
// RegisterLoginFailure records a failed login for an account from a given
// source IP, and returns the number of DISTINCT source IPs that have failed
// for this account within the window. Tracking distinct IPs as a set rather
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
// run the count up and lock a victim out by knowing only their email — a
// single IP is bounded by the per-IP edge/app rate limiters instead. A
// genuinely distributed credential-stuffing attack still trips the lockout.
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
key := loginFailPrefix + identifier
member := ip
if member == "" {
member = "unknown"
}
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
return 0, err
}
// Refresh the TTL on each failure: an active attack keeps the window
// open, while a quiet account ages out `window` after its last failure.
_ = c.client.Expire(ctx, key, window).Err()
return c.client.SCard(ctx, key).Result()
}
// LoginFailureIPCount returns how many distinct source IPs have failed to log
// in to this account within the window (audit MEDIUM-3). SCard on a missing
// key returns 0.
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
}
// ClearLoginFailures resets the failed-login IP set after a successful login.
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
}
// Static data cache helpers
const (
-307
View File
@@ -1,307 +0,0 @@
package services
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/treytartt/honeydue-api/internal/config"
)
const (
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
googleKeysCacheTTL = 24 * time.Hour
googleKeysCacheKey = "google:public_keys"
)
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
var googleIssuers = map[string]bool{
"accounts.google.com": true,
"https://accounts.google.com": true,
}
var (
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
ErrGoogleKeyNotFound = errors.New("Google public key not found")
)
// GoogleJWKS represents Google's JSON Web Key Set.
type GoogleJWKS struct {
Keys []GoogleJWK `json:"keys"`
}
// GoogleJWK represents a single JSON Web Key from Google.
type GoogleJWK struct {
Kty string `json:"kty"` // Key type (RSA)
Kid string `json:"kid"` // Key ID
Use string `json:"use"` // Key use (sig)
Alg string `json:"alg"` // Algorithm (RS256)
N string `json:"n"` // RSA modulus
E string `json:"e"` // RSA exponent
}
// GoogleTokenClaims represents the claims in a Google ID token JWT.
type GoogleTokenClaims struct {
jwt.RegisteredClaims
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Picture string `json:"picture,omitempty"`
Azp string `json:"azp,omitempty"` // Authorized party
}
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
type GoogleTokenInfo struct {
Sub string // Unique Google user ID
Email string
EmailVerified string // "true" or "false" — string for caller compatibility
Name string
GivenName string
FamilyName string
Picture string
Aud string
Azp string
Iss string
}
// IsEmailVerified returns whether the email is verified.
func (t *GoogleTokenInfo) IsEmailVerified() bool {
return t.EmailVerified == "true"
}
// GoogleAuthService handles Google Sign In token verification.
type GoogleAuthService struct {
cache *CacheService
config *config.Config
client *http.Client
}
// NewGoogleAuthService creates a new Google auth service.
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
return &GoogleAuthService{
cache: cache,
config: cfg,
client: &http.Client{Timeout: 10 * time.Second},
}
}
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
// the RS256 signature against Google's published JWKS and the iss, aud, and
// exp claims. It never sends the token to a third-party endpoint, so it no
// longer depends on the deprecated tokeninfo service and never leaks the
// token in a request URL.
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
// Parse the token header to get the key ID.
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, ErrInvalidGoogleToken
}
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode token header: %w", err)
}
var header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse token header: %w", err)
}
publicKey, err := s.getPublicKey(ctx, header.Kid)
if err != nil {
return nil, err
}
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrGoogleTokenExpired
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*GoogleTokenClaims)
if !ok || !token.Valid {
return nil, ErrInvalidGoogleToken
}
// Verify the issuer (audit C3).
if !googleIssuers[claims.Issuer] {
return nil, ErrInvalidGoogleIssuer
}
// Verify the audience matches one of our configured client IDs.
if !s.verifyAudience(claims.Audience, claims.Azp) {
return nil, ErrInvalidGoogleAudience
}
if claims.Subject == "" {
return nil, ErrInvalidGoogleToken
}
emailVerified := "false"
if claims.EmailVerified {
emailVerified = "true"
}
aud := ""
if len(claims.Audience) > 0 {
aud = claims.Audience[0]
}
return &GoogleTokenInfo{
Sub: claims.Subject,
Email: claims.Email,
EmailVerified: emailVerified,
Name: claims.Name,
GivenName: claims.GivenName,
FamilyName: claims.FamilyName,
Picture: claims.Picture,
Aud: aud,
Azp: claims.Azp,
Iss: claims.Issuer,
}, nil
}
// verifyAudience checks the token audience against our configured client IDs.
// In production (non-debug) an empty client ID fails verification rather than
// silently bypassing the check.
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
clientID := s.config.GoogleAuth.ClientID
if clientID == "" {
// In debug mode only, skip audience verification for local development.
return s.config.Server.Debug
}
candidates := []string{clientID}
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
candidates = append(candidates, id)
}
if id := s.config.GoogleAuth.IOSClientID; id != "" {
candidates = append(candidates, id)
}
for _, want := range candidates {
if azp == want {
return true
}
for _, aud := range audience {
if aud == want {
return true
}
}
}
return false
}
// getPublicKey returns the RSA public key for the given key ID, using a
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
// (Google rotates signing keys roughly daily).
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
keys, err := s.getCachedKeys(ctx)
if err != nil || keys == nil {
keys, err = s.fetchGooglePublicKeys(ctx)
if err != nil {
return nil, err
}
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
// Cache miss for this kid — keys may have rotated; fetch fresh.
keys, err = s.fetchGooglePublicKeys(ctx)
if err != nil {
return nil, err
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
return nil, ErrGoogleKeyNotFound
}
// getCachedKeys retrieves cached Google public keys from Redis.
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
if s.cache == nil {
return nil, nil
}
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
if err != nil || data == "" {
return nil, nil
}
var jwks GoogleJWKS
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
return nil, nil
}
return s.parseJWKS(&jwks), nil
}
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
}
var jwks GoogleJWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
}
if s.cache != nil {
keysJSON, _ := json.Marshal(jwks)
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
}
return s.parseJWKS(&jwks), nil
}
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwks.Keys {
if key.Kty != "RSA" {
continue
}
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
continue
}
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
continue
}
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
}
return keys
}