fix(security): remediate 2026-05-12 audit findings (Stages 2–5)
Remediation of the 2026-05-12/13 audits (78 findings + cluster gaps), tracked in deploy-k3s/SECURITY.md, plus fixes from two independent post-remediation reviews. Auth & sessions: - SHA-256 hashed auth-token storage (C1); prior-token cache eviction on re-login (MEDIUM-1) - local Google JWKS verification, iss/aud/exp checks (C2/C3) - constant-time login + generic errors (L1/LIVE-L11/LIVE-L13) - per-account login lockout keyed on distinct source IPs (M5/MEDIUM-3) - verified-email gating, login rate limiting (LIVE-L19, H1-H3) IAP & webhooks: - Apple/Google cross-account replay protection (C5/C6/C10/C13, H5/H6) - migrations 000003-000006 (token hashing, IAP replay, audit_log + webhook_event_log table creation, append-only audit log) Authorization & races: - file-ownership owner-OR-member fix (C7), atomic share-code join (C9/H9), device-token reassignment (C8/LOW-3) Secrets & deploy: - secrets file-mounted at /etc/honeydue/secrets, not env (F8); Redis password out of the ConfigMap (HIGH-1); B2 keys reconciled - digest-pinned images, admin ingress hardening, CSP/HSTS, /metrics lockdown; kubeconfig 0600, etcd secrets-encryption, fail2ban + unattended-upgrades at provision; secret-rotation runbook Build, vet, and the full test suite (incl. -race) pass; the goose migration chain is verified against PostgreSQL 16. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -75,9 +75,9 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
|
||||
assert.Contains(t, resp.Message, "still valid")
|
||||
}
|
||||
|
||||
@@ -88,23 +88,25 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
|
||||
assert.Contains(t, resp.Message, "refreshed")
|
||||
|
||||
// Verify old token was deleted
|
||||
var count int64
|
||||
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "old token should be deleted")
|
||||
|
||||
// Verify new token exists in DB
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count)
|
||||
// resp.Token is the raw token; the DB stores its hash.
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "new token should exist in DB")
|
||||
|
||||
// Verify new token belongs to the same user
|
||||
var newToken models.AuthToken
|
||||
require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error)
|
||||
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
|
||||
assert.Equal(t, user.ID, newToken.UserID)
|
||||
}
|
||||
|
||||
@@ -115,7 +117,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
@@ -130,9 +132,9 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
||||
@@ -155,7 +157,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
// Try to refresh with a different user ID
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
@@ -168,7 +170,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package services
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -36,13 +37,32 @@ var (
|
||||
ErrGoogleSignInFailed = errors.New("Google Sign In failed")
|
||||
)
|
||||
|
||||
// Per-account login lockout (audit M5, hardened per MEDIUM-3).
|
||||
const (
|
||||
// maxLoginFailureIPs is how many DISTINCT source IPs may fail to log in to
|
||||
// one account within the window before that account is locked. Counting
|
||||
// distinct IPs (not raw attempts) means a single attacker who knows a
|
||||
// victim's email cannot lock the victim out by spamming failures — only a
|
||||
// genuinely distributed credential-stuffing attack reaches this threshold.
|
||||
maxLoginFailureIPs = 5
|
||||
// loginLockWindow is how long the failed-IP set persists; it is refreshed
|
||||
// on each failure so an active attack keeps the window open.
|
||||
loginLockWindow = 15 * time.Minute
|
||||
)
|
||||
|
||||
// AuthService handles authentication business logic
|
||||
type AuthService struct {
|
||||
userRepo *repositories.UserRepository
|
||||
notificationRepo *repositories.NotificationRepository
|
||||
cache *CacheService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// SetCacheService wires Redis for per-account login-failure tracking (M5).
|
||||
func (s *AuthService) SetCacheService(cache *CacheService) {
|
||||
s.cache = cache
|
||||
}
|
||||
|
||||
// NewAuthService creates a new auth service
|
||||
func NewAuthService(userRepo *repositories.UserRepository, cfg *config.Config) *AuthService {
|
||||
return &AuthService{
|
||||
@@ -56,34 +76,89 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N
|
||||
s.notificationRepo = notificationRepo
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a token
|
||||
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) {
|
||||
// dummyPasswordHash is a valid bcrypt hash used to keep login response time
|
||||
// constant when the account does not exist (audit LIVE-L11). It is computed
|
||||
// once at startup; the plaintext it hashes is irrelevant and never used.
|
||||
var dummyPasswordHash = func() string {
|
||||
h, err := bcrypt.GenerateFromPassword([]byte("honeydue-login-timing-equalizer"), models.BcryptCost)
|
||||
if err != nil {
|
||||
return "" // CompareHashAndPassword against "" always fails — safe
|
||||
}
|
||||
return string(h)
|
||||
}()
|
||||
|
||||
// freshToken mints a new auth token for the user and evicts any prior token's
|
||||
// Redis cache entry (audit MEDIUM-1). Without the eviction a re-login would
|
||||
// not actually kill a previously-issued token until the cache TTL lapsed — a
|
||||
// stolen token would keep working for up to 5 minutes after the victim
|
||||
// re-authenticates. A cache-eviction failure is logged, not fatal: the token
|
||||
// row is already gone, so the stale entry simply ages out on its own.
|
||||
func (s *AuthService) freshToken(ctx context.Context, userID uint) (*models.AuthToken, error) {
|
||||
token, oldHashes, err := s.userRepo.WithContext(ctx).CreateFreshToken(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.cache != nil && len(oldHashes) > 0 {
|
||||
if cErr := s.cache.InvalidateAuthTokenHashes(ctx, oldHashes...); cErr != nil {
|
||||
log.Warn().Err(cErr).Uint("user_id", userID).
|
||||
Msg("failed to evict prior auth-token cache entries on re-login")
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a token. clientIP is the request's
|
||||
// source IP (echo c.RealIP()), used for the distributed-attack lockout.
|
||||
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest, clientIP string) (*responses.LoginResponse, error) {
|
||||
// Find user by username or email
|
||||
identifier := req.Username
|
||||
if identifier == "" {
|
||||
identifier = req.Email
|
||||
}
|
||||
lockKey := strings.ToLower(strings.TrimSpace(identifier))
|
||||
|
||||
// Audit M5 (hardened per MEDIUM-3): per-account lockout keyed on the set
|
||||
// of distinct source IPs that have failed. Once enough distinct IPs have
|
||||
// failed for one account within the window, reject — this still catches
|
||||
// distributed credential stuffing, without letting a single attacker lock
|
||||
// a victim out by spamming failed logins from one IP.
|
||||
if s.cache != nil && lockKey != "" {
|
||||
if n, cErr := s.cache.LoginFailureIPCount(ctx, lockKey); cErr == nil && n >= maxLoginFailureIPs {
|
||||
return nil, apperrors.TooManyRequests("error.too_many_login_attempts")
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier)
|
||||
if err != nil {
|
||||
if errors.Is(err, repositories.ErrUserNotFound) {
|
||||
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
||||
}
|
||||
if err != nil && !errors.Is(err, repositories.ErrUserNotFound) {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive {
|
||||
return nil, apperrors.Unauthorized("error.account_inactive")
|
||||
// Constant-time login (audit LIVE-L11): always run a bcrypt comparison,
|
||||
// even when the account does not exist or is inactive, so response
|
||||
// timing never reveals which emails are real accounts. Compare against
|
||||
// the user's hash when available, otherwise a fixed dummy hash.
|
||||
passwordHash := dummyPasswordHash
|
||||
if user != nil {
|
||||
passwordHash = user.Password
|
||||
}
|
||||
passwordOK := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)) == nil
|
||||
|
||||
// Verify password
|
||||
if !user.CheckPassword(req.Password) {
|
||||
// One generic error for not-found, inactive, and wrong-password
|
||||
// (audit L1) — none of them disclose which condition failed.
|
||||
if user == nil || !user.IsActive || !passwordOK {
|
||||
if s.cache != nil && lockKey != "" {
|
||||
_, _ = s.cache.RegisterLoginFailure(ctx, lockKey, clientIP, loginLockWindow)
|
||||
}
|
||||
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
||||
}
|
||||
|
||||
// Successful authentication — clear the failure counter (audit M5).
|
||||
if s.cache != nil && lockKey != "" {
|
||||
_ = s.cache.ClearLoginFailures(ctx, lockKey)
|
||||
}
|
||||
|
||||
// Get or create auth token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -95,7 +170,7 @@ func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*r
|
||||
}
|
||||
|
||||
return &responses.LoginResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
}, nil
|
||||
}
|
||||
@@ -176,13 +251,13 @@ func (s *AuthService) Register(ctx context.Context, req *requests.RegisterReques
|
||||
}
|
||||
|
||||
// Create auth token (outside transaction since token generation is idempotent)
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.RegisterResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
Message: "Registration successful. Please check your email to verify your account.",
|
||||
}, code, nil
|
||||
@@ -243,7 +318,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID
|
||||
}
|
||||
|
||||
return &responses.RefreshTokenResponse{
|
||||
Token: newToken.Key,
|
||||
Token: newToken.Plaintext,
|
||||
Message: "Token refreshed successfully.",
|
||||
}, nil
|
||||
}
|
||||
@@ -390,26 +465,26 @@ func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find and validate confirmation code
|
||||
confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, repositories.ErrCodeNotFound) {
|
||||
// Audit M4: validate the code, consume it, and flip the verified flag in
|
||||
// one transaction so the three writes commit or roll back together.
|
||||
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
|
||||
confirmCode, err := txRepo.FindConfirmationCode(userID, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := txRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return txRepo.SetProfileVerified(userID, true)
|
||||
})
|
||||
if txErr != nil {
|
||||
if errors.Is(txErr, repositories.ErrCodeNotFound) {
|
||||
return apperrors.BadRequest("error.invalid_verification_code")
|
||||
}
|
||||
if errors.Is(err, repositories.ErrCodeExpired) {
|
||||
if errors.Is(txErr, repositories.ErrCodeExpired) {
|
||||
return apperrors.BadRequest("error.verification_code_expired")
|
||||
}
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Set profile as verified
|
||||
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
return apperrors.Internal(txErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -476,7 +551,7 @@ func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string,
|
||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry)
|
||||
|
||||
// Hash the code before storing
|
||||
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), models.BcryptCost)
|
||||
if err != nil {
|
||||
return "", nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -596,7 +671,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -605,7 +680,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -638,7 +713,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
|
||||
token, err := s.freshToken(ctx, existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -653,7 +728,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(existingUser),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -704,7 +779,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
// Create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -716,7 +791,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: true,
|
||||
}, nil
|
||||
@@ -749,7 +824,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -758,7 +833,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -794,7 +869,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
|
||||
token, err := s.freshToken(ctx, existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -809,7 +884,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(existingUser),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -861,7 +936,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
// Create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -873,7 +948,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: true,
|
||||
}, nil
|
||||
@@ -882,14 +957,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
// Helper functions
|
||||
|
||||
func generateSixDigitCode() string {
|
||||
b := make([]byte, 4)
|
||||
rand.Read(b)
|
||||
num := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||
if num < 0 {
|
||||
num = -num
|
||||
// Uniform 000000–999999 via rejection sampling on crypto/rand,
|
||||
// removing the modulo bias of `n % 1000000` (audit H4).
|
||||
for {
|
||||
var b [4]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
continue
|
||||
}
|
||||
// 4294000000 is the largest multiple of 1e6 <= MaxUint32.
|
||||
n := binary.BigEndian.Uint32(b[:])
|
||||
if n < 4294000000 {
|
||||
return fmt.Sprintf("%06d", n%1000000)
|
||||
}
|
||||
}
|
||||
code := num % 1000000
|
||||
return fmt.Sprintf("%06d", code)
|
||||
}
|
||||
|
||||
func generateResetToken() string {
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestAuthService_Login(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req)
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "testuser", resp.User.Username)
|
||||
@@ -75,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req)
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) {
|
||||
Password: "WrongPassword1",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
@@ -134,8 +134,10 @@ func TestAuthService_Login_InactiveUser(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
// Audit L1: inactive accounts return the same generic error as bad
|
||||
// credentials so login does not disclose which accounts exist.
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === Register ===
|
||||
@@ -443,7 +445,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
|
||||
Username: "testuser",
|
||||
Password: "NewPassword123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq)
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, loginResp.Token)
|
||||
}
|
||||
@@ -472,7 +474,7 @@ func TestAuthService_Logout(t *testing.T) {
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq)
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Logout
|
||||
@@ -659,7 +661,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) {
|
||||
Password: "",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// CacheService provides Redis caching functionality
|
||||
@@ -139,22 +140,25 @@ const (
|
||||
TokenCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// authTokenCacheKey returns the Redis key for an auth token. The raw token
|
||||
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
|
||||
func authTokenCacheKey(token string) string {
|
||||
return AuthTokenPrefix + models.HashToken(token)
|
||||
}
|
||||
|
||||
// CacheAuthToken caches a user ID for a token
|
||||
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
|
||||
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// GetCachedAuthToken gets a cached user ID for a token
|
||||
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
||||
key := AuthTokenPrefix + token
|
||||
val, err := c.GetString(ctx, key)
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -167,8 +171,7 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui
|
||||
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
|
||||
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
|
||||
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
|
||||
key := AuthTokenPrefix + token
|
||||
val, err := c.GetString(ctx, key)
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
@@ -184,8 +187,62 @@ func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token
|
||||
|
||||
// InvalidateAuthToken removes a cached token
|
||||
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.Delete(ctx, key)
|
||||
return c.Delete(ctx, authTokenCacheKey(token))
|
||||
}
|
||||
|
||||
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
|
||||
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
|
||||
// stored hash directly — used to evict a user's prior token on re-login
|
||||
// (audit MEDIUM-1), where the server no longer has the plaintext.
|
||||
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
|
||||
keys := make([]string, 0, len(hashes))
|
||||
for _, h := range hashes {
|
||||
if h != "" {
|
||||
keys = append(keys, AuthTokenPrefix+h)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.Delete(ctx, keys...)
|
||||
}
|
||||
|
||||
// --- Per-account login-failure tracking (audit M5) ---
|
||||
|
||||
const loginFailPrefix = "login_fail:"
|
||||
|
||||
// RegisterLoginFailure records a failed login for an account from a given
|
||||
// source IP, and returns the number of DISTINCT source IPs that have failed
|
||||
// for this account within the window. Tracking distinct IPs as a set rather
|
||||
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
|
||||
// run the count up and lock a victim out by knowing only their email — a
|
||||
// single IP is bounded by the per-IP edge/app rate limiters instead. A
|
||||
// genuinely distributed credential-stuffing attack still trips the lockout.
|
||||
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
|
||||
key := loginFailPrefix + identifier
|
||||
member := ip
|
||||
if member == "" {
|
||||
member = "unknown"
|
||||
}
|
||||
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Refresh the TTL on each failure: an active attack keeps the window
|
||||
// open, while a quiet account ages out `window` after its last failure.
|
||||
_ = c.client.Expire(ctx, key, window).Err()
|
||||
return c.client.SCard(ctx, key).Result()
|
||||
}
|
||||
|
||||
// LoginFailureIPCount returns how many distinct source IPs have failed to log
|
||||
// in to this account within the window (audit MEDIUM-3). SCard on a missing
|
||||
// key returns 0.
|
||||
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
|
||||
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
|
||||
}
|
||||
|
||||
// ClearLoginFailures resets the failed-login IP set after a successful login.
|
||||
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
|
||||
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
|
||||
}
|
||||
|
||||
// Static data cache helpers
|
||||
|
||||
@@ -296,9 +296,14 @@ func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, us
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Re-fetch the contractor to get the updated state with all relations
|
||||
// Re-fetch to get the updated state with all relations. Audit M12: if the
|
||||
// contractor was deleted concurrently between the toggle and this read,
|
||||
// surface a clean 404 instead of a 500.
|
||||
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, apperrors.NotFound("error.contractor_not_found")
|
||||
}
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// FileOwnershipService checks whether a user owns a file referenced by URL.
|
||||
// It queries task completion images, document files, and document images
|
||||
// to determine ownership through residence access.
|
||||
// FileOwnershipService checks whether a user has access to a file referenced
|
||||
// by URL. It queries task completion images, document files, and document
|
||||
// images, resolving access through residence ownership or membership.
|
||||
type FileOwnershipService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
@@ -17,16 +17,31 @@ func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
|
||||
return &FileOwnershipService{db: db}
|
||||
}
|
||||
|
||||
// IsFileOwnedByUser checks if the given file URL belongs to a record
|
||||
// that the user has access to (via residence membership).
|
||||
// accessibleResidenceIDs returns a subquery of residence IDs the user can
|
||||
// access: residences they own (residence_residence.owner_id) UNION residences
|
||||
// they are a member of (residence_residence_users).
|
||||
//
|
||||
// Audit C7: the previous queries joined residence_residence_users only, so a
|
||||
// residence owner who was not also a member of the join table could not pass
|
||||
// the ownership check for files in their own property.
|
||||
func (s *FileOwnershipService) accessibleResidenceIDs(userID uint) *gorm.DB {
|
||||
return s.db.Raw(`
|
||||
SELECT id FROM residence_residence WHERE owner_id = ?
|
||||
UNION
|
||||
SELECT residence_id FROM residence_residence_users WHERE user_id = ?
|
||||
`, userID, userID)
|
||||
}
|
||||
|
||||
// IsFileOwnedByUser checks if the given file URL belongs to a record in a
|
||||
// residence the user owns or is a member of.
|
||||
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
|
||||
// Check task completion images: image_url -> completion -> task -> residence -> user access
|
||||
// Task completion images: image_url -> completion -> task -> residence.
|
||||
var completionImageCount int64
|
||||
err := s.db.Model(&models.TaskCompletionImage{}).
|
||||
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
|
||||
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
|
||||
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Where("task_taskcompletionimage.image_url = ?", fileURL).
|
||||
Where("task_task.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
|
||||
Count(&completionImageCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -35,11 +50,11 @@ func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (b
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check document files: file_url -> document -> residence -> user access
|
||||
// Document files: file_url -> document -> residence.
|
||||
var documentCount int64
|
||||
err = s.db.Model(&models.Document{}).
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Where("task_document.file_url = ?", fileURL).
|
||||
Where("task_document.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
|
||||
Count(&documentCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -48,12 +63,12 @@ func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (b
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check document images: image_url -> document_image -> document -> residence -> user access
|
||||
// Document images: image_url -> document_image -> document -> residence.
|
||||
var documentImageCount int64
|
||||
err = s.db.Model(&models.DocumentImage{}).
|
||||
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Where("task_documentimage.image_url = ?", fileURL).
|
||||
Where("task_document.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
|
||||
Count(&documentImageCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -2,132 +2,306 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"
|
||||
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
|
||||
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
googleKeysCacheTTL = 24 * time.Hour
|
||||
googleKeysCacheKey = "google:public_keys"
|
||||
)
|
||||
|
||||
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
|
||||
var googleIssuers = map[string]bool{
|
||||
"accounts.google.com": true,
|
||||
"https://accounts.google.com": true,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
|
||||
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
|
||||
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
|
||||
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
|
||||
ErrGoogleKeyNotFound = errors.New("Google public key not found")
|
||||
)
|
||||
|
||||
// GoogleTokenInfo represents the response from Google's token info endpoint
|
||||
type GoogleTokenInfo struct {
|
||||
Sub string `json:"sub"` // Unique Google user ID
|
||||
Email string `json:"email"` // User's email
|
||||
EmailVerified string `json:"email_verified"` // "true" or "false"
|
||||
Name string `json:"name"` // Full name
|
||||
GivenName string `json:"given_name"` // First name
|
||||
FamilyName string `json:"family_name"` // Last name
|
||||
Picture string `json:"picture"` // Profile picture URL
|
||||
Aud string `json:"aud"` // Audience (client ID)
|
||||
Azp string `json:"azp"` // Authorized party
|
||||
Exp string `json:"exp"` // Expiration time
|
||||
Iss string `json:"iss"` // Issuer
|
||||
// GoogleJWKS represents Google's JSON Web Key Set.
|
||||
type GoogleJWKS struct {
|
||||
Keys []GoogleJWK `json:"keys"`
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified
|
||||
// GoogleJWK represents a single JSON Web Key from Google.
|
||||
type GoogleJWK struct {
|
||||
Kty string `json:"kty"` // Key type (RSA)
|
||||
Kid string `json:"kid"` // Key ID
|
||||
Use string `json:"use"` // Key use (sig)
|
||||
Alg string `json:"alg"` // Algorithm (RS256)
|
||||
N string `json:"n"` // RSA modulus
|
||||
E string `json:"e"` // RSA exponent
|
||||
}
|
||||
|
||||
// GoogleTokenClaims represents the claims in a Google ID token JWT.
|
||||
type GoogleTokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified bool `json:"email_verified,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Azp string `json:"azp,omitempty"` // Authorized party
|
||||
}
|
||||
|
||||
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
|
||||
type GoogleTokenInfo struct {
|
||||
Sub string // Unique Google user ID
|
||||
Email string
|
||||
EmailVerified string // "true" or "false" — string for caller compatibility
|
||||
Name string
|
||||
GivenName string
|
||||
FamilyName string
|
||||
Picture string
|
||||
Aud string
|
||||
Azp string
|
||||
Iss string
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified.
|
||||
func (t *GoogleTokenInfo) IsEmailVerified() bool {
|
||||
return t.EmailVerified == "true"
|
||||
}
|
||||
|
||||
// GoogleAuthService handles Google Sign In token verification
|
||||
// GoogleAuthService handles Google Sign In token verification.
|
||||
type GoogleAuthService struct {
|
||||
cache *CacheService
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewGoogleAuthService creates a new Google auth service
|
||||
// NewGoogleAuthService creates a new Google auth service.
|
||||
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
|
||||
return &GoogleAuthService{
|
||||
cache: cache,
|
||||
config: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyIDToken verifies a Google ID token and returns the token info
|
||||
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
|
||||
// the RS256 signature against Google's published JWKS and the iss, aud, and
|
||||
// exp claims. It never sends the token to a third-party endpoint, so it no
|
||||
// longer depends on the deprecated tokeninfo service and never leaks the
|
||||
// token in a request URL.
|
||||
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
|
||||
// Call Google's tokeninfo endpoint to verify the token
|
||||
url := fmt.Sprintf("%s?id_token=%s", googleTokenInfoURL, idToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Parse the token header to get the key ID.
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
var tokenInfo GoogleTokenInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token info: %w", err)
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||
}
|
||||
var header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||
}
|
||||
|
||||
// Verify the audience matches our client ID(s)
|
||||
if !s.verifyAudience(tokenInfo.Aud, tokenInfo.Azp) {
|
||||
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
|
||||
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrGoogleTokenExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*GoogleTokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
// Verify the issuer (audit C3).
|
||||
if !googleIssuers[claims.Issuer] {
|
||||
return nil, ErrInvalidGoogleIssuer
|
||||
}
|
||||
|
||||
// Verify the audience matches one of our configured client IDs.
|
||||
if !s.verifyAudience(claims.Audience, claims.Azp) {
|
||||
return nil, ErrInvalidGoogleAudience
|
||||
}
|
||||
|
||||
// Verify the token is not expired (tokeninfo endpoint already checks this,
|
||||
// but we double-check for security)
|
||||
if tokenInfo.Sub == "" {
|
||||
if claims.Subject == "" {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
emailVerified := "false"
|
||||
if claims.EmailVerified {
|
||||
emailVerified = "true"
|
||||
}
|
||||
aud := ""
|
||||
if len(claims.Audience) > 0 {
|
||||
aud = claims.Audience[0]
|
||||
}
|
||||
return &GoogleTokenInfo{
|
||||
Sub: claims.Subject,
|
||||
Email: claims.Email,
|
||||
EmailVerified: emailVerified,
|
||||
Name: claims.Name,
|
||||
GivenName: claims.GivenName,
|
||||
FamilyName: claims.FamilyName,
|
||||
Picture: claims.Picture,
|
||||
Aud: aud,
|
||||
Azp: claims.Azp,
|
||||
Iss: claims.Issuer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the token audience matches our client ID(s).
|
||||
// In production (non-debug), an empty clientID causes verification to fail
|
||||
// rather than silently bypassing the check.
|
||||
func (s *GoogleAuthService) verifyAudience(aud, azp string) bool {
|
||||
// verifyAudience checks the token audience against our configured client IDs.
|
||||
// In production (non-debug) an empty client ID fails verification rather than
|
||||
// silently bypassing the check.
|
||||
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
|
||||
clientID := s.config.GoogleAuth.ClientID
|
||||
if clientID == "" {
|
||||
if s.config.Server.Debug {
|
||||
// In debug mode only, skip audience verification for local development
|
||||
// In debug mode only, skip audience verification for local development.
|
||||
return s.config.Server.Debug
|
||||
}
|
||||
|
||||
candidates := []string{clientID}
|
||||
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
if id := s.config.GoogleAuth.IOSClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
|
||||
for _, want := range candidates {
|
||||
if azp == want {
|
||||
return true
|
||||
}
|
||||
// In production, missing client ID means we cannot verify the audience
|
||||
return false
|
||||
for _, aud := range audience {
|
||||
if aud == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check both aud and azp (Android vs iOS may use different values)
|
||||
if aud == clientID || azp == clientID {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check Android client ID if configured
|
||||
androidClientID := s.config.GoogleAuth.AndroidClientID
|
||||
if androidClientID != "" && (aud == androidClientID || azp == androidClientID) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check iOS client ID if configured
|
||||
iosClientID := s.config.GoogleAuth.IOSClientID
|
||||
if iosClientID != "" && (aud == iosClientID || azp == iosClientID) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getPublicKey returns the RSA public key for the given key ID, using a
|
||||
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
|
||||
// (Google rotates signing keys roughly daily).
|
||||
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||
keys, err := s.getCachedKeys(ctx)
|
||||
if err != nil || keys == nil {
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
// Cache miss for this kid — keys may have rotated; fetch fresh.
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
return nil, ErrGoogleKeyNotFound
|
||||
}
|
||||
|
||||
// getCachedKeys retrieves cached Google public keys from Redis.
|
||||
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
if s.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
|
||||
if err != nil || data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
|
||||
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
|
||||
}
|
||||
if s.cache != nil {
|
||||
keysJSON, _ := json.Marshal(jwks)
|
||||
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
|
||||
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
@@ -68,13 +68,14 @@ type AppleTransactionInfo struct {
|
||||
|
||||
// AppleValidationResult contains the result of Apple receipt validation
|
||||
type AppleValidationResult struct {
|
||||
Valid bool
|
||||
TransactionID string
|
||||
ProductID string
|
||||
ExpiresAt time.Time
|
||||
IsTrialPeriod bool
|
||||
AutoRenewEnabled bool
|
||||
Environment string
|
||||
Valid bool
|
||||
TransactionID string
|
||||
OriginalTransactionID string // stable across renewals — the replay key
|
||||
ProductID string
|
||||
ExpiresAt time.Time
|
||||
IsTrialPeriod bool
|
||||
AutoRenewEnabled bool
|
||||
Environment string
|
||||
}
|
||||
|
||||
// GoogleValidationResult contains the result of Google token validation
|
||||
@@ -95,6 +96,21 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
|
||||
return nil, ErrIAPNotConfigured
|
||||
}
|
||||
|
||||
// Audit H5 (relaxed per MEDIUM-2): refuse to load the IAP signing key from
|
||||
// a world-accessible file — a leaked .p8 lets an attacker forge App Store
|
||||
// Server API requests. The original "0600 or stricter" check is
|
||||
// incompatible with a Kubernetes Secret volume: the kubelet widens secret
|
||||
// files to 0440 once fsGroup is set, so 0600 is unattainable for a
|
||||
// non-root container. Group access is scoped to the pod's fsGroup; the
|
||||
// real exposure is the "other" bits, so reject only those.
|
||||
info, err := os.Stat(cfg.KeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to stat Apple IAP key: %w", err)
|
||||
}
|
||||
if perm := info.Mode().Perm(); perm&0o007 != 0 {
|
||||
return nil, fmt.Errorf("Apple IAP key %s is world-accessible (permissions %#o); remove other-rwx bits", cfg.KeyPath, perm)
|
||||
}
|
||||
|
||||
// Read the private key
|
||||
keyData, err := os.ReadFile(cfg.KeyPath)
|
||||
if err != nil {
|
||||
@@ -215,11 +231,12 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
|
||||
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
|
||||
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
OriginalTransactionID: transactionInfo.OriginalTransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -243,11 +260,12 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
||||
if err == nil {
|
||||
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
OriginalTransactionID: transactionInfo.OriginalTransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -317,11 +335,12 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
||||
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
|
||||
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
OriginalTransactionID: transactionInfo.OriginalTransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -418,13 +437,14 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
|
||||
}
|
||||
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: latestReceipt.TransactionID,
|
||||
ProductID: latestReceipt.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
IsTrialPeriod: latestReceipt.IsTrialPeriod == "true",
|
||||
AutoRenewEnabled: autoRenew,
|
||||
Environment: legacyResponse.Environment,
|
||||
Valid: true,
|
||||
TransactionID: latestReceipt.TransactionID,
|
||||
OriginalTransactionID: latestReceipt.OriginalTransactionID,
|
||||
ProductID: latestReceipt.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
IsTrialPeriod: latestReceipt.IsTrialPeriod == "true",
|
||||
AutoRenewEnabled: autoRenew,
|
||||
Environment: legacyResponse.Environment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -308,7 +308,18 @@ func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uin
|
||||
// Check if device exists
|
||||
existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID)
|
||||
if err == nil {
|
||||
// Update existing device
|
||||
// Audit C8 / LOW-3: APNs device tokens are recycled across devices,
|
||||
// app reinstalls and OS reassignments, so a token already bound to a
|
||||
// different account is a stale binding — not a hijack. Reassign it to
|
||||
// the current (authenticated) registrant rather than reject: a 409
|
||||
// here would lock the legitimate new owner of a recycled token out of
|
||||
// push entirely. The reassignment is logged as a security-relevant
|
||||
// event so a genuine token-takeover attempt is still traceable.
|
||||
if existing.UserID != nil && *existing.UserID != userID {
|
||||
log.Warn().Uint("user_id", userID).Uint("previous_owner_id", *existing.UserID).
|
||||
Msg("APNS device token reassigned to a new account")
|
||||
}
|
||||
// Update existing device — reassign to the current user
|
||||
existing.UserID = &userID
|
||||
existing.Active = true
|
||||
existing.Name = req.Name
|
||||
@@ -337,7 +348,18 @@ func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint
|
||||
// Check if device exists
|
||||
existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID)
|
||||
if err == nil {
|
||||
// Update existing device
|
||||
// Audit C8 / LOW-3: FCM device tokens are recycled across devices,
|
||||
// app reinstalls and OS reassignments, so a token already bound to a
|
||||
// different account is a stale binding — not a hijack. Reassign it to
|
||||
// the current (authenticated) registrant rather than reject: a 409
|
||||
// here would lock the legitimate new owner of a recycled token out of
|
||||
// push entirely. The reassignment is logged as a security-relevant
|
||||
// event so a genuine token-takeover attempt is still traceable.
|
||||
if existing.UserID != nil && *existing.UserID != userID {
|
||||
log.Warn().Uint("user_id", userID).Uint("previous_owner_id", *existing.UserID).
|
||||
Msg("GCM device token reassigned to a new account")
|
||||
}
|
||||
// Update existing device — reassign to the current user
|
||||
existing.UserID = &userID
|
||||
existing.Active = true
|
||||
existing.Name = req.Name
|
||||
|
||||
@@ -559,30 +559,22 @@ func (s *ResidenceService) GenerateSharePackage(ctx context.Context, residenceID
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JoinWithCode allows a user to join a residence using a share code
|
||||
// JoinWithCode allows a user to join a residence using a share code.
|
||||
// Audit C9/H9: the code lookup, membership add, and one-time-code
|
||||
// deactivation run as a single locked transaction in the repository, so a
|
||||
// code can never be redeemed twice and a deactivation failure aborts the join.
|
||||
func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID uint) (*responses.JoinResidenceResponse, error) {
|
||||
// Find the share code
|
||||
shareCode, err := s.residenceRepo.WithContext(ctx).FindShareCodeByCode(code)
|
||||
residenceID, alreadyMember, err := s.residenceRepo.WithContext(ctx).JoinWithShareCode(code, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, apperrors.NotFound("error.share_code_invalid")
|
||||
}
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Check if already a member
|
||||
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(shareCode.ResidenceID, userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if hasAccess {
|
||||
if alreadyMember {
|
||||
return nil, apperrors.Conflict("error.user_already_member")
|
||||
}
|
||||
|
||||
// Add user to residence
|
||||
if err := s.residenceRepo.WithContext(ctx).AddUser(shareCode.ResidenceID, userID); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if s.cache != nil {
|
||||
// The joining user's residence-IDs cache is now stale, and their
|
||||
// subscription status now reflects an extra residence with all of its
|
||||
@@ -591,15 +583,8 @@ func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
|
||||
}
|
||||
|
||||
// Mark share code as used (one-time use)
|
||||
if err := s.residenceRepo.WithContext(ctx).DeactivateShareCode(shareCode.ID); err != nil {
|
||||
// Log the error but don't fail the join - the user has already been added
|
||||
// The code will just be usable by others until it expires
|
||||
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate share code after join")
|
||||
}
|
||||
|
||||
// Get the residence with full details
|
||||
residence, err := s.residenceRepo.WithContext(ctx).FindByID(shareCode.ResidenceID)
|
||||
residence, err := s.residenceRepo.WithContext(ctx).FindByID(residenceID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
@@ -399,99 +399,135 @@ func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID ui
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ProcessApplePurchase processes an Apple IAP purchase
|
||||
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
|
||||
// ProcessApplePurchase processes an Apple IAP purchase.
|
||||
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID).
|
||||
func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
|
||||
// Store receipt/transaction data
|
||||
dataToStore := receiptData
|
||||
if dataToStore == "" {
|
||||
dataToStore = transactionID
|
||||
}
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Apple IAP client must be configured to validate purchases.
|
||||
// Without server-side validation, we cannot trust client-provided receipts.
|
||||
// Apple IAP client must be configured — without server-side validation
|
||||
// we cannot trust client-provided receipts.
|
||||
if s.appleClient == nil {
|
||||
log.Error().Uint("user_id", userID).Msg("Apple IAP validation not configured, rejecting purchase")
|
||||
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// Validation is a network call to Apple; detach from the request context
|
||||
// so a client disconnect cannot abort an in-flight grant.
|
||||
vctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *AppleValidationResult
|
||||
var err error
|
||||
|
||||
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
|
||||
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1).
|
||||
if transactionID != "" {
|
||||
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
|
||||
result, err = s.appleClient.ValidateTransaction(vctx, transactionID)
|
||||
} else if receiptData != "" {
|
||||
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
|
||||
result, err = s.appleClient.ValidateReceipt(vctx, receiptData)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Validation failed -- do NOT fall through to grant Pro.
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Apple validation failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, apperrors.BadRequest("error.no_receipt_or_transaction")
|
||||
}
|
||||
|
||||
// Audit C5/C10: replay protection. A validated transaction may only ever
|
||||
// be bound to one account — re-submitting a valid receipt against a
|
||||
// second account must not grant Pro for free. The partial unique index
|
||||
// on apple_original_transaction_id is the backstop for the check/store
|
||||
// race below.
|
||||
if result.OriginalTransactionID != "" {
|
||||
existing, lookupErr := s.subscriptionRepo.WithContext(vctx).FindByAppleOriginalTransactionID(result.OriginalTransactionID)
|
||||
switch {
|
||||
case lookupErr == nil && existing != nil && existing.UserID != userID:
|
||||
log.Warn().Uint("user_id", userID).Uint("bound_user_id", existing.UserID).
|
||||
Msg("Apple purchase rejected — transaction already claimed by another account")
|
||||
return nil, apperrors.Forbidden("error.iap_transaction_already_claimed")
|
||||
case lookupErr != nil && !errors.Is(lookupErr, gorm.ErrRecordNotFound):
|
||||
return nil, apperrors.Internal(lookupErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Persist the receipt blob and the replay key.
|
||||
dataToStore := receiptData
|
||||
if dataToStore == "" {
|
||||
dataToStore = transactionID
|
||||
}
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpdateReceiptData(userID, dataToStore); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if result.OriginalTransactionID != "" {
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpdateAppleOriginalTransactionID(userID, result.OriginalTransactionID); err != nil {
|
||||
// The unique index rejected the bind — a concurrent request
|
||||
// claimed the same transaction first.
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to bind Apple transaction ID")
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt := result.ExpiresAt
|
||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
|
||||
|
||||
// Upgrade to Pro with the validated expiration
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Tier flipped — drop cached SubscriptionStatusResponse so the next call
|
||||
// returns Pro immediately instead of stale Free.
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(vctx, userID)
|
||||
}
|
||||
|
||||
return s.GetSubscription(ctx, userID)
|
||||
return s.GetSubscription(vctx, userID)
|
||||
}
|
||||
|
||||
// ProcessGooglePurchase processes a Google Play purchase
|
||||
// productID is optional but helps validate the specific subscription
|
||||
func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
|
||||
// Store purchase token first
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Google IAP client must be configured to validate purchases.
|
||||
// Without server-side validation, we cannot trust client-provided tokens.
|
||||
// Google IAP client must be configured — without server-side validation
|
||||
// we cannot trust client-provided tokens.
|
||||
if s.googleClient == nil {
|
||||
log.Error().Uint("user_id", userID).Msg("Google IAP validation not configured, rejecting purchase")
|
||||
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// Audit C6/C10: replay protection — a purchase token may only ever be
|
||||
// bound to one account. The partial unique index on google_purchase_token
|
||||
// is the backstop for the check/store race.
|
||||
if purchaseToken != "" {
|
||||
existing, lookupErr := s.subscriptionRepo.WithContext(ctx).FindByGoogleToken(purchaseToken)
|
||||
switch {
|
||||
case lookupErr == nil && existing != nil && existing.UserID != userID:
|
||||
log.Warn().Uint("user_id", userID).Uint("bound_user_id", existing.UserID).
|
||||
Msg("Google purchase rejected — token already claimed by another account")
|
||||
return nil, apperrors.Forbidden("error.iap_transaction_already_claimed")
|
||||
case lookupErr != nil && !errors.Is(lookupErr, gorm.ErrRecordNotFound):
|
||||
return nil, apperrors.Internal(lookupErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Store the purchase token (the replay key).
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Validation is a network call; detach from the request context.
|
||||
vctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *GoogleValidationResult
|
||||
var err error
|
||||
|
||||
// If productID is provided, use it directly; otherwise try known IDs
|
||||
// If productID is provided, use it directly; otherwise try known IDs.
|
||||
if productID != "" {
|
||||
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
|
||||
result, err = s.googleClient.ValidateSubscription(vctx, productID, purchaseToken)
|
||||
} else {
|
||||
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
|
||||
result, err = s.googleClient.ValidatePurchaseToken(vctx, purchaseToken, KnownSubscriptionIDs)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Validation failed -- do NOT fall through to grant Pro.
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Google purchase validation failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, apperrors.BadRequest("error.no_purchase_token")
|
||||
}
|
||||
@@ -499,24 +535,23 @@ func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID
|
||||
expiresAt := result.ExpiresAt
|
||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Bool("auto_renew", result.AutoRenewing).Msg("Google purchase validated")
|
||||
|
||||
// Acknowledge the subscription if not already acknowledged
|
||||
// Acknowledge the subscription if not already acknowledged.
|
||||
if !result.AcknowledgedState {
|
||||
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
|
||||
if err := s.googleClient.AcknowledgeSubscription(vctx, result.ProductID, purchaseToken); err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to acknowledge Google subscription")
|
||||
// Don't fail the purchase, just log the warning
|
||||
// Don't fail the purchase, just log the warning.
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade to Pro with the validated expiration
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(vctx, userID)
|
||||
}
|
||||
|
||||
return s.GetSubscription(ctx, userID)
|
||||
return s.GetSubscription(vctx, userID)
|
||||
}
|
||||
|
||||
// CancelSubscription cancels a subscription (downgrades to free at end of period)
|
||||
|
||||
Reference in New Issue
Block a user