Add Sign in with Apple authentication
- Add AppleSocialAuth model to store Apple ID linkages - Create AppleAuthService for JWT verification with Apple's public keys - Add AppleSignIn handler and route (POST /auth/apple-sign-in/) - Implement account linking (links Apple ID to existing accounts by email) - Add Redis caching for Apple public keys (24-hour TTL) - Support private relay emails (@privaterelay.appleid.com) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -13,14 +13,15 @@ import (
|
|||||||
|
|
||||||
// Config holds all configuration for the application
|
// Config holds all configuration for the application
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig
|
Server ServerConfig
|
||||||
Database DatabaseConfig
|
Database DatabaseConfig
|
||||||
Redis RedisConfig
|
Redis RedisConfig
|
||||||
Email EmailConfig
|
Email EmailConfig
|
||||||
Push PushConfig
|
Push PushConfig
|
||||||
Worker WorkerConfig
|
Worker WorkerConfig
|
||||||
Security SecurityConfig
|
Security SecurityConfig
|
||||||
Storage StorageConfig
|
Storage StorageConfig
|
||||||
|
AppleAuth AppleAuthConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
@@ -72,6 +73,11 @@ type PushConfig struct {
|
|||||||
FCMServerKey string
|
FCMServerKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AppleAuthConfig struct {
|
||||||
|
ClientID string // Bundle ID (e.g., com.tt.casera.CaseraDev)
|
||||||
|
TeamID string // Apple Developer Team ID
|
||||||
|
}
|
||||||
|
|
||||||
type WorkerConfig struct {
|
type WorkerConfig struct {
|
||||||
// Scheduled job times (UTC)
|
// Scheduled job times (UTC)
|
||||||
TaskReminderHour int
|
TaskReminderHour int
|
||||||
@@ -184,6 +190,10 @@ func Load() (*Config, error) {
|
|||||||
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
|
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
|
||||||
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
|
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
|
||||||
},
|
},
|
||||||
|
AppleAuth: AppleAuthConfig{
|
||||||
|
ClientID: viper.GetString("APPLE_CLIENT_ID"),
|
||||||
|
TeamID: viper.GetString("APPLE_TEAM_ID"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate required fields
|
// Validate required fields
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ func Migrate() error {
|
|||||||
&models.UserProfile{},
|
&models.UserProfile{},
|
||||||
&models.ConfirmationCode{},
|
&models.ConfirmationCode{},
|
||||||
&models.PasswordResetCode{},
|
&models.PasswordResetCode{},
|
||||||
|
&models.AppleSocialAuth{},
|
||||||
|
|
||||||
// Admin users (separate from app users)
|
// Admin users (separate from app users)
|
||||||
&models.AdminUser{},
|
&models.AdminUser{},
|
||||||
|
|||||||
@@ -49,3 +49,12 @@ type UpdateProfileRequest struct {
|
|||||||
type ResendVerificationRequest struct {
|
type ResendVerificationRequest struct {
|
||||||
// No body needed - uses authenticated user's email
|
// No body needed - uses authenticated user's email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AppleSignInRequest represents the Apple Sign In request body
|
||||||
|
type AppleSignInRequest struct {
|
||||||
|
IDToken string `json:"id_token" binding:"required"`
|
||||||
|
UserID string `json:"user_id" binding:"required"` // Apple's sub claim
|
||||||
|
Email *string `json:"email"` // May be nil or private relay
|
||||||
|
FirstName *string `json:"first_name"`
|
||||||
|
LastName *string `json:"last_name"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -155,3 +155,19 @@ func NewRegisterResponse(token string, user *models.User) RegisterResponse {
|
|||||||
Message: "Registration successful. Please check your email to verify your account.",
|
Message: "Registration successful. Please check your email to verify your account.",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AppleSignInResponse represents the Apple Sign In response
|
||||||
|
type AppleSignInResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
User UserResponse `json:"user"`
|
||||||
|
IsNewUser bool `json:"is_new_user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAppleSignInResponse creates an AppleSignInResponse
|
||||||
|
func NewAppleSignInResponse(token string, user *models.User, isNewUser bool) AppleSignInResponse {
|
||||||
|
return AppleSignInResponse{
|
||||||
|
Token: token,
|
||||||
|
User: NewUserResponse(user),
|
||||||
|
IsNewUser: isNewUser,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,9 +15,10 @@ import (
|
|||||||
|
|
||||||
// AuthHandler handles authentication endpoints
|
// AuthHandler handles authentication endpoints
|
||||||
type AuthHandler struct {
|
type AuthHandler struct {
|
||||||
authService *services.AuthService
|
authService *services.AuthService
|
||||||
emailService *services.EmailService
|
emailService *services.EmailService
|
||||||
cache *services.CacheService
|
cache *services.CacheService
|
||||||
|
appleAuthService *services.AppleAuthService
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthHandler creates a new auth handler
|
// NewAuthHandler creates a new auth handler
|
||||||
@@ -29,6 +30,11 @@ func NewAuthHandler(authService *services.AuthService, emailService *services.Em
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAppleAuthService sets the Apple auth service (called after initialization)
|
||||||
|
func (h *AuthHandler) SetAppleAuthService(appleAuth *services.AppleAuthService) {
|
||||||
|
h.appleAuthService = appleAuth
|
||||||
|
}
|
||||||
|
|
||||||
// Login handles POST /api/auth/login/
|
// Login handles POST /api/auth/login/
|
||||||
func (h *AuthHandler) Login(c *gin.Context) {
|
func (h *AuthHandler) Login(c *gin.Context) {
|
||||||
var req requests.LoginRequest
|
var req requests.LoginRequest
|
||||||
@@ -362,3 +368,43 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
|||||||
Message: "Password reset successfully. Please log in with your new password.",
|
Message: "Password reset successfully. Please log in with your new password.",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AppleSignIn handles POST /api/auth/apple-sign-in/
|
||||||
|
func (h *AuthHandler) AppleSignIn(c *gin.Context) {
|
||||||
|
var req requests.AppleSignInRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
|
||||||
|
Error: "Invalid request body",
|
||||||
|
Details: map[string]string{
|
||||||
|
"validation": err.Error(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.appleAuthService == nil {
|
||||||
|
log.Error().Msg("Apple auth service not configured")
|
||||||
|
c.JSON(http.StatusInternalServerError, responses.ErrorResponse{
|
||||||
|
Error: "Apple Sign In is not configured",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := h.authService.AppleSignIn(c.Request.Context(), h.appleAuthService, &req)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
message := "Apple Sign In failed"
|
||||||
|
|
||||||
|
if errors.Is(err, services.ErrUserInactive) {
|
||||||
|
message = "Account is inactive"
|
||||||
|
} else if errors.Is(err, services.ErrAppleSignInFailed) {
|
||||||
|
message = "Invalid Apple identity token"
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Err(err).Msg("Apple Sign In failed")
|
||||||
|
c.JSON(status, responses.ErrorResponse{Error: message})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
|
|||||||
@@ -230,3 +230,20 @@ func GenerateResetToken() string {
|
|||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
return hex.EncodeToString(b)
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AppleSocialAuth represents a user's linked Apple ID for Sign in with Apple
|
||||||
|
type AppleSocialAuth struct {
|
||||||
|
ID uint `gorm:"primaryKey" json:"id"`
|
||||||
|
UserID uint `gorm:"uniqueIndex;not null" json:"user_id"`
|
||||||
|
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||||
|
AppleID string `gorm:"column:apple_id;size:255;uniqueIndex;not null" json:"apple_id"` // Apple's unique subject ID
|
||||||
|
Email string `gorm:"column:email;size:254" json:"email"` // May be private relay
|
||||||
|
IsPrivateEmail bool `gorm:"column:is_private_email;default:false" json:"is_private_email"`
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime" json:"created_at"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime" json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName returns the table name for GORM
|
||||||
|
func (AppleSocialAuth) TableName() string {
|
||||||
|
return "user_applesocialauth"
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,15 +11,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = errors.New("user not found")
|
ErrUserNotFound = errors.New("user not found")
|
||||||
ErrUserExists = errors.New("user already exists")
|
ErrUserExists = errors.New("user already exists")
|
||||||
ErrInvalidToken = errors.New("invalid token")
|
ErrInvalidToken = errors.New("invalid token")
|
||||||
ErrTokenNotFound = errors.New("token not found")
|
ErrTokenNotFound = errors.New("token not found")
|
||||||
ErrCodeNotFound = errors.New("code not found")
|
ErrCodeNotFound = errors.New("code not found")
|
||||||
ErrCodeExpired = errors.New("code expired")
|
ErrCodeExpired = errors.New("code expired")
|
||||||
ErrCodeUsed = errors.New("code already used")
|
ErrCodeUsed = errors.New("code already used")
|
||||||
ErrTooManyAttempts = errors.New("too many attempts")
|
ErrTooManyAttempts = errors.New("too many attempts")
|
||||||
ErrRateLimitExceeded = errors.New("rate limit exceeded")
|
ErrRateLimitExceeded = errors.New("rate limit exceeded")
|
||||||
|
ErrAppleAuthNotFound = errors.New("apple social auth not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserRepository handles user-related database operations
|
// UserRepository handles user-related database operations
|
||||||
@@ -486,3 +487,27 @@ func (r *UserRepository) FindProfilesInSharedResidences(userID uint) ([]models.U
|
|||||||
|
|
||||||
return profiles, err
|
return profiles, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Apple Social Auth Methods ---
|
||||||
|
|
||||||
|
// FindByAppleID finds an Apple social auth by Apple ID
|
||||||
|
func (r *UserRepository) FindByAppleID(appleID string) (*models.AppleSocialAuth, error) {
|
||||||
|
var auth models.AppleSocialAuth
|
||||||
|
if err := r.db.Where("apple_id = ?", appleID).First(&auth).Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, ErrAppleAuthNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAppleSocialAuth creates a new Apple social auth record
|
||||||
|
func (r *UserRepository) CreateAppleSocialAuth(auth *models.AppleSocialAuth) error {
|
||||||
|
return r.db.Create(auth).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAppleSocialAuth updates an Apple social auth record
|
||||||
|
func (r *UserRepository) UpdateAppleSocialAuth(auth *models.AppleSocialAuth) error {
|
||||||
|
return r.db.Save(auth).Error
|
||||||
|
}
|
||||||
|
|||||||
@@ -92,8 +92,12 @@ func SetupRouter(deps *Dependencies) *gin.Engine {
|
|||||||
// Initialize middleware
|
// Initialize middleware
|
||||||
authMiddleware := middleware.NewAuthMiddleware(deps.DB, deps.Cache)
|
authMiddleware := middleware.NewAuthMiddleware(deps.DB, deps.Cache)
|
||||||
|
|
||||||
|
// Initialize Apple auth service
|
||||||
|
appleAuthService := services.NewAppleAuthService(deps.Cache, cfg)
|
||||||
|
|
||||||
// Initialize handlers
|
// Initialize handlers
|
||||||
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
|
||||||
|
authHandler.SetAppleAuthService(appleAuthService)
|
||||||
userHandler := handlers.NewUserHandler(userService)
|
userHandler := handlers.NewUserHandler(userService)
|
||||||
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService)
|
residenceHandler := handlers.NewResidenceHandler(residenceService, deps.PDFService, deps.EmailService)
|
||||||
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
|
taskHandler := handlers.NewTaskHandler(taskService, deps.StorageService)
|
||||||
@@ -183,6 +187,7 @@ func setupPublicAuthRoutes(api *gin.RouterGroup, authHandler *handlers.AuthHandl
|
|||||||
auth.POST("/forgot-password/", authHandler.ForgotPassword)
|
auth.POST("/forgot-password/", authHandler.ForgotPassword)
|
||||||
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
|
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
|
||||||
auth.POST("/reset-password/", authHandler.ResetPassword)
|
auth.POST("/reset-password/", authHandler.ResetPassword)
|
||||||
|
auth.POST("/apple-sign-in/", authHandler.AppleSignIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,7 +199,8 @@ func setupProtectedAuthRoutes(api *gin.RouterGroup, authHandler *handlers.AuthHa
|
|||||||
auth.GET("/me/", authHandler.CurrentUser)
|
auth.GET("/me/", authHandler.CurrentUser)
|
||||||
auth.PUT("/profile/", authHandler.UpdateProfile)
|
auth.PUT("/profile/", authHandler.UpdateProfile)
|
||||||
auth.PATCH("/profile/", authHandler.UpdateProfile)
|
auth.PATCH("/profile/", authHandler.UpdateProfile)
|
||||||
auth.POST("/verify-email/", authHandler.VerifyEmail)
|
auth.POST("/verify/", authHandler.VerifyEmail) // Alias for mobile app compatibility
|
||||||
|
auth.POST("/verify-email/", authHandler.VerifyEmail) // Original route
|
||||||
auth.POST("/resend-verification/", authHandler.ResendVerification)
|
auth.POST("/resend-verification/", authHandler.ResendVerification)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
295
internal/services/apple_auth.go
Normal file
295
internal/services/apple_auth.go
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
|
||||||
|
"github.com/treytartt/casera-api/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
appleKeysURL = "https://appleid.apple.com/auth/keys"
|
||||||
|
appleIssuer = "https://appleid.apple.com"
|
||||||
|
appleKeysCacheTTL = 24 * time.Hour
|
||||||
|
appleKeysCacheKey = "apple:public_keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidAppleToken = errors.New("invalid Apple identity token")
|
||||||
|
ErrAppleTokenExpired = errors.New("Apple identity token has expired")
|
||||||
|
ErrInvalidAppleAudience = errors.New("invalid Apple token audience")
|
||||||
|
ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer")
|
||||||
|
ErrAppleKeyNotFound = errors.New("Apple public key not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// AppleJWKS represents Apple's JSON Web Key Set
|
||||||
|
type AppleJWKS struct {
|
||||||
|
Keys []AppleJWK `json:"keys"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppleJWK represents a single JSON Web Key from Apple
|
||||||
|
type AppleJWK struct {
|
||||||
|
Kty string `json:"kty"` // Key type (RSA)
|
||||||
|
Kid string `json:"kid"` // Key ID
|
||||||
|
Use string `json:"use"` // Key use (sig)
|
||||||
|
Alg string `json:"alg"` // Algorithm (RS256)
|
||||||
|
N string `json:"n"` // RSA modulus
|
||||||
|
E string `json:"e"` // RSA exponent
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppleTokenClaims represents the claims in an Apple identity token
|
||||||
|
type AppleTokenClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string
|
||||||
|
IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string
|
||||||
|
AuthTime int64 `json:"auth_time,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEmailVerified returns whether the email is verified (handles both bool and string types)
|
||||||
|
func (c *AppleTokenClaims) IsEmailVerified() bool {
|
||||||
|
switch v := c.EmailVerified.(type) {
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
return v == "true"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrivateRelayEmail returns whether the email is a private relay email
|
||||||
|
func (c *AppleTokenClaims) IsPrivateRelayEmail() bool {
|
||||||
|
switch v := c.IsPrivateEmail.(type) {
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
return v == "true"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppleAuthService handles Apple Sign In token verification
|
||||||
|
type AppleAuthService struct {
|
||||||
|
cache *CacheService
|
||||||
|
config *config.Config
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAppleAuthService creates a new Apple auth service
|
||||||
|
func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService {
|
||||||
|
return &AppleAuthService{
|
||||||
|
cache: cache,
|
||||||
|
config: cfg,
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyIdentityToken verifies an Apple identity token and returns the claims
|
||||||
|
func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) {
|
||||||
|
// Parse the token header to get the key ID
|
||||||
|
parts := strings.Split(idToken, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, ErrInvalidAppleToken
|
||||||
|
}
|
||||||
|
|
||||||
|
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var header struct {
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
Alg string `json:"alg"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the public key for this key ID
|
||||||
|
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and verify the token
|
||||||
|
token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// Verify the signing method
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
return publicKey, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||||
|
return nil, ErrAppleTokenExpired
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(*AppleTokenClaims)
|
||||||
|
if !ok || !token.Valid {
|
||||||
|
return nil, ErrInvalidAppleToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the issuer
|
||||||
|
if claims.Issuer != appleIssuer {
|
||||||
|
return nil, ErrInvalidAppleIssuer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the audience (should be our bundle ID)
|
||||||
|
if !s.verifyAudience(claims.Audience) {
|
||||||
|
return nil, ErrInvalidAppleAudience
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyAudience checks if the token audience matches our client ID
|
||||||
|
func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool {
|
||||||
|
clientID := s.config.AppleAuth.ClientID
|
||||||
|
if clientID == "" {
|
||||||
|
// If not configured, skip audience verification (for development)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, aud := range audience {
|
||||||
|
if aud == clientID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPublicKey retrieves the public key for the given key ID
|
||||||
|
func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||||
|
// Try to get from cache first
|
||||||
|
keys, err := s.getCachedKeys(ctx)
|
||||||
|
if err != nil || keys == nil {
|
||||||
|
// Fetch fresh keys
|
||||||
|
keys, err = s.fetchApplePublicKeys(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the key with the matching ID
|
||||||
|
for keyID, pubKey := range keys {
|
||||||
|
if keyID == kid {
|
||||||
|
return pubKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Key not found in cache, try fetching fresh keys
|
||||||
|
keys, err = s.fetchApplePublicKeys(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pubKey, ok := keys[kid]; ok {
|
||||||
|
return pubKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrAppleKeyNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCachedKeys retrieves cached Apple public keys from Redis
|
||||||
|
func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||||
|
if s.cache == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := s.cache.GetString(ctx, appleKeysCacheKey)
|
||||||
|
if err != nil || data == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var jwks AppleJWKS
|
||||||
|
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.parseJWKS(&jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchApplePublicKeys fetches Apple's public keys and caches them
|
||||||
|
func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch Apple keys: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var jwks AppleJWKS
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode Apple keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the keys
|
||||||
|
if s.cache != nil {
|
||||||
|
keysJSON, _ := json.Marshal(jwks)
|
||||||
|
_ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.parseJWKS(&jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseJWKS converts Apple's JWKS to RSA public keys
|
||||||
|
func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) {
|
||||||
|
keys := make(map[string]*rsa.PublicKey)
|
||||||
|
|
||||||
|
for _, key := range jwks.Keys {
|
||||||
|
if key.Kty != "RSA" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the modulus (N)
|
||||||
|
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n := new(big.Int).SetBytes(nBytes)
|
||||||
|
|
||||||
|
// Decode the exponent (E)
|
||||||
|
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
e := 0
|
||||||
|
for _, b := range eBytes {
|
||||||
|
e = e<<8 + int(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := &rsa.PublicKey{
|
||||||
|
N: n,
|
||||||
|
E: e,
|
||||||
|
}
|
||||||
|
|
||||||
|
keys[key.Kid] = pubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
package services
|
package services
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -26,6 +28,7 @@ var (
|
|||||||
ErrAlreadyVerified = errors.New("email already verified")
|
ErrAlreadyVerified = errors.New("email already verified")
|
||||||
ErrRateLimitExceeded = errors.New("too many requests, please try again later")
|
ErrRateLimitExceeded = errors.New("too many requests, please try again later")
|
||||||
ErrInvalidResetToken = errors.New("invalid or expired reset token")
|
ErrInvalidResetToken = errors.New("invalid or expired reset token")
|
||||||
|
ErrAppleSignInFailed = errors.New("Apple Sign In failed")
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthService handles authentication business logic
|
// AuthService handles authentication business logic
|
||||||
@@ -137,8 +140,13 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
|
|||||||
return nil, "", fmt.Errorf("failed to create token: %w", err)
|
return nil, "", fmt.Errorf("failed to create token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate confirmation code
|
// Generate confirmation code - use fixed code in debug mode for easier local testing
|
||||||
code := generateSixDigitCode()
|
var code string
|
||||||
|
if s.cfg.Server.Debug {
|
||||||
|
code = "123456"
|
||||||
|
} else {
|
||||||
|
code = generateSixDigitCode()
|
||||||
|
}
|
||||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
||||||
|
|
||||||
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
|
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
|
||||||
@@ -268,8 +276,13 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
|
|||||||
return "", ErrAlreadyVerified
|
return "", ErrAlreadyVerified
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate new code
|
// Generate new code - use fixed code in debug mode for easier local testing
|
||||||
code := generateSixDigitCode()
|
var code string
|
||||||
|
if s.cfg.Server.Debug {
|
||||||
|
code = "123456"
|
||||||
|
} else {
|
||||||
|
code = generateSixDigitCode()
|
||||||
|
}
|
||||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
|
||||||
|
|
||||||
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil {
|
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil {
|
||||||
@@ -300,8 +313,13 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
|
|||||||
return "", nil, ErrRateLimitExceeded
|
return "", nil, ErrRateLimitExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate code and reset token
|
// Generate code and reset token - use fixed code in debug mode for easier local testing
|
||||||
code := generateSixDigitCode()
|
var code string
|
||||||
|
if s.cfg.Server.Debug {
|
||||||
|
code = "123456"
|
||||||
|
} else {
|
||||||
|
code = generateSixDigitCode()
|
||||||
|
}
|
||||||
resetToken := generateResetToken()
|
resetToken := generateResetToken()
|
||||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry)
|
expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry)
|
||||||
|
|
||||||
@@ -398,6 +416,140 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AppleSignIn handles Sign in with Apple authentication
|
||||||
|
func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthService, req *requests.AppleSignInRequest) (*responses.AppleSignInResponse, error) {
|
||||||
|
// 1. Verify the Apple JWT token
|
||||||
|
claims, err := appleAuth.VerifyIdentityToken(ctx, req.IDToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrAppleSignInFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the subject from claims as the authoritative Apple ID
|
||||||
|
appleID := claims.Subject
|
||||||
|
if appleID == "" {
|
||||||
|
appleID = req.UserID // Fallback to request UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check if this Apple ID is already linked to an account
|
||||||
|
existingAuth, err := s.userRepo.FindByAppleID(appleID)
|
||||||
|
if err == nil && existingAuth != nil {
|
||||||
|
// User already linked with this Apple ID - log them in
|
||||||
|
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to find user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.IsActive {
|
||||||
|
return nil, ErrUserInactive
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get or create token
|
||||||
|
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update last login
|
||||||
|
_ = s.userRepo.UpdateLastLogin(user.ID)
|
||||||
|
|
||||||
|
return &responses.AppleSignInResponse{
|
||||||
|
Token: token.Key,
|
||||||
|
User: responses.NewUserResponse(user),
|
||||||
|
IsNewUser: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check if email matches an existing user (for account linking)
|
||||||
|
email := getEmailFromRequest(req.Email, claims.Email)
|
||||||
|
if email != "" {
|
||||||
|
existingUser, err := s.userRepo.FindByEmail(email)
|
||||||
|
if err == nil && existingUser != nil {
|
||||||
|
// Link Apple ID to existing account
|
||||||
|
appleAuthRecord := &models.AppleSocialAuth{
|
||||||
|
UserID: existingUser.ID,
|
||||||
|
AppleID: appleID,
|
||||||
|
Email: email,
|
||||||
|
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
|
||||||
|
}
|
||||||
|
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to link Apple ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as verified since Apple verified the email
|
||||||
|
_ = s.userRepo.SetProfileVerified(existingUser.ID, true)
|
||||||
|
|
||||||
|
// Get or create token
|
||||||
|
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update last login
|
||||||
|
_ = s.userRepo.UpdateLastLogin(existingUser.ID)
|
||||||
|
|
||||||
|
// Reload user with profile
|
||||||
|
existingUser, _ = s.userRepo.FindByIDWithProfile(existingUser.ID)
|
||||||
|
|
||||||
|
return &responses.AppleSignInResponse{
|
||||||
|
Token: token.Key,
|
||||||
|
User: responses.NewUserResponse(existingUser),
|
||||||
|
IsNewUser: false,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Create new user
|
||||||
|
username := generateUniqueUsername(email, req.FirstName)
|
||||||
|
|
||||||
|
user := &models.User{
|
||||||
|
Username: username,
|
||||||
|
Email: getEmailOrDefault(email),
|
||||||
|
FirstName: getStringOrEmpty(req.FirstName),
|
||||||
|
LastName: getStringOrEmpty(req.LastName),
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a random password (user won't use it since they log in with Apple)
|
||||||
|
randomPassword := generateResetToken()
|
||||||
|
_ = user.SetPassword(randomPassword)
|
||||||
|
|
||||||
|
if err := s.userRepo.Create(user); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create profile (already verified since Apple verified)
|
||||||
|
profile, _ := s.userRepo.GetOrCreateProfile(user.ID)
|
||||||
|
if profile != nil {
|
||||||
|
_ = s.userRepo.SetProfileVerified(user.ID, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link Apple ID
|
||||||
|
appleAuthRecord := &models.AppleSocialAuth{
|
||||||
|
UserID: user.ID,
|
||||||
|
AppleID: appleID,
|
||||||
|
Email: getEmailOrDefault(email),
|
||||||
|
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
|
||||||
|
}
|
||||||
|
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create Apple auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token
|
||||||
|
token, err := s.userRepo.GetOrCreateToken(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload user with profile
|
||||||
|
user, _ = s.userRepo.FindByIDWithProfile(user.ID)
|
||||||
|
|
||||||
|
return &responses.AppleSignInResponse{
|
||||||
|
Token: token.Key,
|
||||||
|
User: responses.NewUserResponse(user),
|
||||||
|
IsNewUser: true,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
||||||
func generateSixDigitCode() string {
|
func generateSixDigitCode() string {
|
||||||
@@ -416,3 +568,50 @@ func generateResetToken() string {
|
|||||||
rand.Read(b)
|
rand.Read(b)
|
||||||
return hex.EncodeToString(b)
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions for Apple Sign In
|
||||||
|
|
||||||
|
func getEmailFromRequest(reqEmail *string, claimsEmail string) string {
|
||||||
|
if reqEmail != nil && *reqEmail != "" {
|
||||||
|
return *reqEmail
|
||||||
|
}
|
||||||
|
return claimsEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEmailOrDefault(email string) string {
|
||||||
|
if email == "" {
|
||||||
|
// Generate a placeholder email for users without one
|
||||||
|
return fmt.Sprintf("apple_%s@privaterelay.appleid.com", generateResetToken()[:16])
|
||||||
|
}
|
||||||
|
return email
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringOrEmpty(s *string) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *s
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPrivateRelayEmail(email string) bool {
|
||||||
|
return strings.HasSuffix(strings.ToLower(email), "@privaterelay.appleid.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateUniqueUsername(email string, firstName *string) string {
|
||||||
|
// Try using first part of email
|
||||||
|
if email != "" && !isPrivateRelayEmail(email) {
|
||||||
|
parts := strings.Split(email, "@")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
// Add random suffix to ensure uniqueness
|
||||||
|
return parts[0] + "_" + generateResetToken()[:6]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try using first name
|
||||||
|
if firstName != nil && *firstName != "" {
|
||||||
|
return strings.ToLower(*firstName) + "_" + generateResetToken()[:6]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to random username
|
||||||
|
return "user_" + generateResetToken()[:10]
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user