package services import ( "context" "crypto/rsa" "encoding/base64" "encoding/json" "errors" "fmt" "math/big" "net/http" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/treytartt/casera-api/internal/config" ) const ( appleKeysURL = "https://appleid.apple.com/auth/keys" appleIssuer = "https://appleid.apple.com" appleKeysCacheTTL = 24 * time.Hour appleKeysCacheKey = "apple:public_keys" ) var ( ErrInvalidAppleToken = errors.New("invalid Apple identity token") ErrAppleTokenExpired = errors.New("Apple identity token has expired") ErrInvalidAppleAudience = errors.New("invalid Apple token audience") ErrInvalidAppleIssuer = errors.New("invalid Apple token issuer") ErrAppleKeyNotFound = errors.New("Apple public key not found") ) // AppleJWKS represents Apple's JSON Web Key Set type AppleJWKS struct { Keys []AppleJWK `json:"keys"` } // AppleJWK represents a single JSON Web Key from Apple type AppleJWK struct { Kty string `json:"kty"` // Key type (RSA) Kid string `json:"kid"` // Key ID Use string `json:"use"` // Key use (sig) Alg string `json:"alg"` // Algorithm (RS256) N string `json:"n"` // RSA modulus E string `json:"e"` // RSA exponent } // AppleTokenClaims represents the claims in an Apple identity token type AppleTokenClaims struct { jwt.RegisteredClaims Email string `json:"email,omitempty"` EmailVerified any `json:"email_verified,omitempty"` // Can be bool or string IsPrivateEmail any `json:"is_private_email,omitempty"` // Can be bool or string AuthTime int64 `json:"auth_time,omitempty"` } // IsEmailVerified returns whether the email is verified (handles both bool and string types) func (c *AppleTokenClaims) IsEmailVerified() bool { switch v := c.EmailVerified.(type) { case bool: return v case string: return v == "true" default: return false } } // IsPrivateRelayEmail returns whether the email is a private relay email func (c *AppleTokenClaims) IsPrivateRelayEmail() bool { switch v := c.IsPrivateEmail.(type) { case bool: return v case string: return v == "true" default: return false } } // AppleAuthService handles Apple Sign In token verification type AppleAuthService struct { cache *CacheService config *config.Config client *http.Client } // NewAppleAuthService creates a new Apple auth service func NewAppleAuthService(cache *CacheService, cfg *config.Config) *AppleAuthService { return &AppleAuthService{ cache: cache, config: cfg, client: &http.Client{ Timeout: 10 * time.Second, }, } } // VerifyIdentityToken verifies an Apple identity token and returns the claims func (s *AppleAuthService) VerifyIdentityToken(ctx context.Context, idToken string) (*AppleTokenClaims, error) { // Parse the token header to get the key ID parts := strings.Split(idToken, ".") if len(parts) != 3 { return nil, ErrInvalidAppleToken } headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) if err != nil { return nil, fmt.Errorf("failed to decode token header: %w", err) } var header struct { Kid string `json:"kid"` Alg string `json:"alg"` } if err := json.Unmarshal(headerBytes, &header); err != nil { return nil, fmt.Errorf("failed to parse token header: %w", err) } // Get the public key for this key ID publicKey, err := s.getPublicKey(ctx, header.Kid) if err != nil { return nil, err } // Parse and verify the token token, err := jwt.ParseWithClaims(idToken, &AppleTokenClaims{}, func(token *jwt.Token) (interface{}, error) { // Verify the signing method if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return publicKey, nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, ErrAppleTokenExpired } return nil, fmt.Errorf("failed to parse token: %w", err) } claims, ok := token.Claims.(*AppleTokenClaims) if !ok || !token.Valid { return nil, ErrInvalidAppleToken } // Verify the issuer if claims.Issuer != appleIssuer { return nil, ErrInvalidAppleIssuer } // Verify the audience (should be our bundle ID) if !s.verifyAudience(claims.Audience) { return nil, ErrInvalidAppleAudience } return claims, nil } // verifyAudience checks if the token audience matches our client ID func (s *AppleAuthService) verifyAudience(audience jwt.ClaimStrings) bool { clientID := s.config.AppleAuth.ClientID if clientID == "" { // If not configured, skip audience verification (for development) return true } for _, aud := range audience { if aud == clientID { return true } } return false } // getPublicKey retrieves the public key for the given key ID func (s *AppleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) { // Try to get from cache first keys, err := s.getCachedKeys(ctx) if err != nil || keys == nil { // Fetch fresh keys keys, err = s.fetchApplePublicKeys(ctx) if err != nil { return nil, err } } // Find the key with the matching ID for keyID, pubKey := range keys { if keyID == kid { return pubKey, nil } } // Key not found in cache, try fetching fresh keys keys, err = s.fetchApplePublicKeys(ctx) if err != nil { return nil, err } if pubKey, ok := keys[kid]; ok { return pubKey, nil } return nil, ErrAppleKeyNotFound } // getCachedKeys retrieves cached Apple public keys from Redis func (s *AppleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) { if s.cache == nil { return nil, nil } data, err := s.cache.GetString(ctx, appleKeysCacheKey) if err != nil || data == "" { return nil, nil } var jwks AppleJWKS if err := json.Unmarshal([]byte(data), &jwks); err != nil { return nil, nil } return s.parseJWKS(&jwks) } // fetchApplePublicKeys fetches Apple's public keys and caches them func (s *AppleAuthService) fetchApplePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, appleKeysURL, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } resp, err := s.client.Do(req) if err != nil { return nil, fmt.Errorf("failed to fetch Apple keys: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("Apple keys endpoint returned status %d", resp.StatusCode) } var jwks AppleJWKS if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { return nil, fmt.Errorf("failed to decode Apple keys: %w", err) } // Cache the keys if s.cache != nil { keysJSON, _ := json.Marshal(jwks) _ = s.cache.SetString(ctx, appleKeysCacheKey, string(keysJSON), appleKeysCacheTTL) } return s.parseJWKS(&jwks) } // parseJWKS converts Apple's JWKS to RSA public keys func (s *AppleAuthService) parseJWKS(jwks *AppleJWKS) (map[string]*rsa.PublicKey, error) { keys := make(map[string]*rsa.PublicKey) for _, key := range jwks.Keys { if key.Kty != "RSA" { continue } // Decode the modulus (N) nBytes, err := base64.RawURLEncoding.DecodeString(key.N) if err != nil { continue } n := new(big.Int).SetBytes(nBytes) // Decode the exponent (E) eBytes, err := base64.RawURLEncoding.DecodeString(key.E) if err != nil { continue } e := 0 for _, b := range eBytes { e = e<<8 + int(b) } pubKey := &rsa.PublicKey{ N: n, E: e, } keys[key.Kid] = pubKey } return keys, nil }