fix(security): remediate 2026-05-12 audit findings (Stages 2–5)
Remediation of the 2026-05-12/13 audits (78 findings + cluster gaps), tracked in deploy-k3s/SECURITY.md, plus fixes from two independent post-remediation reviews. Auth & sessions: - SHA-256 hashed auth-token storage (C1); prior-token cache eviction on re-login (MEDIUM-1) - local Google JWKS verification, iss/aud/exp checks (C2/C3) - constant-time login + generic errors (L1/LIVE-L11/LIVE-L13) - per-account login lockout keyed on distinct source IPs (M5/MEDIUM-3) - verified-email gating, login rate limiting (LIVE-L19, H1-H3) IAP & webhooks: - Apple/Google cross-account replay protection (C5/C6/C10/C13, H5/H6) - migrations 000003-000006 (token hashing, IAP replay, audit_log + webhook_event_log table creation, append-only audit log) Authorization & races: - file-ownership owner-OR-member fix (C7), atomic share-code join (C9/H9), device-token reassignment (C8/LOW-3) Secrets & deploy: - secrets file-mounted at /etc/honeydue/secrets, not env (F8); Redis password out of the ConfigMap (HIGH-1); B2 keys reconciled - digest-pinned images, admin ingress hardening, CSP/HSTS, /metrics lockdown; kubeconfig 0600, etcd secrets-encryption, fail2ban + unattended-upgrades at provision; secret-rotation runbook Build, vet, and the full test suite (incl. -race) pass; the goose migration chain is verified against PostgreSQL 16. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@@ -216,6 +217,11 @@ func Load() (*Config, error) {
|
||||
// Set defaults
|
||||
setDefaults()
|
||||
|
||||
// Audit F8: overlay file-mounted secrets onto Viper. No-op when the
|
||||
// directory is absent (local/dev), so this is safe to ship before the
|
||||
// manifests mount honeydue-secrets as a volume.
|
||||
loadFileSecrets()
|
||||
|
||||
// Parse DATABASE_URL if set (Dokku-style)
|
||||
dbConfig := DatabaseConfig{
|
||||
Host: viper.GetString("DB_HOST"),
|
||||
@@ -432,14 +438,67 @@ func isWeakSecretKey(key string) bool {
|
||||
return knownWeakSecretKeys[strings.ToLower(strings.TrimSpace(key))]
|
||||
}
|
||||
|
||||
// loadFileSecrets overlays file-mounted secrets onto Viper (audit F8). When
|
||||
// the honeydue-secrets Secret is mounted as a volume at /etc/honeydue/secrets
|
||||
// each key is a file; reading the value here and viper.Set-ing it (highest
|
||||
// Viper precedence) keeps the secret out of the process environment
|
||||
// (/proc/<pid>/environ), which plain env-var injection cannot. When the
|
||||
// directory is absent it is a silent no-op and env vars are used as before.
|
||||
func loadFileSecrets() {
|
||||
dir := os.Getenv("HONEYDUE_SECRETS_DIR")
|
||||
if dir == "" {
|
||||
dir = "/etc/honeydue/secrets"
|
||||
}
|
||||
for _, k := range []string{
|
||||
"POSTGRES_PASSWORD", "SECRET_KEY", "EMAIL_HOST_PASSWORD", "FCM_SERVER_KEY",
|
||||
"REDIS_PASSWORD", "B2_KEY_ID", "B2_APP_KEY", "OBS_INGEST_TOKEN", "OBS_TRACES_URL",
|
||||
} {
|
||||
b, err := os.ReadFile(dir + "/" + k)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if v := strings.TrimSpace(string(b)); v != "" {
|
||||
viper.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SecretValue resolves a configuration value that is not part of the typed
|
||||
// Config struct. It reads through Viper, so a value supplied via a file-mounted
|
||||
// secret (audit F8, loaded by loadFileSecrets) is found just like an env var.
|
||||
//
|
||||
// Must be called after Load(). Used by cmd/api and cmd/worker for the
|
||||
// observability endpoints, which are needed before the full Config is wired
|
||||
// and would otherwise be read with os.Getenv — which misses file-mounted
|
||||
// secrets entirely once F8 removes them from the process environment.
|
||||
func SecretValue(key string) string {
|
||||
return viper.GetString(key)
|
||||
}
|
||||
|
||||
// randomHexKey returns a cryptographically secure random hex string
|
||||
// representing n random bytes (2n hex characters).
|
||||
func randomHexKey(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func validate(cfg *Config) error {
|
||||
// S-08: Validate SECRET_KEY against known weak defaults
|
||||
// M8: SECRET_KEY validation — no static fallback secret in the binary.
|
||||
if cfg.Security.SecretKey == "" {
|
||||
if cfg.Server.Debug {
|
||||
// In debug mode, use a default key with a warning for local development
|
||||
cfg.Security.SecretKey = "change-me-in-production-secret-key-12345"
|
||||
fmt.Println("WARNING: SECRET_KEY not set, using default (debug mode only)")
|
||||
fmt.Println("WARNING: *** DO NOT USE THIS DEFAULT KEY IN PRODUCTION ***")
|
||||
// Debug only: generate a random key per boot. Tokens signed with
|
||||
// it do not survive a restart, which is acceptable for local dev
|
||||
// and far safer than a well-known hardcoded fallback.
|
||||
randomKey, err := randomHexKey(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate ephemeral debug SECRET_KEY: %w", err)
|
||||
}
|
||||
cfg.Security.SecretKey = randomKey
|
||||
fmt.Println("WARNING: SECRET_KEY not set, generated an ephemeral random key (debug mode only)")
|
||||
fmt.Println("WARNING: tokens will not survive a restart — set SECRET_KEY for stable local sessions")
|
||||
} else {
|
||||
// In production, refuse to start without a proper secret key
|
||||
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
|
||||
@@ -452,6 +511,12 @@ func validate(cfg *Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// C4: fixed confirmation codes ("123456") must never be enabled outside
|
||||
// debug — with DEBUG=false they are a full authentication bypass.
|
||||
if cfg.Server.DebugFixedCodes && !cfg.Server.Debug {
|
||||
return fmt.Errorf("FATAL: DEBUG_FIXED_CODES is enabled with DEBUG=false — fixed confirmation codes must never run in production")
|
||||
}
|
||||
|
||||
// Database password might come from DATABASE_URL, don't require it separately
|
||||
// The actual connection will fail if credentials are wrong
|
||||
|
||||
|
||||
@@ -106,8 +106,10 @@ func TestLoad_Validation_MissingSecretKey_DebugMode(t *testing.T) {
|
||||
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
// In debug mode, a default key is assigned
|
||||
assert.Equal(t, "change-me-in-production-secret-key-12345", c.Security.SecretKey)
|
||||
// Audit M8: in debug mode an ephemeral random key is generated per boot
|
||||
// (no static fallback). It must be a non-empty 64-char hex string.
|
||||
assert.Len(t, c.Security.SecretKey, 64)
|
||||
assert.NotEqual(t, "change-me-in-production-secret-key-12345", c.Security.SecretKey)
|
||||
}
|
||||
|
||||
func TestLoad_Validation_WeakSecretKey_Production(t *testing.T) {
|
||||
@@ -133,6 +135,33 @@ func TestLoad_Validation_WeakSecretKey_DebugMode(t *testing.T) {
|
||||
assert.Equal(t, "secret", c.Security.SecretKey)
|
||||
}
|
||||
|
||||
// Audit C4: DEBUG_FIXED_CODES makes confirmation codes a fixed "123456" — a
|
||||
// full authentication bypass. With DEBUG=false, validate() must refuse to boot
|
||||
// rather than ship that bypass to production.
|
||||
func TestLoad_Validation_DebugFixedCodes_Production(t *testing.T) {
|
||||
// validate() directly — avoids the sync.Once issue Load() has on failure.
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{Debug: false, DebugFixedCodes: true},
|
||||
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
|
||||
}
|
||||
|
||||
err := validate(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "DEBUG_FIXED_CODES")
|
||||
}
|
||||
|
||||
// With DEBUG=true the fixed codes are an intended local-dev convenience, so
|
||||
// the same combination must NOT error.
|
||||
func TestLoad_Validation_DebugFixedCodes_DebugMode(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{Debug: true, DebugFixedCodes: true},
|
||||
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
|
||||
}
|
||||
|
||||
err := validate(cfg)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLoad_Validation_EncryptionKey_Valid(t *testing.T) {
|
||||
resetConfigState()
|
||||
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
@@ -55,8 +56,15 @@ func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
|
||||
h.auditService = auditService
|
||||
}
|
||||
|
||||
// noStore marks a response as non-cacheable (audit L2) — auth responses
|
||||
// carry tokens and user data that must never sit in any cache.
|
||||
func noStore(c echo.Context) {
|
||||
c.Response().Header().Set("Cache-Control", "no-store")
|
||||
}
|
||||
|
||||
// Login handles POST /api/auth/login/
|
||||
func (h *AuthHandler) Login(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.LoginRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
@@ -65,9 +73,11 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
response, err := h.authService.Login(c.Request().Context(), &req)
|
||||
response, err := h.authService.Login(c.Request().Context(), &req, c.RealIP())
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
|
||||
log.Debug().Err(err).Str("identifier", req.Username).
|
||||
Str("ip", c.RealIP()).Str("user_agent", c.Request().UserAgent()).
|
||||
Msg("Login failed")
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
|
||||
"identifier": req.Username,
|
||||
@@ -86,6 +96,7 @@ func (h *AuthHandler) Login(c echo.Context) error {
|
||||
|
||||
// Register handles POST /api/auth/register/
|
||||
func (h *AuthHandler) Register(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.RegisterRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
@@ -157,6 +168,7 @@ func (h *AuthHandler) Logout(c echo.Context) error {
|
||||
|
||||
// CurrentUser handles GET /api/auth/me/
|
||||
func (h *AuthHandler) CurrentUser(c echo.Context) error {
|
||||
noStore(c)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -276,31 +288,7 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
||||
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
|
||||
}
|
||||
|
||||
code, user, err := h.authService.ForgotPassword(c.Request().Context(), req.Email)
|
||||
if err != nil {
|
||||
var appErr *apperrors.AppError
|
||||
if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests {
|
||||
// Only reveal rate limit errors
|
||||
return err
|
||||
}
|
||||
|
||||
log.Error().Err(err).Str("email", req.Email).Msg("Forgot password failed")
|
||||
// Don't reveal other errors to prevent email enumeration
|
||||
}
|
||||
|
||||
// Send password reset email (async) - only if user found
|
||||
if h.emailService != nil && code != "" && user != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in password reset email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
|
||||
}
|
||||
}()
|
||||
}
|
||||
noStore(c)
|
||||
|
||||
if h.auditService != nil {
|
||||
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
|
||||
@@ -308,7 +296,33 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Always return success to prevent email enumeration
|
||||
// Audit LIVE-L13: run the user lookup, code generation, and email send
|
||||
// entirely in the background, then return the generic response
|
||||
// immediately. This makes the response time identical whether or not
|
||||
// the email belongs to a real account, defeating timing-based user
|
||||
// enumeration. context.Background() is used because the request context
|
||||
// is cancelled the moment this handler returns. Per-account rate
|
||||
// limiting still runs inside the service; the edge auth-rate-limit
|
||||
// middleware covers per-IP abuse.
|
||||
email := req.Email
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", email).Msg("Panic in forgot-password goroutine")
|
||||
}
|
||||
}()
|
||||
code, user, err := h.authService.ForgotPassword(context.Background(), email)
|
||||
if err != nil || code == "" || user == nil {
|
||||
return
|
||||
}
|
||||
if h.emailService != nil {
|
||||
if sendErr := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); sendErr != nil {
|
||||
log.Error().Err(sendErr).Str("email", user.Email).Msg("Failed to send password reset email")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Always return success to prevent email enumeration.
|
||||
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
|
||||
Message: "Password reset email sent",
|
||||
})
|
||||
@@ -365,6 +379,7 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
|
||||
|
||||
// AppleSignIn handles POST /api/auth/apple-sign-in/
|
||||
func (h *AuthHandler) AppleSignIn(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.AppleSignInRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
@@ -412,6 +427,7 @@ func (h *AuthHandler) AppleSignIn(c echo.Context) error {
|
||||
|
||||
// GoogleSignIn handles POST /api/auth/google-sign-in/
|
||||
func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
noStore(c)
|
||||
var req requests.GoogleSignInRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
@@ -459,6 +475,7 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
|
||||
// RefreshToken handles POST /api/auth/refresh/
|
||||
func (h *AuthHandler) RefreshToken(c echo.Context) error {
|
||||
noStore(c)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -650,14 +650,14 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
|
||||
authGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", user)
|
||||
c.Set("auth_token", authToken.Key)
|
||||
c.Set("auth_token", authToken.Plaintext) // raw token — repo hashes for lookup (audit C1)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
authGroup.POST("/refresh/", handler.RefreshToken)
|
||||
|
||||
t.Run("successful refresh", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Key)
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Plaintext)
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
|
||||
@@ -37,6 +37,23 @@ func NewMediaHandler(
|
||||
}
|
||||
}
|
||||
|
||||
// safeContentDisposition builds an inline Content-Disposition header value
|
||||
// with a sanitized filename (audit M1). Control characters (including CR/LF),
|
||||
// double-quote and backslash are stripped so an attacker-controlled upload
|
||||
// filename cannot inject additional response headers (CWE-113).
|
||||
func safeContentDisposition(filename string) string {
|
||||
cleaned := strings.Map(func(r rune) rune {
|
||||
if r < 0x20 || r == 0x7f || r == '"' || r == '\\' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, filename)
|
||||
if cleaned == "" {
|
||||
cleaned = "download"
|
||||
}
|
||||
return `inline; filename="` + cleaned + `"`
|
||||
}
|
||||
|
||||
// ServeDocument serves a document file with access control
|
||||
// GET /api/media/document/:id
|
||||
func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
@@ -71,7 +88,7 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
// Set caching and disposition headers
|
||||
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
|
||||
if doc.FileName != "" {
|
||||
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+doc.FileName+"\"")
|
||||
c.Response().Header().Set("Content-Disposition", safeContentDisposition(doc.FileName))
|
||||
}
|
||||
return c.Blob(http.StatusOK, mimeType, data)
|
||||
}
|
||||
@@ -114,7 +131,7 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
|
||||
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"")
|
||||
c.Response().Header().Set("Content-Disposition", safeContentDisposition(filepath.Base(img.ImageURL)))
|
||||
return c.Blob(http.StatusOK, mimeType, data)
|
||||
}
|
||||
|
||||
@@ -162,7 +179,7 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
|
||||
}
|
||||
|
||||
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
|
||||
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"")
|
||||
c.Response().Header().Set("Content-Disposition", safeContentDisposition(filepath.Base(img.ImageURL)))
|
||||
return c.Blob(http.StatusOK, mimeType, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
@@ -165,9 +167,13 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
if notification.NotificationUUID != "" {
|
||||
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to check dedup")
|
||||
// Continue processing on dedup check failure (fail-open)
|
||||
} else if alreadyProcessed {
|
||||
// Audit H6: fail closed. A dedup-check failure must not let a
|
||||
// possibly-duplicate event through (duplicate refunds/grants).
|
||||
// Return 500 so Apple retries once the DB is healthy.
|
||||
log.Error().Err(err).Msg("Apple Webhook: dedup check failed — returning 500 for retry")
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "dedup check failed"})
|
||||
}
|
||||
if alreadyProcessed {
|
||||
log.Info().Str("uuid", notification.NotificationUUID).Msg("Apple Webhook: Duplicate event, skipping")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
|
||||
}
|
||||
@@ -352,10 +358,24 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) findUserByAppleTransaction(originalTransactionID string) (*models.User, error) {
|
||||
// Look up user subscription by stored receipt data
|
||||
subscription, err := h.subscriptionRepo.FindByAppleReceiptContains(originalTransactionID)
|
||||
// Audit C13: exact match on the indexed apple_original_transaction_id
|
||||
// column. Falls back to the legacy escaped-LIKE scan over
|
||||
// apple_receipt_data only for subscriptions created before that column
|
||||
// existed (and thus not yet populated).
|
||||
subscription, err := h.subscriptionRepo.FindByAppleOriginalTransactionID(originalTransactionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Only fall back to the legacy substring scan when the exact-match
|
||||
// column genuinely had no row (a subscription created before the
|
||||
// column existed). A real DB error must propagate — masking it as
|
||||
// "not found" could bind the webhook to the wrong account via the
|
||||
// LIKE scan, or silently drop a legitimate event.
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
subscription, err = h.subscriptionRepo.FindByAppleReceiptContains(originalTransactionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
user, err := h.userRepo.FindByID(subscription.UserID)
|
||||
@@ -566,9 +586,12 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
if messageID != "" {
|
||||
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to check dedup")
|
||||
// Continue processing on dedup check failure (fail-open)
|
||||
} else if alreadyProcessed {
|
||||
// Audit H6: fail closed — see the Apple handler. Return 500 so
|
||||
// Google Pub/Sub redelivers once the DB is healthy.
|
||||
log.Error().Err(err).Msg("Google Webhook: dedup check failed — returning 500 for retry")
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "dedup check failed"})
|
||||
}
|
||||
if alreadyProcessed {
|
||||
log.Info().Str("message_id", messageID).Msg("Google Webhook: Duplicate event, skipping")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
|
||||
}
|
||||
|
||||
@@ -169,6 +169,34 @@ func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// RequireVerified returns middleware that rejects users whose email is not
|
||||
// verified (audit LIVE-L19). Apply it after TokenAuth to gate sensitive
|
||||
// actions — e.g. generating residence share codes — behind proof that the
|
||||
// account actually controls its email address.
|
||||
func (m *AuthMiddleware) RequireVerified() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return apperrors.Unauthorized("error.not_authenticated")
|
||||
}
|
||||
var verified bool
|
||||
err := m.db.WithContext(c.Request().Context()).
|
||||
Model(&models.UserProfile{}).
|
||||
Where("user_id = ?", user.ID).
|
||||
Select("verified").
|
||||
Scan(&verified).Error
|
||||
if err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
if !verified {
|
||||
return apperrors.Forbidden("error.email_not_verified")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractToken extracts the token from the Authorization header
|
||||
func extractToken(c echo.Context) (string, error) {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
@@ -297,7 +325,7 @@ func (m *AuthMiddleware) getUserFromDatabaseWithToken(ctx context.Context, token
|
||||
u.last_login AS u_last_login
|
||||
`).
|
||||
Joins("INNER JOIN auth_user u ON u.id = t.user_id").
|
||||
Where("t.key = ?", token).
|
||||
Where("t.key = ?", models.HashToken(token)). // audit C1: only the hash is stored
|
||||
Limit(1).
|
||||
Scan(&row).Error
|
||||
if err != nil || row.Key == "" {
|
||||
|
||||
@@ -65,7 +65,7 @@ func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestTokenAuth_AcceptsValidToken(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token.Key)
|
||||
req.Header.Set("Authorization", "Bearer "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -46,7 +46,7 @@ func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Basic "+token.Key)
|
||||
req.Header.Set("Authorization", "Basic "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -110,7 +110,7 @@ func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -156,7 +156,7 @@ func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -182,7 +182,7 @@ func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -242,7 +242,7 @@ func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
@@ -270,7 +270,7 @@ func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
req.Header.Set("Authorization", "Token "+token.Plaintext)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
|
||||
@@ -99,21 +99,23 @@ func parseTimezone(tz string) *time.Location {
|
||||
return loc
|
||||
}
|
||||
|
||||
// Try parsing as UTC offset (e.g., "-08:00", "+05:30")
|
||||
// We parse a reference time with the given offset to extract the offset value
|
||||
t, err := time.Parse("-07:00", tz)
|
||||
if err == nil {
|
||||
// time.Parse returns a time, we need to extract the offset
|
||||
// The parsed time will have the offset embedded
|
||||
_, offset := t.Zone()
|
||||
return time.FixedZone(tz, offset)
|
||||
// Try parsing as a UTC offset (e.g., "-08:00", "+05:30"). Audit H8:
|
||||
// reject absurd offsets — real timezones are within ±14h of UTC — so a
|
||||
// crafted X-Timezone header cannot shift date math arbitrarily.
|
||||
const maxOffsetSeconds = 14 * 3600
|
||||
if t, err := time.Parse("-07:00", tz); err == nil {
|
||||
if _, offset := t.Zone(); offset >= -maxOffsetSeconds && offset <= maxOffsetSeconds {
|
||||
return time.FixedZone(tz, offset)
|
||||
}
|
||||
return time.UTC
|
||||
}
|
||||
|
||||
// Also try without colon (e.g., "-0800")
|
||||
t, err = time.Parse("-0700", tz)
|
||||
if err == nil {
|
||||
_, offset := t.Zone()
|
||||
return time.FixedZone(tz, offset)
|
||||
if t, err := time.Parse("-0700", tz); err == nil {
|
||||
if _, offset := t.Zone(); offset >= -maxOffsetSeconds && offset <= maxOffsetSeconds {
|
||||
return time.FixedZone(tz, offset)
|
||||
}
|
||||
return time.UTC
|
||||
}
|
||||
|
||||
// Default to UTC
|
||||
|
||||
@@ -252,7 +252,8 @@ func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Len(t, token.Key, 40) // 20 bytes = 40 hex chars
|
||||
assert.Len(t, token.Key, 64) // SHA-256 hex hash (audit C1)
|
||||
assert.Len(t, token.Plaintext, 40) // raw 20-byte token, returned to the client
|
||||
assert.False(t, token.Created.IsZero())
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,9 @@ type UserSubscription struct {
|
||||
// In-App Purchase data (Apple / Google)
|
||||
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
|
||||
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
|
||||
// AppleOriginalTransactionID binds an Apple subscription to one account
|
||||
// (audit C5/C13). A partial unique index enforces one-account-per-txn.
|
||||
AppleOriginalTransactionID *string `gorm:"column:apple_original_transaction_id;type:text" json:"-"`
|
||||
|
||||
// Stripe data (web subscriptions)
|
||||
StripeCustomerID *string `gorm:"column:stripe_customer_id;size:255" json:"-"`
|
||||
|
||||
+55
-20
@@ -2,7 +2,10 @@ package models
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -37,14 +40,16 @@ func (User) TableName() string {
|
||||
return "auth_user"
|
||||
}
|
||||
|
||||
// BcryptCost is the bcrypt work factor for password and code hashing.
|
||||
// 12 (audit M2) is stronger than bcrypt.DefaultCost (10).
|
||||
const BcryptCost = 12
|
||||
|
||||
// SetPassword hashes and sets the password
|
||||
func (u *User) SetPassword(password string) error {
|
||||
// Django uses PBKDF2_SHA256 by default, but we'll use bcrypt for Go
|
||||
// Note: This means passwords set by Django won't work with Go's check
|
||||
// For migration, you'd need to either:
|
||||
// 1. Force password reset for all users
|
||||
// 2. Implement Django's PBKDF2 hasher in Go
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
// Django uses PBKDF2_SHA256 by default, but we use bcrypt for Go.
|
||||
// Passwords set by Django won't verify with Go's bcrypt check — those
|
||||
// users must reset their password after migration.
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -69,12 +74,22 @@ func (u *User) GetFullName() string {
|
||||
return u.Username
|
||||
}
|
||||
|
||||
// AuthToken represents the user_authtoken table
|
||||
// AuthToken represents the user_authtoken table.
|
||||
//
|
||||
// Audit C1: the Key column stores the SHA-256 hash of the token, never the
|
||||
// token itself. The raw token is handed to the client exactly once, at
|
||||
// creation, via the non-persisted Plaintext field — it is never stored or
|
||||
// logged. A database compromise therefore yields no usable session tokens.
|
||||
type AuthToken struct {
|
||||
Key string `gorm:"column:key;primaryKey;size:40" json:"key"`
|
||||
Key string `gorm:"column:key;primaryKey;size:64" json:"-"`
|
||||
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
|
||||
Created time.Time `gorm:"column:created;autoCreateTime" json:"created"`
|
||||
|
||||
// Plaintext is the raw token value. It is never persisted (gorm:"-")
|
||||
// and is only populated on a freshly-created token so the caller can
|
||||
// return it to the client. On a token loaded from the DB it is "".
|
||||
Plaintext string `gorm:"-" json:"-"`
|
||||
|
||||
// Relations
|
||||
User User `gorm:"foreignKey:UserID" json:"-"`
|
||||
}
|
||||
@@ -84,10 +99,13 @@ func (AuthToken) TableName() string {
|
||||
return "user_authtoken"
|
||||
}
|
||||
|
||||
// BeforeCreate generates a token key if not provided
|
||||
// BeforeCreate generates a token if one is not already set, storing only
|
||||
// its hash in Key and the raw value in the non-persisted Plaintext field.
|
||||
func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
|
||||
if t.Key == "" {
|
||||
t.Key = generateToken()
|
||||
raw := generateToken()
|
||||
t.Plaintext = raw
|
||||
t.Key = HashToken(raw)
|
||||
}
|
||||
if t.Created.IsZero() {
|
||||
t.Created = time.Now().UTC()
|
||||
@@ -95,13 +113,23 @@ func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateToken creates a random 40-character hex token
|
||||
// generateToken creates a random 40-character hex token (the raw value).
|
||||
func generateToken() string {
|
||||
b := make([]byte, 20)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// HashToken returns the at-rest representation of an auth token: the
|
||||
// hex-encoded SHA-256 hash. Auth tokens are 160-bit random values, so a
|
||||
// fast deterministic hash is appropriate — there is nothing to brute-force,
|
||||
// and determinism preserves the single indexed-lookup query in the auth
|
||||
// middleware. The raw token is never stored.
|
||||
func HashToken(raw string) string {
|
||||
sum := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// GetOrCreate gets an existing token or creates a new one for the user
|
||||
func GetOrCreateToken(tx *gorm.DB, userID uint) (*AuthToken, error) {
|
||||
var token AuthToken
|
||||
@@ -160,15 +188,22 @@ func (c *ConfirmationCode) IsValid() bool {
|
||||
return !c.IsUsed && time.Now().UTC().Before(c.ExpiresAt)
|
||||
}
|
||||
|
||||
// GenerateCode creates a random 6-digit code
|
||||
// GenerateConfirmationCode creates a uniformly-random 6-digit code using
|
||||
// rejection sampling on crypto/rand (audit H4 — removes the modulo bias of
|
||||
// the previous implementation).
|
||||
func GenerateConfirmationCode() string {
|
||||
b := make([]byte, 3)
|
||||
rand.Read(b)
|
||||
// Convert to 6-digit number
|
||||
num := int(b[0])<<16 | int(b[1])<<8 | int(b[2])
|
||||
return string(rune('0'+num%10)) + string(rune('0'+(num/10)%10)) +
|
||||
string(rune('0'+(num/100)%10)) + string(rune('0'+(num/1000)%10)) +
|
||||
string(rune('0'+(num/10000)%10)) + string(rune('0'+(num/100000)%10))
|
||||
for {
|
||||
var b [4]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
continue
|
||||
}
|
||||
// 4294000000 is the largest multiple of 1e6 <= MaxUint32; rejecting
|
||||
// the tail above it makes n % 1000000 perfectly uniform.
|
||||
n := binary.BigEndian.Uint32(b[:])
|
||||
if n < 4294000000 {
|
||||
return fmt.Sprintf("%06d", n%1000000)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PasswordResetCode represents the user_passwordresetcode table
|
||||
@@ -193,7 +228,7 @@ func (PasswordResetCode) TableName() string {
|
||||
|
||||
// SetCode hashes and stores the reset code
|
||||
func (p *PasswordResetCode) SetCode(code string) error {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(code), BcryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
@@ -194,6 +195,60 @@ func (r *ResidenceRepository) HasAccess(residenceID, userID uint) (bool, error)
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// JoinWithShareCode atomically redeems a one-time share code (audit C9/H9):
|
||||
// it locks the share-code row, re-checks validity under the lock, adds the
|
||||
// user to the residence, and deactivates the code — all in one transaction.
|
||||
// Concurrent redemptions of the same code serialize on the row lock; the
|
||||
// loser sees is_active=false and is rejected. A failure to deactivate aborts
|
||||
// the whole join. Returns gorm.ErrRecordNotFound for an unknown, inactive, or
|
||||
// expired code so the caller can map every case to one generic error.
|
||||
func (r *ResidenceRepository) JoinWithShareCode(code string, userID uint) (residenceID uint, alreadyMember bool, err error) {
|
||||
err = r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var sc models.ResidenceShareCode
|
||||
if e := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
||||
Where("code = ?", code).First(&sc).Error; e != nil {
|
||||
return e
|
||||
}
|
||||
if !sc.IsActive || (sc.ExpiresAt != nil && time.Now().UTC().After(*sc.ExpiresAt)) {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
residenceID = sc.ResidenceID
|
||||
|
||||
// Already a member (owner or shared user)?
|
||||
var accessCount int64
|
||||
if e := tx.Raw(`
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1 FROM residence_residence
|
||||
WHERE id = ? AND owner_id = ? AND is_active = true
|
||||
UNION
|
||||
SELECT 1 FROM residence_residence_users
|
||||
WHERE residence_id = ? AND user_id = ?
|
||||
) ac
|
||||
`, sc.ResidenceID, userID, sc.ResidenceID, userID).Scan(&accessCount).Error; e != nil {
|
||||
return e
|
||||
}
|
||||
if accessCount > 0 {
|
||||
alreadyMember = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if e := tx.Exec(
|
||||
"INSERT INTO residence_residence_users (residence_id, user_id) VALUES (?, ?) ON CONFLICT DO NOTHING",
|
||||
sc.ResidenceID, userID,
|
||||
).Error; e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
// One-time use: deactivate the code. A failure here aborts the join.
|
||||
if e := tx.Model(&models.ResidenceShareCode{}).
|
||||
Where("id = ?", sc.ID).Update("is_active", false).Error; e != nil {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return residenceID, alreadyMember, err
|
||||
}
|
||||
|
||||
// IsOwner checks if a user is the owner of a residence
|
||||
func (r *ResidenceRepository) IsOwner(residenceID, userID uint) (bool, error) {
|
||||
var count int64
|
||||
|
||||
@@ -151,6 +151,28 @@ func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// FindByAppleOriginalTransactionID finds a subscription by the Apple original
|
||||
// transaction ID (audit C5/C13). Exact match on an indexed column — replaces
|
||||
// the LIKE scan in FindByAppleReceiptContains for both replay detection and
|
||||
// webhook user lookup.
|
||||
func (r *SubscriptionRepository) FindByAppleOriginalTransactionID(originalTransactionID string) (*models.UserSubscription, error) {
|
||||
var sub models.UserSubscription
|
||||
err := r.db.Where("apple_original_transaction_id = ?", originalTransactionID).First(&sub).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sub, nil
|
||||
}
|
||||
|
||||
// UpdateAppleOriginalTransactionID binds an Apple original transaction ID to a
|
||||
// user's subscription. A partial unique index enforces one account per
|
||||
// transaction (audit C5) — a second account claiming the same ID fails here.
|
||||
func (r *SubscriptionRepository) UpdateAppleOriginalTransactionID(userID uint, originalTransactionID string) error {
|
||||
return r.db.Model(&models.UserSubscription{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("apple_original_transaction_id", originalTransactionID).Error
|
||||
}
|
||||
|
||||
// FindByGoogleToken finds a subscription by Google purchase token
|
||||
// Used by webhooks to find the user associated with a purchase
|
||||
func (r *SubscriptionRepository) FindByGoogleToken(purchaseToken string) (*models.UserSubscription, error) {
|
||||
|
||||
@@ -226,3 +226,48 @@ func TestUpdateExpiresAt(t *testing.T) {
|
||||
require.NotNil(t, updated.ExpiresAt)
|
||||
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
|
||||
}
|
||||
|
||||
// TestSubscriptionRepo_IAPTransactionReplayRejected is the regression test for
|
||||
// audit C5/C6: an in-app-purchase transaction (an Apple original transaction
|
||||
// ID or a Google purchase token) may be bound to exactly one account. Without
|
||||
// that guarantee a valid receipt could be replayed against a second account
|
||||
// to grant Pro for free. The guarantee is the pair of partial unique indexes
|
||||
// added by migration 000004; AutoMigrate does not create them, so this test
|
||||
// recreates them verbatim to exercise the same DB-level enforcement.
|
||||
func TestSubscriptionRepo_IAPTransactionReplayRejected(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
require.NoError(t, db.Exec(`CREATE UNIQUE INDEX uq_subscription_apple_original_txn `+
|
||||
`ON subscription_usersubscription (apple_original_transaction_id) `+
|
||||
`WHERE apple_original_transaction_id IS NOT NULL AND apple_original_transaction_id <> ''`).Error)
|
||||
require.NoError(t, db.Exec(`CREATE UNIQUE INDEX uq_subscription_google_purchase_token `+
|
||||
`ON subscription_usersubscription (google_purchase_token) `+
|
||||
`WHERE google_purchase_token IS NOT NULL AND google_purchase_token <> ''`).Error)
|
||||
|
||||
repo := NewSubscriptionRepository(db)
|
||||
userA := testutil.CreateTestUser(t, db, "iapusera", "iapa@test.com", "password")
|
||||
userB := testutil.CreateTestUser(t, db, "iapuserb", "iapb@test.com", "password")
|
||||
require.NoError(t, db.Create(&models.UserSubscription{UserID: userA.ID, Tier: models.TierFree}).Error)
|
||||
require.NoError(t, db.Create(&models.UserSubscription{UserID: userB.ID, Tier: models.TierFree}).Error)
|
||||
|
||||
t.Run("apple transaction cannot be claimed by a second account", func(t *testing.T) {
|
||||
require.NoError(t, repo.UpdateAppleOriginalTransactionID(userA.ID, "apple-original-txn-1"),
|
||||
"the first account binding the transaction must succeed")
|
||||
err := repo.UpdateAppleOriginalTransactionID(userB.ID, "apple-original-txn-1")
|
||||
require.Error(t, err,
|
||||
"replaying account A's Apple transaction onto account B must be rejected (C5)")
|
||||
})
|
||||
|
||||
t.Run("google purchase token cannot be claimed by a second account", func(t *testing.T) {
|
||||
require.NoError(t, repo.UpdatePurchaseToken(userA.ID, "google-purchase-token-1"),
|
||||
"the first account binding the token must succeed")
|
||||
err := repo.UpdatePurchaseToken(userB.ID, "google-purchase-token-1")
|
||||
require.Error(t, err,
|
||||
"replaying account A's Google purchase token onto account B must be rejected (C6)")
|
||||
})
|
||||
|
||||
t.Run("re-binding the same transaction to the same account is allowed", func(t *testing.T) {
|
||||
// A renewal re-submitting the same transaction for its owner must not
|
||||
// be rejected — the partial unique index excludes the row's own value.
|
||||
require.NoError(t, repo.UpdateAppleOriginalTransactionID(userA.ID, "apple-original-txn-1"))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -174,10 +174,12 @@ func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// FindTokenByKey looks up an auth token by its key value.
|
||||
func (r *UserRepository) FindTokenByKey(key string) (*models.AuthToken, error) {
|
||||
// FindTokenByKey looks up an auth token by its raw key value. The raw token
|
||||
// is hashed (audit C1) before the indexed lookup, since only the hash is
|
||||
// stored.
|
||||
func (r *UserRepository) FindTokenByKey(rawKey string) (*models.AuthToken, error) {
|
||||
var token models.AuthToken
|
||||
if err := r.db.Where("key = ?", key).First(&token).Error; err != nil {
|
||||
if err := r.db.Where("key = ?", models.HashToken(rawKey)).First(&token).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrTokenNotFound
|
||||
}
|
||||
@@ -195,9 +197,46 @@ func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken deletes an auth token
|
||||
// CreateFreshToken issues a new auth token for the user, replacing any
|
||||
// existing one. Because tokens are stored hashed (audit C1) the server
|
||||
// cannot re-issue a previously-minted token's plaintext, so every login
|
||||
// mints a fresh token. The returned token's Plaintext field carries the
|
||||
// raw value to hand to the client; it is never persisted.
|
||||
//
|
||||
// It also returns the stored hashes of the token rows it deleted, so the
|
||||
// caller can evict those entries from the Redis token cache (audit MEDIUM-1).
|
||||
// Without that, a prior (e.g. stolen) token keeps authenticating via a cache
|
||||
// hit for up to the cache TTL even though its DB row is gone.
|
||||
func (r *UserRepository) CreateFreshToken(userID uint) (*models.AuthToken, []string, error) {
|
||||
var token models.AuthToken
|
||||
var oldHashes []string
|
||||
err := r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var old []models.AuthToken
|
||||
if err := tx.Where("user_id = ?", userID).Find(&old).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
oldHashes = make([]string, 0, len(old))
|
||||
for i := range old {
|
||||
if old[i].Key != "" {
|
||||
oldHashes = append(oldHashes, old[i].Key)
|
||||
}
|
||||
}
|
||||
if err := tx.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
token = models.AuthToken{UserID: userID}
|
||||
return tx.Create(&token).Error
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &token, oldHashes, nil
|
||||
}
|
||||
|
||||
// DeleteToken deletes an auth token by its raw key value. The raw token is
|
||||
// hashed (audit C1) before the lookup, since only the hash is stored.
|
||||
func (r *UserRepository) DeleteToken(token string) error {
|
||||
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})
|
||||
result := r.db.Where("key = ?", models.HashToken(token)).Delete(&models.AuthToken{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func TestUserRepository_FindTokenByKey(t *testing.T) {
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindTokenByKey(token.Key)
|
||||
found, err := repo.FindTokenByKey(token.Plaintext)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, found.Key)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
@@ -128,10 +128,10 @@ func TestUserRepository_DeleteToken(t *testing.T) {
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteToken(token.Key)
|
||||
err = repo.DeleteToken(token.Plaintext)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindTokenByKey(token.Key)
|
||||
_, err = repo.FindTokenByKey(token.Plaintext)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
|
||||
@@ -75,10 +75,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
// responses are unaffected — they don't load any assets, so any CSP is fine.
|
||||
// frame-ancestors stays 'none' to block clickjacking.
|
||||
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
|
||||
XSSProtection: "1; mode=block",
|
||||
// XSSProtection deliberately empty (audit L7): the X-XSS-Protection
|
||||
// header is deprecated and has itself caused XSS in legacy browsers.
|
||||
XSSProtection: "",
|
||||
ContentTypeNosniff: "nosniff",
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
HSTSMaxAge: 31536000,
|
||||
HSTSMaxAge: 63072000, // 2 years — preload-eligible (audit L5/CODE-L3)
|
||||
HSTSPreloadEnabled: true,
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
ContentSecurityPolicy: "default-src 'self'; " +
|
||||
"style-src 'self' https://fonts.googleapis.com; " +
|
||||
@@ -86,6 +89,8 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
"img-src 'self' data:; " +
|
||||
"script-src 'self'; " +
|
||||
"connect-src 'self'; " +
|
||||
"object-src 'none'; " + // audit L8 — disable plugins/embeds
|
||||
"base-uri 'self'; " + // audit L8 — block <base> hijacking
|
||||
"frame-ancestors 'none'",
|
||||
}))
|
||||
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
|
||||
@@ -136,9 +141,20 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
// labeled by route pattern, method, and status code.
|
||||
e.Use(prom.HTTPMiddleware())
|
||||
|
||||
// /metrics endpoint exposed for vmagent scrape. No auth — bound to
|
||||
// the cluster network only; not exposed via Cloudflare.
|
||||
e.GET("/metrics", prom.Handler())
|
||||
// /metrics endpoint for the in-cluster vmagent scrape (audit LIVE-L1).
|
||||
// vmagent scrapes api pods directly (pod-to-pod), so its requests carry
|
||||
// no X-Forwarded-For. Any request that DOES carry one reached us through
|
||||
// Traefik/Cloudflare — i.e. the public internet — and is refused with a
|
||||
// 404. The api pod port is not exposed outside the cluster, so a request
|
||||
// cannot reach /metrics without going through Traefik, and Traefik always
|
||||
// appends X-Forwarded-For — the check cannot be bypassed.
|
||||
metricsHandler := prom.Handler()
|
||||
e.GET("/metrics", func(c echo.Context) error {
|
||||
if c.Request().Header.Get("X-Forwarded-For") != "" {
|
||||
return echo.NewHTTPError(http.StatusNotFound)
|
||||
}
|
||||
return metricsHandler(c)
|
||||
})
|
||||
|
||||
// Serve landing page static files (if static directory is configured)
|
||||
staticDir := cfg.Server.StaticDir
|
||||
@@ -204,6 +220,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
// Wire Redis cache for residence-ID lookups across the four services that
|
||||
// read it on the request hot path. Cache is best-effort; nil cache is OK.
|
||||
if deps.Cache != nil {
|
||||
authService.SetCacheService(deps.Cache) // per-account login lockout (audit M5)
|
||||
residenceService.SetCacheService(deps.Cache)
|
||||
taskService.SetCacheService(deps.Cache)
|
||||
contractorService.SetCacheService(deps.Cache)
|
||||
@@ -316,7 +333,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
|
||||
protected.Use(custommiddleware.TimezoneMiddleware())
|
||||
{
|
||||
setupProtectedAuthRoutes(protected, authHandler)
|
||||
setupResidenceRoutes(protected, residenceHandler)
|
||||
setupResidenceRoutes(protected, residenceHandler, authMiddleware.RequireVerified())
|
||||
setupTaskRoutes(protected, taskHandler)
|
||||
setupSuggestionRoutes(protected, suggestionHandler)
|
||||
setupContractorRoutes(protected, contractorHandler)
|
||||
@@ -583,7 +600,7 @@ func setupPublicDataRoutes(api *echo.Group, residenceHandler *handlers.Residence
|
||||
}
|
||||
|
||||
// setupResidenceRoutes configures residence routes
|
||||
func setupResidenceRoutes(api *echo.Group, residenceHandler *handlers.ResidenceHandler) {
|
||||
func setupResidenceRoutes(api *echo.Group, residenceHandler *handlers.ResidenceHandler, requireVerified echo.MiddlewareFunc) {
|
||||
residences := api.Group("/residences")
|
||||
{
|
||||
residences.GET("/", residenceHandler.ListResidences)
|
||||
@@ -598,8 +615,11 @@ func setupResidenceRoutes(api *echo.Group, residenceHandler *handlers.ResidenceH
|
||||
residences.DELETE("/:id/", residenceHandler.DeleteResidence)
|
||||
|
||||
residences.GET("/:id/share-code/", residenceHandler.GetShareCode)
|
||||
residences.POST("/:id/generate-share-code/", residenceHandler.GenerateShareCode)
|
||||
residences.POST("/:id/generate-share-package/", residenceHandler.GenerateSharePackage)
|
||||
// Audit LIVE-L19: generating a residence share code requires a
|
||||
// verified email — it blocks bad-faith unverified signups from
|
||||
// minting share codes.
|
||||
residences.POST("/:id/generate-share-code/", residenceHandler.GenerateShareCode, requireVerified)
|
||||
residences.POST("/:id/generate-share-package/", residenceHandler.GenerateSharePackage, requireVerified)
|
||||
residences.POST("/:id/generate-tasks-report/", residenceHandler.GenerateTasksReport)
|
||||
residences.GET("/:id/users/", residenceHandler.GetResidenceUsers)
|
||||
residences.DELETE("/:id/users/:user_id/", residenceHandler.RemoveResidenceUser)
|
||||
|
||||
@@ -75,9 +75,9 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
|
||||
assert.Contains(t, resp.Message, "still valid")
|
||||
}
|
||||
|
||||
@@ -88,23 +88,25 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
|
||||
assert.Contains(t, resp.Message, "refreshed")
|
||||
|
||||
// Verify old token was deleted
|
||||
var count int64
|
||||
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "old token should be deleted")
|
||||
|
||||
// Verify new token exists in DB
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count)
|
||||
// resp.Token is the raw token; the DB stores its hash.
|
||||
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "new token should exist in DB")
|
||||
|
||||
// Verify new token belongs to the same user
|
||||
var newToken models.AuthToken
|
||||
require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error)
|
||||
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
|
||||
assert.Equal(t, user.ID, newToken.UserID)
|
||||
}
|
||||
|
||||
@@ -115,7 +117,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
@@ -130,9 +132,9 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
|
||||
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
|
||||
}
|
||||
|
||||
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
|
||||
@@ -155,7 +157,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
// Try to refresh with a different user ID
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
@@ -168,7 +170,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
|
||||
|
||||
svc := newTestAuthService(db)
|
||||
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
|
||||
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package services
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -36,13 +37,32 @@ var (
|
||||
ErrGoogleSignInFailed = errors.New("Google Sign In failed")
|
||||
)
|
||||
|
||||
// Per-account login lockout (audit M5, hardened per MEDIUM-3).
|
||||
const (
|
||||
// maxLoginFailureIPs is how many DISTINCT source IPs may fail to log in to
|
||||
// one account within the window before that account is locked. Counting
|
||||
// distinct IPs (not raw attempts) means a single attacker who knows a
|
||||
// victim's email cannot lock the victim out by spamming failures — only a
|
||||
// genuinely distributed credential-stuffing attack reaches this threshold.
|
||||
maxLoginFailureIPs = 5
|
||||
// loginLockWindow is how long the failed-IP set persists; it is refreshed
|
||||
// on each failure so an active attack keeps the window open.
|
||||
loginLockWindow = 15 * time.Minute
|
||||
)
|
||||
|
||||
// AuthService handles authentication business logic
|
||||
type AuthService struct {
|
||||
userRepo *repositories.UserRepository
|
||||
notificationRepo *repositories.NotificationRepository
|
||||
cache *CacheService
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// SetCacheService wires Redis for per-account login-failure tracking (M5).
|
||||
func (s *AuthService) SetCacheService(cache *CacheService) {
|
||||
s.cache = cache
|
||||
}
|
||||
|
||||
// NewAuthService creates a new auth service
|
||||
func NewAuthService(userRepo *repositories.UserRepository, cfg *config.Config) *AuthService {
|
||||
return &AuthService{
|
||||
@@ -56,34 +76,89 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N
|
||||
s.notificationRepo = notificationRepo
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a token
|
||||
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) {
|
||||
// dummyPasswordHash is a valid bcrypt hash used to keep login response time
|
||||
// constant when the account does not exist (audit LIVE-L11). It is computed
|
||||
// once at startup; the plaintext it hashes is irrelevant and never used.
|
||||
var dummyPasswordHash = func() string {
|
||||
h, err := bcrypt.GenerateFromPassword([]byte("honeydue-login-timing-equalizer"), models.BcryptCost)
|
||||
if err != nil {
|
||||
return "" // CompareHashAndPassword against "" always fails — safe
|
||||
}
|
||||
return string(h)
|
||||
}()
|
||||
|
||||
// freshToken mints a new auth token for the user and evicts any prior token's
|
||||
// Redis cache entry (audit MEDIUM-1). Without the eviction a re-login would
|
||||
// not actually kill a previously-issued token until the cache TTL lapsed — a
|
||||
// stolen token would keep working for up to 5 minutes after the victim
|
||||
// re-authenticates. A cache-eviction failure is logged, not fatal: the token
|
||||
// row is already gone, so the stale entry simply ages out on its own.
|
||||
func (s *AuthService) freshToken(ctx context.Context, userID uint) (*models.AuthToken, error) {
|
||||
token, oldHashes, err := s.userRepo.WithContext(ctx).CreateFreshToken(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.cache != nil && len(oldHashes) > 0 {
|
||||
if cErr := s.cache.InvalidateAuthTokenHashes(ctx, oldHashes...); cErr != nil {
|
||||
log.Warn().Err(cErr).Uint("user_id", userID).
|
||||
Msg("failed to evict prior auth-token cache entries on re-login")
|
||||
}
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a token. clientIP is the request's
|
||||
// source IP (echo c.RealIP()), used for the distributed-attack lockout.
|
||||
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest, clientIP string) (*responses.LoginResponse, error) {
|
||||
// Find user by username or email
|
||||
identifier := req.Username
|
||||
if identifier == "" {
|
||||
identifier = req.Email
|
||||
}
|
||||
lockKey := strings.ToLower(strings.TrimSpace(identifier))
|
||||
|
||||
// Audit M5 (hardened per MEDIUM-3): per-account lockout keyed on the set
|
||||
// of distinct source IPs that have failed. Once enough distinct IPs have
|
||||
// failed for one account within the window, reject — this still catches
|
||||
// distributed credential stuffing, without letting a single attacker lock
|
||||
// a victim out by spamming failed logins from one IP.
|
||||
if s.cache != nil && lockKey != "" {
|
||||
if n, cErr := s.cache.LoginFailureIPCount(ctx, lockKey); cErr == nil && n >= maxLoginFailureIPs {
|
||||
return nil, apperrors.TooManyRequests("error.too_many_login_attempts")
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier)
|
||||
if err != nil {
|
||||
if errors.Is(err, repositories.ErrUserNotFound) {
|
||||
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
||||
}
|
||||
if err != nil && !errors.Is(err, repositories.ErrUserNotFound) {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive {
|
||||
return nil, apperrors.Unauthorized("error.account_inactive")
|
||||
// Constant-time login (audit LIVE-L11): always run a bcrypt comparison,
|
||||
// even when the account does not exist or is inactive, so response
|
||||
// timing never reveals which emails are real accounts. Compare against
|
||||
// the user's hash when available, otherwise a fixed dummy hash.
|
||||
passwordHash := dummyPasswordHash
|
||||
if user != nil {
|
||||
passwordHash = user.Password
|
||||
}
|
||||
passwordOK := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)) == nil
|
||||
|
||||
// Verify password
|
||||
if !user.CheckPassword(req.Password) {
|
||||
// One generic error for not-found, inactive, and wrong-password
|
||||
// (audit L1) — none of them disclose which condition failed.
|
||||
if user == nil || !user.IsActive || !passwordOK {
|
||||
if s.cache != nil && lockKey != "" {
|
||||
_, _ = s.cache.RegisterLoginFailure(ctx, lockKey, clientIP, loginLockWindow)
|
||||
}
|
||||
return nil, apperrors.Unauthorized("error.invalid_credentials")
|
||||
}
|
||||
|
||||
// Successful authentication — clear the failure counter (audit M5).
|
||||
if s.cache != nil && lockKey != "" {
|
||||
_ = s.cache.ClearLoginFailures(ctx, lockKey)
|
||||
}
|
||||
|
||||
// Get or create auth token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -95,7 +170,7 @@ func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*r
|
||||
}
|
||||
|
||||
return &responses.LoginResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
}, nil
|
||||
}
|
||||
@@ -176,13 +251,13 @@ func (s *AuthService) Register(ctx context.Context, req *requests.RegisterReques
|
||||
}
|
||||
|
||||
// Create auth token (outside transaction since token generation is idempotent)
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, "", apperrors.Internal(err)
|
||||
}
|
||||
|
||||
return &responses.RegisterResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
Message: "Registration successful. Please check your email to verify your account.",
|
||||
}, code, nil
|
||||
@@ -243,7 +318,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID
|
||||
}
|
||||
|
||||
return &responses.RefreshTokenResponse{
|
||||
Token: newToken.Key,
|
||||
Token: newToken.Plaintext,
|
||||
Message: "Token refreshed successfully.",
|
||||
}, nil
|
||||
}
|
||||
@@ -390,26 +465,26 @@ func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find and validate confirmation code
|
||||
confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, repositories.ErrCodeNotFound) {
|
||||
// Audit M4: validate the code, consume it, and flip the verified flag in
|
||||
// one transaction so the three writes commit or roll back together.
|
||||
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
|
||||
confirmCode, err := txRepo.FindConfirmationCode(userID, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := txRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return txRepo.SetProfileVerified(userID, true)
|
||||
})
|
||||
if txErr != nil {
|
||||
if errors.Is(txErr, repositories.ErrCodeNotFound) {
|
||||
return apperrors.BadRequest("error.invalid_verification_code")
|
||||
}
|
||||
if errors.Is(err, repositories.ErrCodeExpired) {
|
||||
if errors.Is(txErr, repositories.ErrCodeExpired) {
|
||||
return apperrors.BadRequest("error.verification_code_expired")
|
||||
}
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Set profile as verified
|
||||
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
|
||||
return apperrors.Internal(err)
|
||||
return apperrors.Internal(txErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -476,7 +551,7 @@ func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string,
|
||||
expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry)
|
||||
|
||||
// Hash the code before storing
|
||||
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), models.BcryptCost)
|
||||
if err != nil {
|
||||
return "", nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -596,7 +671,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -605,7 +680,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -638,7 +713,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
|
||||
token, err := s.freshToken(ctx, existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -653,7 +728,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(existingUser),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -704,7 +779,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
// Create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -716,7 +791,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
|
||||
}
|
||||
|
||||
return &responses.AppleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: true,
|
||||
}, nil
|
||||
@@ -749,7 +824,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -758,7 +833,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -794,7 +869,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
// Get or create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
|
||||
token, err := s.freshToken(ctx, existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -809,7 +884,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(existingUser),
|
||||
IsNewUser: false,
|
||||
}, nil
|
||||
@@ -861,7 +936,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
// Create token
|
||||
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
|
||||
token, err := s.freshToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
@@ -873,7 +948,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
}
|
||||
|
||||
return &responses.GoogleSignInResponse{
|
||||
Token: token.Key,
|
||||
Token: token.Plaintext,
|
||||
User: responses.NewUserResponse(user),
|
||||
IsNewUser: true,
|
||||
}, nil
|
||||
@@ -882,14 +957,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
|
||||
// Helper functions
|
||||
|
||||
func generateSixDigitCode() string {
|
||||
b := make([]byte, 4)
|
||||
rand.Read(b)
|
||||
num := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||
if num < 0 {
|
||||
num = -num
|
||||
// Uniform 000000–999999 via rejection sampling on crypto/rand,
|
||||
// removing the modulo bias of `n % 1000000` (audit H4).
|
||||
for {
|
||||
var b [4]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
continue
|
||||
}
|
||||
// 4294000000 is the largest multiple of 1e6 <= MaxUint32.
|
||||
n := binary.BigEndian.Uint32(b[:])
|
||||
if n < 4294000000 {
|
||||
return fmt.Sprintf("%06d", n%1000000)
|
||||
}
|
||||
}
|
||||
code := num % 1000000
|
||||
return fmt.Sprintf("%06d", code)
|
||||
}
|
||||
|
||||
func generateResetToken() string {
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestAuthService_Login(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req)
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "testuser", resp.User.Username)
|
||||
@@ -75,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(context.Background(), req)
|
||||
resp, err := service.Login(context.Background(), req, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) {
|
||||
Password: "WrongPassword1",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
@@ -134,8 +134,10 @@ func TestAuthService_Login_InactiveUser(t *testing.T) {
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
// Audit L1: inactive accounts return the same generic error as bad
|
||||
// credentials so login does not disclose which accounts exist.
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === Register ===
|
||||
@@ -443,7 +445,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
|
||||
Username: "testuser",
|
||||
Password: "NewPassword123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq)
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, loginResp.Token)
|
||||
}
|
||||
@@ -472,7 +474,7 @@ func TestAuthService_Logout(t *testing.T) {
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
loginResp, err := service.Login(context.Background(), loginReq)
|
||||
loginResp, err := service.Login(context.Background(), loginReq, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Logout
|
||||
@@ -659,7 +661,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) {
|
||||
Password: "",
|
||||
}
|
||||
|
||||
_, err := service.Login(context.Background(), req)
|
||||
_, err := service.Login(context.Background(), req, "")
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// CacheService provides Redis caching functionality
|
||||
@@ -139,22 +140,25 @@ const (
|
||||
TokenCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// authTokenCacheKey returns the Redis key for an auth token. The raw token
|
||||
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
|
||||
func authTokenCacheKey(token string) string {
|
||||
return AuthTokenPrefix + models.HashToken(token)
|
||||
}
|
||||
|
||||
// CacheAuthToken caches a user ID for a token
|
||||
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
|
||||
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
|
||||
}
|
||||
|
||||
// GetCachedAuthToken gets a cached user ID for a token
|
||||
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
|
||||
key := AuthTokenPrefix + token
|
||||
val, err := c.GetString(ctx, key)
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -167,8 +171,7 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui
|
||||
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
|
||||
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
|
||||
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
|
||||
key := AuthTokenPrefix + token
|
||||
val, err := c.GetString(ctx, key)
|
||||
val, err := c.GetString(ctx, authTokenCacheKey(token))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
@@ -184,8 +187,62 @@ func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token
|
||||
|
||||
// InvalidateAuthToken removes a cached token
|
||||
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
|
||||
key := AuthTokenPrefix + token
|
||||
return c.Delete(ctx, key)
|
||||
return c.Delete(ctx, authTokenCacheKey(token))
|
||||
}
|
||||
|
||||
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
|
||||
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
|
||||
// stored hash directly — used to evict a user's prior token on re-login
|
||||
// (audit MEDIUM-1), where the server no longer has the plaintext.
|
||||
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
|
||||
keys := make([]string, 0, len(hashes))
|
||||
for _, h := range hashes {
|
||||
if h != "" {
|
||||
keys = append(keys, AuthTokenPrefix+h)
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.Delete(ctx, keys...)
|
||||
}
|
||||
|
||||
// --- Per-account login-failure tracking (audit M5) ---
|
||||
|
||||
const loginFailPrefix = "login_fail:"
|
||||
|
||||
// RegisterLoginFailure records a failed login for an account from a given
|
||||
// source IP, and returns the number of DISTINCT source IPs that have failed
|
||||
// for this account within the window. Tracking distinct IPs as a set rather
|
||||
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
|
||||
// run the count up and lock a victim out by knowing only their email — a
|
||||
// single IP is bounded by the per-IP edge/app rate limiters instead. A
|
||||
// genuinely distributed credential-stuffing attack still trips the lockout.
|
||||
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
|
||||
key := loginFailPrefix + identifier
|
||||
member := ip
|
||||
if member == "" {
|
||||
member = "unknown"
|
||||
}
|
||||
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Refresh the TTL on each failure: an active attack keeps the window
|
||||
// open, while a quiet account ages out `window` after its last failure.
|
||||
_ = c.client.Expire(ctx, key, window).Err()
|
||||
return c.client.SCard(ctx, key).Result()
|
||||
}
|
||||
|
||||
// LoginFailureIPCount returns how many distinct source IPs have failed to log
|
||||
// in to this account within the window (audit MEDIUM-3). SCard on a missing
|
||||
// key returns 0.
|
||||
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
|
||||
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
|
||||
}
|
||||
|
||||
// ClearLoginFailures resets the failed-login IP set after a successful login.
|
||||
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
|
||||
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
|
||||
}
|
||||
|
||||
// Static data cache helpers
|
||||
|
||||
@@ -296,9 +296,14 @@ func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, us
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Re-fetch the contractor to get the updated state with all relations
|
||||
// Re-fetch to get the updated state with all relations. Audit M12: if the
|
||||
// contractor was deleted concurrently between the toggle and this read,
|
||||
// surface a clean 404 instead of a 500.
|
||||
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, apperrors.NotFound("error.contractor_not_found")
|
||||
}
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// FileOwnershipService checks whether a user owns a file referenced by URL.
|
||||
// It queries task completion images, document files, and document images
|
||||
// to determine ownership through residence access.
|
||||
// FileOwnershipService checks whether a user has access to a file referenced
|
||||
// by URL. It queries task completion images, document files, and document
|
||||
// images, resolving access through residence ownership or membership.
|
||||
type FileOwnershipService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
@@ -17,16 +17,31 @@ func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
|
||||
return &FileOwnershipService{db: db}
|
||||
}
|
||||
|
||||
// IsFileOwnedByUser checks if the given file URL belongs to a record
|
||||
// that the user has access to (via residence membership).
|
||||
// accessibleResidenceIDs returns a subquery of residence IDs the user can
|
||||
// access: residences they own (residence_residence.owner_id) UNION residences
|
||||
// they are a member of (residence_residence_users).
|
||||
//
|
||||
// Audit C7: the previous queries joined residence_residence_users only, so a
|
||||
// residence owner who was not also a member of the join table could not pass
|
||||
// the ownership check for files in their own property.
|
||||
func (s *FileOwnershipService) accessibleResidenceIDs(userID uint) *gorm.DB {
|
||||
return s.db.Raw(`
|
||||
SELECT id FROM residence_residence WHERE owner_id = ?
|
||||
UNION
|
||||
SELECT residence_id FROM residence_residence_users WHERE user_id = ?
|
||||
`, userID, userID)
|
||||
}
|
||||
|
||||
// IsFileOwnedByUser checks if the given file URL belongs to a record in a
|
||||
// residence the user owns or is a member of.
|
||||
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
|
||||
// Check task completion images: image_url -> completion -> task -> residence -> user access
|
||||
// Task completion images: image_url -> completion -> task -> residence.
|
||||
var completionImageCount int64
|
||||
err := s.db.Model(&models.TaskCompletionImage{}).
|
||||
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
|
||||
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
|
||||
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Where("task_taskcompletionimage.image_url = ?", fileURL).
|
||||
Where("task_task.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
|
||||
Count(&completionImageCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -35,11 +50,11 @@ func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (b
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check document files: file_url -> document -> residence -> user access
|
||||
// Document files: file_url -> document -> residence.
|
||||
var documentCount int64
|
||||
err = s.db.Model(&models.Document{}).
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Where("task_document.file_url = ?", fileURL).
|
||||
Where("task_document.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
|
||||
Count(&documentCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -48,12 +63,12 @@ func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (b
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check document images: image_url -> document_image -> document -> residence -> user access
|
||||
// Document images: image_url -> document_image -> document -> residence.
|
||||
var documentImageCount int64
|
||||
err = s.db.Model(&models.DocumentImage{}).
|
||||
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
|
||||
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
|
||||
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
|
||||
Where("task_documentimage.image_url = ?", fileURL).
|
||||
Where("task_document.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
|
||||
Count(&documentImageCount).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -2,132 +2,306 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"
|
||||
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
|
||||
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
|
||||
googleKeysCacheTTL = 24 * time.Hour
|
||||
googleKeysCacheKey = "google:public_keys"
|
||||
)
|
||||
|
||||
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
|
||||
var googleIssuers = map[string]bool{
|
||||
"accounts.google.com": true,
|
||||
"https://accounts.google.com": true,
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
|
||||
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
|
||||
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
|
||||
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
|
||||
ErrGoogleKeyNotFound = errors.New("Google public key not found")
|
||||
)
|
||||
|
||||
// GoogleTokenInfo represents the response from Google's token info endpoint
|
||||
type GoogleTokenInfo struct {
|
||||
Sub string `json:"sub"` // Unique Google user ID
|
||||
Email string `json:"email"` // User's email
|
||||
EmailVerified string `json:"email_verified"` // "true" or "false"
|
||||
Name string `json:"name"` // Full name
|
||||
GivenName string `json:"given_name"` // First name
|
||||
FamilyName string `json:"family_name"` // Last name
|
||||
Picture string `json:"picture"` // Profile picture URL
|
||||
Aud string `json:"aud"` // Audience (client ID)
|
||||
Azp string `json:"azp"` // Authorized party
|
||||
Exp string `json:"exp"` // Expiration time
|
||||
Iss string `json:"iss"` // Issuer
|
||||
// GoogleJWKS represents Google's JSON Web Key Set.
|
||||
type GoogleJWKS struct {
|
||||
Keys []GoogleJWK `json:"keys"`
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified
|
||||
// GoogleJWK represents a single JSON Web Key from Google.
|
||||
type GoogleJWK struct {
|
||||
Kty string `json:"kty"` // Key type (RSA)
|
||||
Kid string `json:"kid"` // Key ID
|
||||
Use string `json:"use"` // Key use (sig)
|
||||
Alg string `json:"alg"` // Algorithm (RS256)
|
||||
N string `json:"n"` // RSA modulus
|
||||
E string `json:"e"` // RSA exponent
|
||||
}
|
||||
|
||||
// GoogleTokenClaims represents the claims in a Google ID token JWT.
|
||||
type GoogleTokenClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified bool `json:"email_verified,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
GivenName string `json:"given_name,omitempty"`
|
||||
FamilyName string `json:"family_name,omitempty"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Azp string `json:"azp,omitempty"` // Authorized party
|
||||
}
|
||||
|
||||
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
|
||||
type GoogleTokenInfo struct {
|
||||
Sub string // Unique Google user ID
|
||||
Email string
|
||||
EmailVerified string // "true" or "false" — string for caller compatibility
|
||||
Name string
|
||||
GivenName string
|
||||
FamilyName string
|
||||
Picture string
|
||||
Aud string
|
||||
Azp string
|
||||
Iss string
|
||||
}
|
||||
|
||||
// IsEmailVerified returns whether the email is verified.
|
||||
func (t *GoogleTokenInfo) IsEmailVerified() bool {
|
||||
return t.EmailVerified == "true"
|
||||
}
|
||||
|
||||
// GoogleAuthService handles Google Sign In token verification
|
||||
// GoogleAuthService handles Google Sign In token verification.
|
||||
type GoogleAuthService struct {
|
||||
cache *CacheService
|
||||
config *config.Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewGoogleAuthService creates a new Google auth service
|
||||
// NewGoogleAuthService creates a new Google auth service.
|
||||
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
|
||||
return &GoogleAuthService{
|
||||
cache: cache,
|
||||
config: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyIDToken verifies a Google ID token and returns the token info
|
||||
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
|
||||
// the RS256 signature against Google's published JWKS and the iss, aud, and
|
||||
// exp claims. It never sends the token to a third-party endpoint, so it no
|
||||
// longer depends on the deprecated tokeninfo service and never leaks the
|
||||
// token in a request URL.
|
||||
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
|
||||
// Call Google's tokeninfo endpoint to verify the token
|
||||
url := fmt.Sprintf("%s?id_token=%s", googleTokenInfoURL, idToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Parse the token header to get the key ID.
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
var tokenInfo GoogleTokenInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token info: %w", err)
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token header: %w", err)
|
||||
}
|
||||
var header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token header: %w", err)
|
||||
}
|
||||
|
||||
// Verify the audience matches our client ID(s)
|
||||
if !s.verifyAudience(tokenInfo.Aud, tokenInfo.Azp) {
|
||||
publicKey, err := s.getPublicKey(ctx, header.Kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
|
||||
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return publicKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
return nil, ErrGoogleTokenExpired
|
||||
}
|
||||
return nil, fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*GoogleTokenClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
// Verify the issuer (audit C3).
|
||||
if !googleIssuers[claims.Issuer] {
|
||||
return nil, ErrInvalidGoogleIssuer
|
||||
}
|
||||
|
||||
// Verify the audience matches one of our configured client IDs.
|
||||
if !s.verifyAudience(claims.Audience, claims.Azp) {
|
||||
return nil, ErrInvalidGoogleAudience
|
||||
}
|
||||
|
||||
// Verify the token is not expired (tokeninfo endpoint already checks this,
|
||||
// but we double-check for security)
|
||||
if tokenInfo.Sub == "" {
|
||||
if claims.Subject == "" {
|
||||
return nil, ErrInvalidGoogleToken
|
||||
}
|
||||
|
||||
return &tokenInfo, nil
|
||||
emailVerified := "false"
|
||||
if claims.EmailVerified {
|
||||
emailVerified = "true"
|
||||
}
|
||||
aud := ""
|
||||
if len(claims.Audience) > 0 {
|
||||
aud = claims.Audience[0]
|
||||
}
|
||||
return &GoogleTokenInfo{
|
||||
Sub: claims.Subject,
|
||||
Email: claims.Email,
|
||||
EmailVerified: emailVerified,
|
||||
Name: claims.Name,
|
||||
GivenName: claims.GivenName,
|
||||
FamilyName: claims.FamilyName,
|
||||
Picture: claims.Picture,
|
||||
Aud: aud,
|
||||
Azp: claims.Azp,
|
||||
Iss: claims.Issuer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// verifyAudience checks if the token audience matches our client ID(s).
|
||||
// In production (non-debug), an empty clientID causes verification to fail
|
||||
// rather than silently bypassing the check.
|
||||
func (s *GoogleAuthService) verifyAudience(aud, azp string) bool {
|
||||
// verifyAudience checks the token audience against our configured client IDs.
|
||||
// In production (non-debug) an empty client ID fails verification rather than
|
||||
// silently bypassing the check.
|
||||
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
|
||||
clientID := s.config.GoogleAuth.ClientID
|
||||
if clientID == "" {
|
||||
if s.config.Server.Debug {
|
||||
// In debug mode only, skip audience verification for local development
|
||||
// In debug mode only, skip audience verification for local development.
|
||||
return s.config.Server.Debug
|
||||
}
|
||||
|
||||
candidates := []string{clientID}
|
||||
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
if id := s.config.GoogleAuth.IOSClientID; id != "" {
|
||||
candidates = append(candidates, id)
|
||||
}
|
||||
|
||||
for _, want := range candidates {
|
||||
if azp == want {
|
||||
return true
|
||||
}
|
||||
// In production, missing client ID means we cannot verify the audience
|
||||
return false
|
||||
for _, aud := range audience {
|
||||
if aud == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check both aud and azp (Android vs iOS may use different values)
|
||||
if aud == clientID || azp == clientID {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check Android client ID if configured
|
||||
androidClientID := s.config.GoogleAuth.AndroidClientID
|
||||
if androidClientID != "" && (aud == androidClientID || azp == androidClientID) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check iOS client ID if configured
|
||||
iosClientID := s.config.GoogleAuth.IOSClientID
|
||||
if iosClientID != "" && (aud == iosClientID || azp == iosClientID) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getPublicKey returns the RSA public key for the given key ID, using a
|
||||
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
|
||||
// (Google rotates signing keys roughly daily).
|
||||
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
|
||||
keys, err := s.getCachedKeys(ctx)
|
||||
if err != nil || keys == nil {
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
|
||||
// Cache miss for this kid — keys may have rotated; fetch fresh.
|
||||
keys, err = s.fetchGooglePublicKeys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pubKey, ok := keys[kid]; ok {
|
||||
return pubKey, nil
|
||||
}
|
||||
return nil, ErrGoogleKeyNotFound
|
||||
}
|
||||
|
||||
// getCachedKeys retrieves cached Google public keys from Redis.
|
||||
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
if s.cache == nil {
|
||||
return nil, nil
|
||||
}
|
||||
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
|
||||
if err != nil || data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
|
||||
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
|
||||
}
|
||||
var jwks GoogleJWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
|
||||
}
|
||||
if s.cache != nil {
|
||||
keysJSON, _ := json.Marshal(jwks)
|
||||
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
|
||||
}
|
||||
return s.parseJWKS(&jwks), nil
|
||||
}
|
||||
|
||||
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
|
||||
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
|
||||
keys := make(map[string]*rsa.PublicKey)
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kty != "RSA" {
|
||||
continue
|
||||
}
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
e := 0
|
||||
for _, b := range eBytes {
|
||||
e = e<<8 + int(b)
|
||||
}
|
||||
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
@@ -68,13 +68,14 @@ type AppleTransactionInfo struct {
|
||||
|
||||
// AppleValidationResult contains the result of Apple receipt validation
|
||||
type AppleValidationResult struct {
|
||||
Valid bool
|
||||
TransactionID string
|
||||
ProductID string
|
||||
ExpiresAt time.Time
|
||||
IsTrialPeriod bool
|
||||
AutoRenewEnabled bool
|
||||
Environment string
|
||||
Valid bool
|
||||
TransactionID string
|
||||
OriginalTransactionID string // stable across renewals — the replay key
|
||||
ProductID string
|
||||
ExpiresAt time.Time
|
||||
IsTrialPeriod bool
|
||||
AutoRenewEnabled bool
|
||||
Environment string
|
||||
}
|
||||
|
||||
// GoogleValidationResult contains the result of Google token validation
|
||||
@@ -95,6 +96,21 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
|
||||
return nil, ErrIAPNotConfigured
|
||||
}
|
||||
|
||||
// Audit H5 (relaxed per MEDIUM-2): refuse to load the IAP signing key from
|
||||
// a world-accessible file — a leaked .p8 lets an attacker forge App Store
|
||||
// Server API requests. The original "0600 or stricter" check is
|
||||
// incompatible with a Kubernetes Secret volume: the kubelet widens secret
|
||||
// files to 0440 once fsGroup is set, so 0600 is unattainable for a
|
||||
// non-root container. Group access is scoped to the pod's fsGroup; the
|
||||
// real exposure is the "other" bits, so reject only those.
|
||||
info, err := os.Stat(cfg.KeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to stat Apple IAP key: %w", err)
|
||||
}
|
||||
if perm := info.Mode().Perm(); perm&0o007 != 0 {
|
||||
return nil, fmt.Errorf("Apple IAP key %s is world-accessible (permissions %#o); remove other-rwx bits", cfg.KeyPath, perm)
|
||||
}
|
||||
|
||||
// Read the private key
|
||||
keyData, err := os.ReadFile(cfg.KeyPath)
|
||||
if err != nil {
|
||||
@@ -215,11 +231,12 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
|
||||
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
|
||||
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
OriginalTransactionID: transactionInfo.OriginalTransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -243,11 +260,12 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
||||
if err == nil {
|
||||
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
OriginalTransactionID: transactionInfo.OriginalTransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
@@ -317,11 +335,12 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
|
||||
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
|
||||
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
Valid: true,
|
||||
TransactionID: transactionInfo.TransactionID,
|
||||
OriginalTransactionID: transactionInfo.OriginalTransactionID,
|
||||
ProductID: transactionInfo.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
Environment: transactionInfo.Environment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -418,13 +437,14 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
|
||||
}
|
||||
|
||||
return &AppleValidationResult{
|
||||
Valid: true,
|
||||
TransactionID: latestReceipt.TransactionID,
|
||||
ProductID: latestReceipt.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
IsTrialPeriod: latestReceipt.IsTrialPeriod == "true",
|
||||
AutoRenewEnabled: autoRenew,
|
||||
Environment: legacyResponse.Environment,
|
||||
Valid: true,
|
||||
TransactionID: latestReceipt.TransactionID,
|
||||
OriginalTransactionID: latestReceipt.OriginalTransactionID,
|
||||
ProductID: latestReceipt.ProductID,
|
||||
ExpiresAt: expiresAt,
|
||||
IsTrialPeriod: latestReceipt.IsTrialPeriod == "true",
|
||||
AutoRenewEnabled: autoRenew,
|
||||
Environment: legacyResponse.Environment,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -308,7 +308,18 @@ func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uin
|
||||
// Check if device exists
|
||||
existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID)
|
||||
if err == nil {
|
||||
// Update existing device
|
||||
// Audit C8 / LOW-3: APNs device tokens are recycled across devices,
|
||||
// app reinstalls and OS reassignments, so a token already bound to a
|
||||
// different account is a stale binding — not a hijack. Reassign it to
|
||||
// the current (authenticated) registrant rather than reject: a 409
|
||||
// here would lock the legitimate new owner of a recycled token out of
|
||||
// push entirely. The reassignment is logged as a security-relevant
|
||||
// event so a genuine token-takeover attempt is still traceable.
|
||||
if existing.UserID != nil && *existing.UserID != userID {
|
||||
log.Warn().Uint("user_id", userID).Uint("previous_owner_id", *existing.UserID).
|
||||
Msg("APNS device token reassigned to a new account")
|
||||
}
|
||||
// Update existing device — reassign to the current user
|
||||
existing.UserID = &userID
|
||||
existing.Active = true
|
||||
existing.Name = req.Name
|
||||
@@ -337,7 +348,18 @@ func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint
|
||||
// Check if device exists
|
||||
existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID)
|
||||
if err == nil {
|
||||
// Update existing device
|
||||
// Audit C8 / LOW-3: FCM device tokens are recycled across devices,
|
||||
// app reinstalls and OS reassignments, so a token already bound to a
|
||||
// different account is a stale binding — not a hijack. Reassign it to
|
||||
// the current (authenticated) registrant rather than reject: a 409
|
||||
// here would lock the legitimate new owner of a recycled token out of
|
||||
// push entirely. The reassignment is logged as a security-relevant
|
||||
// event so a genuine token-takeover attempt is still traceable.
|
||||
if existing.UserID != nil && *existing.UserID != userID {
|
||||
log.Warn().Uint("user_id", userID).Uint("previous_owner_id", *existing.UserID).
|
||||
Msg("GCM device token reassigned to a new account")
|
||||
}
|
||||
// Update existing device — reassign to the current user
|
||||
existing.UserID = &userID
|
||||
existing.Active = true
|
||||
existing.Name = req.Name
|
||||
|
||||
@@ -559,30 +559,22 @@ func (s *ResidenceService) GenerateSharePackage(ctx context.Context, residenceID
|
||||
}, nil
|
||||
}
|
||||
|
||||
// JoinWithCode allows a user to join a residence using a share code
|
||||
// JoinWithCode allows a user to join a residence using a share code.
|
||||
// Audit C9/H9: the code lookup, membership add, and one-time-code
|
||||
// deactivation run as a single locked transaction in the repository, so a
|
||||
// code can never be redeemed twice and a deactivation failure aborts the join.
|
||||
func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID uint) (*responses.JoinResidenceResponse, error) {
|
||||
// Find the share code
|
||||
shareCode, err := s.residenceRepo.WithContext(ctx).FindShareCodeByCode(code)
|
||||
residenceID, alreadyMember, err := s.residenceRepo.WithContext(ctx).JoinWithShareCode(code, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, apperrors.NotFound("error.share_code_invalid")
|
||||
}
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Check if already a member
|
||||
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(shareCode.ResidenceID, userID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if hasAccess {
|
||||
if alreadyMember {
|
||||
return nil, apperrors.Conflict("error.user_already_member")
|
||||
}
|
||||
|
||||
// Add user to residence
|
||||
if err := s.residenceRepo.WithContext(ctx).AddUser(shareCode.ResidenceID, userID); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if s.cache != nil {
|
||||
// The joining user's residence-IDs cache is now stale, and their
|
||||
// subscription status now reflects an extra residence with all of its
|
||||
@@ -591,15 +583,8 @@ func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
|
||||
}
|
||||
|
||||
// Mark share code as used (one-time use)
|
||||
if err := s.residenceRepo.WithContext(ctx).DeactivateShareCode(shareCode.ID); err != nil {
|
||||
// Log the error but don't fail the join - the user has already been added
|
||||
// The code will just be usable by others until it expires
|
||||
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate share code after join")
|
||||
}
|
||||
|
||||
// Get the residence with full details
|
||||
residence, err := s.residenceRepo.WithContext(ctx).FindByID(shareCode.ResidenceID)
|
||||
residence, err := s.residenceRepo.WithContext(ctx).FindByID(residenceID)
|
||||
if err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
@@ -399,99 +399,135 @@ func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID ui
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ProcessApplePurchase processes an Apple IAP purchase
|
||||
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
|
||||
// ProcessApplePurchase processes an Apple IAP purchase.
|
||||
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID).
|
||||
func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
|
||||
// Store receipt/transaction data
|
||||
dataToStore := receiptData
|
||||
if dataToStore == "" {
|
||||
dataToStore = transactionID
|
||||
}
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Apple IAP client must be configured to validate purchases.
|
||||
// Without server-side validation, we cannot trust client-provided receipts.
|
||||
// Apple IAP client must be configured — without server-side validation
|
||||
// we cannot trust client-provided receipts.
|
||||
if s.appleClient == nil {
|
||||
log.Error().Uint("user_id", userID).Msg("Apple IAP validation not configured, rejecting purchase")
|
||||
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// Validation is a network call to Apple; detach from the request context
|
||||
// so a client disconnect cannot abort an in-flight grant.
|
||||
vctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *AppleValidationResult
|
||||
var err error
|
||||
|
||||
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
|
||||
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1).
|
||||
if transactionID != "" {
|
||||
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
|
||||
result, err = s.appleClient.ValidateTransaction(vctx, transactionID)
|
||||
} else if receiptData != "" {
|
||||
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
|
||||
result, err = s.appleClient.ValidateReceipt(vctx, receiptData)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Validation failed -- do NOT fall through to grant Pro.
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Apple validation failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, apperrors.BadRequest("error.no_receipt_or_transaction")
|
||||
}
|
||||
|
||||
// Audit C5/C10: replay protection. A validated transaction may only ever
|
||||
// be bound to one account — re-submitting a valid receipt against a
|
||||
// second account must not grant Pro for free. The partial unique index
|
||||
// on apple_original_transaction_id is the backstop for the check/store
|
||||
// race below.
|
||||
if result.OriginalTransactionID != "" {
|
||||
existing, lookupErr := s.subscriptionRepo.WithContext(vctx).FindByAppleOriginalTransactionID(result.OriginalTransactionID)
|
||||
switch {
|
||||
case lookupErr == nil && existing != nil && existing.UserID != userID:
|
||||
log.Warn().Uint("user_id", userID).Uint("bound_user_id", existing.UserID).
|
||||
Msg("Apple purchase rejected — transaction already claimed by another account")
|
||||
return nil, apperrors.Forbidden("error.iap_transaction_already_claimed")
|
||||
case lookupErr != nil && !errors.Is(lookupErr, gorm.ErrRecordNotFound):
|
||||
return nil, apperrors.Internal(lookupErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Persist the receipt blob and the replay key.
|
||||
dataToStore := receiptData
|
||||
if dataToStore == "" {
|
||||
dataToStore = transactionID
|
||||
}
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpdateReceiptData(userID, dataToStore); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
if result.OriginalTransactionID != "" {
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpdateAppleOriginalTransactionID(userID, result.OriginalTransactionID); err != nil {
|
||||
// The unique index rejected the bind — a concurrent request
|
||||
// claimed the same transaction first.
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to bind Apple transaction ID")
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt := result.ExpiresAt
|
||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
|
||||
|
||||
// Upgrade to Pro with the validated expiration
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Tier flipped — drop cached SubscriptionStatusResponse so the next call
|
||||
// returns Pro immediately instead of stale Free.
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(vctx, userID)
|
||||
}
|
||||
|
||||
return s.GetSubscription(ctx, userID)
|
||||
return s.GetSubscription(vctx, userID)
|
||||
}
|
||||
|
||||
// ProcessGooglePurchase processes a Google Play purchase
|
||||
// productID is optional but helps validate the specific subscription
|
||||
func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
|
||||
// Store purchase token first
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Google IAP client must be configured to validate purchases.
|
||||
// Without server-side validation, we cannot trust client-provided tokens.
|
||||
// Google IAP client must be configured — without server-side validation
|
||||
// we cannot trust client-provided tokens.
|
||||
if s.googleClient == nil {
|
||||
log.Error().Uint("user_id", userID).Msg("Google IAP validation not configured, rejecting purchase")
|
||||
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
// Audit C6/C10: replay protection — a purchase token may only ever be
|
||||
// bound to one account. The partial unique index on google_purchase_token
|
||||
// is the backstop for the check/store race.
|
||||
if purchaseToken != "" {
|
||||
existing, lookupErr := s.subscriptionRepo.WithContext(ctx).FindByGoogleToken(purchaseToken)
|
||||
switch {
|
||||
case lookupErr == nil && existing != nil && existing.UserID != userID:
|
||||
log.Warn().Uint("user_id", userID).Uint("bound_user_id", existing.UserID).
|
||||
Msg("Google purchase rejected — token already claimed by another account")
|
||||
return nil, apperrors.Forbidden("error.iap_transaction_already_claimed")
|
||||
case lookupErr != nil && !errors.Is(lookupErr, gorm.ErrRecordNotFound):
|
||||
return nil, apperrors.Internal(lookupErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Store the purchase token (the replay key).
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
// Validation is a network call; detach from the request context.
|
||||
vctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var result *GoogleValidationResult
|
||||
var err error
|
||||
|
||||
// If productID is provided, use it directly; otherwise try known IDs
|
||||
// If productID is provided, use it directly; otherwise try known IDs.
|
||||
if productID != "" {
|
||||
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
|
||||
result, err = s.googleClient.ValidateSubscription(vctx, productID, purchaseToken)
|
||||
} else {
|
||||
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
|
||||
result, err = s.googleClient.ValidatePurchaseToken(vctx, purchaseToken, KnownSubscriptionIDs)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Validation failed -- do NOT fall through to grant Pro.
|
||||
log.Error().Err(err).Uint("user_id", userID).Msg("Google purchase validation failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return nil, apperrors.BadRequest("error.no_purchase_token")
|
||||
}
|
||||
@@ -499,24 +535,23 @@ func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID
|
||||
expiresAt := result.ExpiresAt
|
||||
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Bool("auto_renew", result.AutoRenewing).Msg("Google purchase validated")
|
||||
|
||||
// Acknowledge the subscription if not already acknowledged
|
||||
// Acknowledge the subscription if not already acknowledged.
|
||||
if !result.AcknowledgedState {
|
||||
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
|
||||
if err := s.googleClient.AcknowledgeSubscription(vctx, result.ProductID, purchaseToken); err != nil {
|
||||
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to acknowledge Google subscription")
|
||||
// Don't fail the purchase, just log the warning
|
||||
// Don't fail the purchase, just log the warning.
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade to Pro with the validated expiration
|
||||
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
||||
if err := s.subscriptionRepo.WithContext(vctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
|
||||
return nil, apperrors.Internal(err)
|
||||
}
|
||||
|
||||
if s.cache != nil {
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
|
||||
_ = s.cache.InvalidateSubscriptionStatusForUsers(vctx, userID)
|
||||
}
|
||||
|
||||
return s.GetSubscription(ctx, userID)
|
||||
return s.GetSubscription(vctx, userID)
|
||||
}
|
||||
|
||||
// CancelSubscription cancels a subscription (downgrades to free at end of period)
|
||||
|
||||
Reference in New Issue
Block a user