fix(security): remediate 2026-05-12 audit findings (Stages 2–5)
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled
Backend CI / Build (push) Has been cancelled

Remediation of the 2026-05-12/13 audits (78 findings + cluster gaps),
tracked in deploy-k3s/SECURITY.md, plus fixes from two independent
post-remediation reviews.

Auth & sessions:
- SHA-256 hashed auth-token storage (C1); prior-token cache eviction on
  re-login (MEDIUM-1)
- local Google JWKS verification, iss/aud/exp checks (C2/C3)
- constant-time login + generic errors (L1/LIVE-L11/LIVE-L13)
- per-account login lockout keyed on distinct source IPs (M5/MEDIUM-3)
- verified-email gating, login rate limiting (LIVE-L19, H1-H3)

IAP & webhooks:
- Apple/Google cross-account replay protection (C5/C6/C10/C13, H5/H6)
- migrations 000003-000006 (token hashing, IAP replay, audit_log +
  webhook_event_log table creation, append-only audit log)

Authorization & races:
- file-ownership owner-OR-member fix (C7), atomic share-code join
  (C9/H9), device-token reassignment (C8/LOW-3)

Secrets & deploy:
- secrets file-mounted at /etc/honeydue/secrets, not env (F8); Redis
  password out of the ConfigMap (HIGH-1); B2 keys reconciled
- digest-pinned images, admin ingress hardening, CSP/HSTS, /metrics
  lockdown; kubeconfig 0600, etcd secrets-encryption, fail2ban +
  unattended-upgrades at provision; secret-rotation runbook

Build, vet, and the full test suite (incl. -race) pass; the goose
migration chain is verified against PostgreSQL 16.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-05-16 22:28:33 -05:00
parent 2004f9c5b2
commit c77ff07ce9
59 changed files with 2819 additions and 1245 deletions
+70 -5
View File
@@ -1,6 +1,7 @@
package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/url"
@@ -216,6 +217,11 @@ func Load() (*Config, error) {
// Set defaults
setDefaults()
// Audit F8: overlay file-mounted secrets onto Viper. No-op when the
// directory is absent (local/dev), so this is safe to ship before the
// manifests mount honeydue-secrets as a volume.
loadFileSecrets()
// Parse DATABASE_URL if set (Dokku-style)
dbConfig := DatabaseConfig{
Host: viper.GetString("DB_HOST"),
@@ -432,14 +438,67 @@ func isWeakSecretKey(key string) bool {
return knownWeakSecretKeys[strings.ToLower(strings.TrimSpace(key))]
}
// loadFileSecrets overlays file-mounted secrets onto Viper (audit F8). When
// the honeydue-secrets Secret is mounted as a volume at /etc/honeydue/secrets
// each key is a file; reading the value here and viper.Set-ing it (highest
// Viper precedence) keeps the secret out of the process environment
// (/proc/<pid>/environ), which plain env-var injection cannot. When the
// directory is absent it is a silent no-op and env vars are used as before.
func loadFileSecrets() {
dir := os.Getenv("HONEYDUE_SECRETS_DIR")
if dir == "" {
dir = "/etc/honeydue/secrets"
}
for _, k := range []string{
"POSTGRES_PASSWORD", "SECRET_KEY", "EMAIL_HOST_PASSWORD", "FCM_SERVER_KEY",
"REDIS_PASSWORD", "B2_KEY_ID", "B2_APP_KEY", "OBS_INGEST_TOKEN", "OBS_TRACES_URL",
} {
b, err := os.ReadFile(dir + "/" + k)
if err != nil {
continue
}
if v := strings.TrimSpace(string(b)); v != "" {
viper.Set(k, v)
}
}
}
// SecretValue resolves a configuration value that is not part of the typed
// Config struct. It reads through Viper, so a value supplied via a file-mounted
// secret (audit F8, loaded by loadFileSecrets) is found just like an env var.
//
// Must be called after Load(). Used by cmd/api and cmd/worker for the
// observability endpoints, which are needed before the full Config is wired
// and would otherwise be read with os.Getenv — which misses file-mounted
// secrets entirely once F8 removes them from the process environment.
func SecretValue(key string) string {
return viper.GetString(key)
}
// randomHexKey returns a cryptographically secure random hex string
// representing n random bytes (2n hex characters).
func randomHexKey(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
func validate(cfg *Config) error {
// S-08: Validate SECRET_KEY against known weak defaults
// M8: SECRET_KEY validation — no static fallback secret in the binary.
if cfg.Security.SecretKey == "" {
if cfg.Server.Debug {
// In debug mode, use a default key with a warning for local development
cfg.Security.SecretKey = "change-me-in-production-secret-key-12345"
fmt.Println("WARNING: SECRET_KEY not set, using default (debug mode only)")
fmt.Println("WARNING: *** DO NOT USE THIS DEFAULT KEY IN PRODUCTION ***")
// Debug only: generate a random key per boot. Tokens signed with
// it do not survive a restart, which is acceptable for local dev
// and far safer than a well-known hardcoded fallback.
randomKey, err := randomHexKey(32)
if err != nil {
return fmt.Errorf("failed to generate ephemeral debug SECRET_KEY: %w", err)
}
cfg.Security.SecretKey = randomKey
fmt.Println("WARNING: SECRET_KEY not set, generated an ephemeral random key (debug mode only)")
fmt.Println("WARNING: tokens will not survive a restart — set SECRET_KEY for stable local sessions")
} else {
// In production, refuse to start without a proper secret key
return fmt.Errorf("FATAL: SECRET_KEY environment variable is required in production (DEBUG=false)")
@@ -452,6 +511,12 @@ func validate(cfg *Config) error {
}
}
// C4: fixed confirmation codes ("123456") must never be enabled outside
// debug — with DEBUG=false they are a full authentication bypass.
if cfg.Server.DebugFixedCodes && !cfg.Server.Debug {
return fmt.Errorf("FATAL: DEBUG_FIXED_CODES is enabled with DEBUG=false — fixed confirmation codes must never run in production")
}
// Database password might come from DATABASE_URL, don't require it separately
// The actual connection will fail if credentials are wrong
+31 -2
View File
@@ -106,8 +106,10 @@ func TestLoad_Validation_MissingSecretKey_DebugMode(t *testing.T) {
c, err := Load()
require.NoError(t, err)
// In debug mode, a default key is assigned
assert.Equal(t, "change-me-in-production-secret-key-12345", c.Security.SecretKey)
// Audit M8: in debug mode an ephemeral random key is generated per boot
// (no static fallback). It must be a non-empty 64-char hex string.
assert.Len(t, c.Security.SecretKey, 64)
assert.NotEqual(t, "change-me-in-production-secret-key-12345", c.Security.SecretKey)
}
func TestLoad_Validation_WeakSecretKey_Production(t *testing.T) {
@@ -133,6 +135,33 @@ func TestLoad_Validation_WeakSecretKey_DebugMode(t *testing.T) {
assert.Equal(t, "secret", c.Security.SecretKey)
}
// Audit C4: DEBUG_FIXED_CODES makes confirmation codes a fixed "123456" — a
// full authentication bypass. With DEBUG=false, validate() must refuse to boot
// rather than ship that bypass to production.
func TestLoad_Validation_DebugFixedCodes_Production(t *testing.T) {
// validate() directly — avoids the sync.Once issue Load() has on failure.
cfg := &Config{
Server: ServerConfig{Debug: false, DebugFixedCodes: true},
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
}
err := validate(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "DEBUG_FIXED_CODES")
}
// With DEBUG=true the fixed codes are an intended local-dev convenience, so
// the same combination must NOT error.
func TestLoad_Validation_DebugFixedCodes_DebugMode(t *testing.T) {
cfg := &Config{
Server: ServerConfig{Debug: true, DebugFixedCodes: true},
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
}
err := validate(cfg)
require.NoError(t, err)
}
func TestLoad_Validation_EncryptionKey_Valid(t *testing.T) {
resetConfigState()
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
+45 -28
View File
@@ -1,6 +1,7 @@
package handlers
import (
"context"
"errors"
"net/http"
@@ -55,8 +56,15 @@ func (h *AuthHandler) SetAuditService(auditService *services.AuditService) {
h.auditService = auditService
}
// noStore marks a response as non-cacheable (audit L2) — auth responses
// carry tokens and user data that must never sit in any cache.
func noStore(c echo.Context) {
c.Response().Header().Set("Cache-Control", "no-store")
}
// Login handles POST /api/auth/login/
func (h *AuthHandler) Login(c echo.Context) error {
noStore(c)
var req requests.LoginRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
@@ -65,9 +73,11 @@ func (h *AuthHandler) Login(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
response, err := h.authService.Login(c.Request().Context(), &req)
response, err := h.authService.Login(c.Request().Context(), &req, c.RealIP())
if err != nil {
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
log.Debug().Err(err).Str("identifier", req.Username).
Str("ip", c.RealIP()).Str("user_agent", c.Request().UserAgent()).
Msg("Login failed")
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventLoginFailed, map[string]interface{}{
"identifier": req.Username,
@@ -86,6 +96,7 @@ func (h *AuthHandler) Login(c echo.Context) error {
// Register handles POST /api/auth/register/
func (h *AuthHandler) Register(c echo.Context) error {
noStore(c)
var req requests.RegisterRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
@@ -157,6 +168,7 @@ func (h *AuthHandler) Logout(c echo.Context) error {
// CurrentUser handles GET /api/auth/me/
func (h *AuthHandler) CurrentUser(c echo.Context) error {
noStore(c)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
@@ -276,31 +288,7 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
return c.JSON(http.StatusBadRequest, validator.FormatValidationErrors(err))
}
code, user, err := h.authService.ForgotPassword(c.Request().Context(), req.Email)
if err != nil {
var appErr *apperrors.AppError
if errors.As(err, &appErr) && appErr.Code == http.StatusTooManyRequests {
// Only reveal rate limit errors
return err
}
log.Error().Err(err).Str("email", req.Email).Msg("Forgot password failed")
// Don't reveal other errors to prevent email enumeration
}
// Send password reset email (async) - only if user found
if h.emailService != nil && code != "" && user != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in password reset email goroutine")
}
}()
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
}
}()
}
noStore(c)
if h.auditService != nil {
h.auditService.LogEvent(c, nil, services.AuditEventPasswordReset, map[string]interface{}{
@@ -308,7 +296,33 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
})
}
// Always return success to prevent email enumeration
// Audit LIVE-L13: run the user lookup, code generation, and email send
// entirely in the background, then return the generic response
// immediately. This makes the response time identical whether or not
// the email belongs to a real account, defeating timing-based user
// enumeration. context.Background() is used because the request context
// is cancelled the moment this handler returns. Per-account rate
// limiting still runs inside the service; the edge auth-rate-limit
// middleware covers per-IP abuse.
email := req.Email
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", email).Msg("Panic in forgot-password goroutine")
}
}()
code, user, err := h.authService.ForgotPassword(context.Background(), email)
if err != nil || code == "" || user == nil {
return
}
if h.emailService != nil {
if sendErr := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); sendErr != nil {
log.Error().Err(sendErr).Str("email", user.Email).Msg("Failed to send password reset email")
}
}
}()
// Always return success to prevent email enumeration.
return c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
Message: "Password reset email sent",
})
@@ -365,6 +379,7 @@ func (h *AuthHandler) ResetPassword(c echo.Context) error {
// AppleSignIn handles POST /api/auth/apple-sign-in/
func (h *AuthHandler) AppleSignIn(c echo.Context) error {
noStore(c)
var req requests.AppleSignInRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
@@ -412,6 +427,7 @@ func (h *AuthHandler) AppleSignIn(c echo.Context) error {
// GoogleSignIn handles POST /api/auth/google-sign-in/
func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
noStore(c)
var req requests.GoogleSignInRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
@@ -459,6 +475,7 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
// RefreshToken handles POST /api/auth/refresh/
func (h *AuthHandler) RefreshToken(c echo.Context) error {
noStore(c)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
+2 -2
View File
@@ -650,14 +650,14 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
authGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("auth_user", user)
c.Set("auth_token", authToken.Key)
c.Set("auth_token", authToken.Plaintext) // raw token — repo hashes for lookup (audit C1)
return next(c)
}
})
authGroup.POST("/refresh/", handler.RefreshToken)
t.Run("successful refresh", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Key)
w := testutil.MakeRequest(e, "POST", "/api/auth/refresh/", nil, authToken.Plaintext)
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
+20 -3
View File
@@ -37,6 +37,23 @@ func NewMediaHandler(
}
}
// safeContentDisposition builds an inline Content-Disposition header value
// with a sanitized filename (audit M1). Control characters (including CR/LF),
// double-quote and backslash are stripped so an attacker-controlled upload
// filename cannot inject additional response headers (CWE-113).
func safeContentDisposition(filename string) string {
cleaned := strings.Map(func(r rune) rune {
if r < 0x20 || r == 0x7f || r == '"' || r == '\\' {
return -1
}
return r
}, filename)
if cleaned == "" {
cleaned = "download"
}
return `inline; filename="` + cleaned + `"`
}
// ServeDocument serves a document file with access control
// GET /api/media/document/:id
func (h *MediaHandler) ServeDocument(c echo.Context) error {
@@ -71,7 +88,7 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
// Set caching and disposition headers
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
if doc.FileName != "" {
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+doc.FileName+"\"")
c.Response().Header().Set("Content-Disposition", safeContentDisposition(doc.FileName))
}
return c.Blob(http.StatusOK, mimeType, data)
}
@@ -114,7 +131,7 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
}
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"")
c.Response().Header().Set("Content-Disposition", safeContentDisposition(filepath.Base(img.ImageURL)))
return c.Blob(http.StatusOK, mimeType, data)
}
@@ -162,7 +179,7 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
}
c.Response().Header().Set("Cache-Control", "private, max-age=3600")
c.Response().Header().Set("Content-Disposition", "inline; filename=\""+filepath.Base(img.ImageURL)+"\"")
c.Response().Header().Set("Content-Disposition", safeContentDisposition(filepath.Base(img.ImageURL)))
return c.Blob(http.StatusOK, mimeType, data)
}
@@ -8,6 +8,7 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
@@ -20,6 +21,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
@@ -165,9 +167,13 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
if notification.NotificationUUID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
if err != nil {
log.Error().Err(err).Msg("Apple Webhook: Failed to check dedup")
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
// Audit H6: fail closed. A dedup-check failure must not let a
// possibly-duplicate event through (duplicate refunds/grants).
// Return 500 so Apple retries once the DB is healthy.
log.Error().Err(err).Msg("Apple Webhook: dedup check failed — returning 500 for retry")
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "dedup check failed"})
}
if alreadyProcessed {
log.Info().Str("uuid", notification.NotificationUUID).Msg("Apple Webhook: Duplicate event, skipping")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
@@ -352,10 +358,24 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
}
func (h *SubscriptionWebhookHandler) findUserByAppleTransaction(originalTransactionID string) (*models.User, error) {
// Look up user subscription by stored receipt data
subscription, err := h.subscriptionRepo.FindByAppleReceiptContains(originalTransactionID)
// Audit C13: exact match on the indexed apple_original_transaction_id
// column. Falls back to the legacy escaped-LIKE scan over
// apple_receipt_data only for subscriptions created before that column
// existed (and thus not yet populated).
subscription, err := h.subscriptionRepo.FindByAppleOriginalTransactionID(originalTransactionID)
if err != nil {
return nil, err
// Only fall back to the legacy substring scan when the exact-match
// column genuinely had no row (a subscription created before the
// column existed). A real DB error must propagate — masking it as
// "not found" could bind the webhook to the wrong account via the
// LIKE scan, or silently drop a legitimate event.
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
subscription, err = h.subscriptionRepo.FindByAppleReceiptContains(originalTransactionID)
if err != nil {
return nil, err
}
}
user, err := h.userRepo.FindByID(subscription.UserID)
@@ -566,9 +586,12 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
if messageID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
if err != nil {
log.Error().Err(err).Msg("Google Webhook: Failed to check dedup")
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
// Audit H6: fail closed — see the Apple handler. Return 500 so
// Google Pub/Sub redelivers once the DB is healthy.
log.Error().Err(err).Msg("Google Webhook: dedup check failed — returning 500 for retry")
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "dedup check failed"})
}
if alreadyProcessed {
log.Info().Str("message_id", messageID).Msg("Google Webhook: Duplicate event, skipping")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
+29 -1
View File
@@ -169,6 +169,34 @@ func (m *AuthMiddleware) OptionalTokenAuth() echo.MiddlewareFunc {
}
}
// RequireVerified returns middleware that rejects users whose email is not
// verified (audit LIVE-L19). Apply it after TokenAuth to gate sensitive
// actions — e.g. generating residence share codes — behind proof that the
// account actually controls its email address.
func (m *AuthMiddleware) RequireVerified() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
user := GetAuthUser(c)
if user == nil {
return apperrors.Unauthorized("error.not_authenticated")
}
var verified bool
err := m.db.WithContext(c.Request().Context()).
Model(&models.UserProfile{}).
Where("user_id = ?", user.ID).
Select("verified").
Scan(&verified).Error
if err != nil {
return apperrors.Internal(err)
}
if !verified {
return apperrors.Forbidden("error.email_not_verified")
}
return next(c)
}
}
}
// extractToken extracts the token from the Authorization header
func extractToken(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
@@ -297,7 +325,7 @@ func (m *AuthMiddleware) getUserFromDatabaseWithToken(ctx context.Context, token
u.last_login AS u_last_login
`).
Joins("INNER JOIN auth_user u ON u.id = t.user_id").
Where("t.key = ?", token).
Where("t.key = ?", models.HashToken(token)). // audit C1: only the hash is stored
Limit(1).
Scan(&row).Error
if err != nil || row.Key == "" {
+3 -3
View File
@@ -65,7 +65,7 @@ func TestTokenAuth_RejectsExpiredToken(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -86,7 +86,7 @@ func TestTokenAuth_AcceptsValidToken(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -112,7 +112,7 @@ func TestTokenAuth_AcceptsTokenAtBoundary(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
+7 -7
View File
@@ -21,7 +21,7 @@ func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Bearer "+token.Key)
req.Header.Set("Authorization", "Bearer "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -46,7 +46,7 @@ func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Basic "+token.Key)
req.Header.Set("Authorization", "Basic "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -110,7 +110,7 @@ func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -156,7 +156,7 @@ func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -182,7 +182,7 @@ func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -242,7 +242,7 @@ func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
@@ -270,7 +270,7 @@ func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
req.Header.Set("Authorization", "Token "+token.Key)
req.Header.Set("Authorization", "Token "+token.Plaintext)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
+14 -12
View File
@@ -99,21 +99,23 @@ func parseTimezone(tz string) *time.Location {
return loc
}
// Try parsing as UTC offset (e.g., "-08:00", "+05:30")
// We parse a reference time with the given offset to extract the offset value
t, err := time.Parse("-07:00", tz)
if err == nil {
// time.Parse returns a time, we need to extract the offset
// The parsed time will have the offset embedded
_, offset := t.Zone()
return time.FixedZone(tz, offset)
// Try parsing as a UTC offset (e.g., "-08:00", "+05:30"). Audit H8:
// reject absurd offsets — real timezones are within ±14h of UTC — so a
// crafted X-Timezone header cannot shift date math arbitrarily.
const maxOffsetSeconds = 14 * 3600
if t, err := time.Parse("-07:00", tz); err == nil {
if _, offset := t.Zone(); offset >= -maxOffsetSeconds && offset <= maxOffsetSeconds {
return time.FixedZone(tz, offset)
}
return time.UTC
}
// Also try without colon (e.g., "-0800")
t, err = time.Parse("-0700", tz)
if err == nil {
_, offset := t.Zone()
return time.FixedZone(tz, offset)
if t, err := time.Parse("-0700", tz); err == nil {
if _, offset := t.Zone(); offset >= -maxOffsetSeconds && offset <= maxOffsetSeconds {
return time.FixedZone(tz, offset)
}
return time.UTC
}
// Default to UTC
+2 -1
View File
@@ -252,7 +252,8 @@ func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) {
require.NoError(t, err)
assert.NotEmpty(t, token.Key)
assert.Len(t, token.Key, 40) // 20 bytes = 40 hex chars
assert.Len(t, token.Key, 64) // SHA-256 hex hash (audit C1)
assert.Len(t, token.Plaintext, 40) // raw 20-byte token, returned to the client
assert.False(t, token.Created.IsZero())
}
+3
View File
@@ -43,6 +43,9 @@ type UserSubscription struct {
// In-App Purchase data (Apple / Google)
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
// AppleOriginalTransactionID binds an Apple subscription to one account
// (audit C5/C13). A partial unique index enforces one-account-per-txn.
AppleOriginalTransactionID *string `gorm:"column:apple_original_transaction_id;type:text" json:"-"`
// Stripe data (web subscriptions)
StripeCustomerID *string `gorm:"column:stripe_customer_id;size:255" json:"-"`
+55 -20
View File
@@ -2,7 +2,10 @@ package models
import (
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"time"
"golang.org/x/crypto/bcrypt"
@@ -37,14 +40,16 @@ func (User) TableName() string {
return "auth_user"
}
// BcryptCost is the bcrypt work factor for password and code hashing.
// 12 (audit M2) is stronger than bcrypt.DefaultCost (10).
const BcryptCost = 12
// SetPassword hashes and sets the password
func (u *User) SetPassword(password string) error {
// Django uses PBKDF2_SHA256 by default, but we'll use bcrypt for Go
// Note: This means passwords set by Django won't work with Go's check
// For migration, you'd need to either:
// 1. Force password reset for all users
// 2. Implement Django's PBKDF2 hasher in Go
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
// Django uses PBKDF2_SHA256 by default, but we use bcrypt for Go.
// Passwords set by Django won't verify with Go's bcrypt check — those
// users must reset their password after migration.
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
if err != nil {
return err
}
@@ -69,12 +74,22 @@ func (u *User) GetFullName() string {
return u.Username
}
// AuthToken represents the user_authtoken table
// AuthToken represents the user_authtoken table.
//
// Audit C1: the Key column stores the SHA-256 hash of the token, never the
// token itself. The raw token is handed to the client exactly once, at
// creation, via the non-persisted Plaintext field — it is never stored or
// logged. A database compromise therefore yields no usable session tokens.
type AuthToken struct {
Key string `gorm:"column:key;primaryKey;size:40" json:"key"`
Key string `gorm:"column:key;primaryKey;size:64" json:"-"`
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
Created time.Time `gorm:"column:created;autoCreateTime" json:"created"`
// Plaintext is the raw token value. It is never persisted (gorm:"-")
// and is only populated on a freshly-created token so the caller can
// return it to the client. On a token loaded from the DB it is "".
Plaintext string `gorm:"-" json:"-"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
@@ -84,10 +99,13 @@ func (AuthToken) TableName() string {
return "user_authtoken"
}
// BeforeCreate generates a token key if not provided
// BeforeCreate generates a token if one is not already set, storing only
// its hash in Key and the raw value in the non-persisted Plaintext field.
func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
if t.Key == "" {
t.Key = generateToken()
raw := generateToken()
t.Plaintext = raw
t.Key = HashToken(raw)
}
if t.Created.IsZero() {
t.Created = time.Now().UTC()
@@ -95,13 +113,23 @@ func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
return nil
}
// generateToken creates a random 40-character hex token
// generateToken creates a random 40-character hex token (the raw value).
func generateToken() string {
b := make([]byte, 20)
rand.Read(b)
return hex.EncodeToString(b)
}
// HashToken returns the at-rest representation of an auth token: the
// hex-encoded SHA-256 hash. Auth tokens are 160-bit random values, so a
// fast deterministic hash is appropriate — there is nothing to brute-force,
// and determinism preserves the single indexed-lookup query in the auth
// middleware. The raw token is never stored.
func HashToken(raw string) string {
sum := sha256.Sum256([]byte(raw))
return hex.EncodeToString(sum[:])
}
// GetOrCreate gets an existing token or creates a new one for the user
func GetOrCreateToken(tx *gorm.DB, userID uint) (*AuthToken, error) {
var token AuthToken
@@ -160,15 +188,22 @@ func (c *ConfirmationCode) IsValid() bool {
return !c.IsUsed && time.Now().UTC().Before(c.ExpiresAt)
}
// GenerateCode creates a random 6-digit code
// GenerateConfirmationCode creates a uniformly-random 6-digit code using
// rejection sampling on crypto/rand (audit H4 — removes the modulo bias of
// the previous implementation).
func GenerateConfirmationCode() string {
b := make([]byte, 3)
rand.Read(b)
// Convert to 6-digit number
num := int(b[0])<<16 | int(b[1])<<8 | int(b[2])
return string(rune('0'+num%10)) + string(rune('0'+(num/10)%10)) +
string(rune('0'+(num/100)%10)) + string(rune('0'+(num/1000)%10)) +
string(rune('0'+(num/10000)%10)) + string(rune('0'+(num/100000)%10))
for {
var b [4]byte
if _, err := rand.Read(b[:]); err != nil {
continue
}
// 4294000000 is the largest multiple of 1e6 <= MaxUint32; rejecting
// the tail above it makes n % 1000000 perfectly uniform.
n := binary.BigEndian.Uint32(b[:])
if n < 4294000000 {
return fmt.Sprintf("%06d", n%1000000)
}
}
}
// PasswordResetCode represents the user_passwordresetcode table
@@ -193,7 +228,7 @@ func (PasswordResetCode) TableName() string {
// SetCode hashes and stores the reset code
func (p *PasswordResetCode) SetCode(code string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
hash, err := bcrypt.GenerateFromPassword([]byte(code), BcryptCost)
if err != nil {
return err
}
+55
View File
@@ -9,6 +9,7 @@ import (
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/treytartt/honeydue-api/internal/models"
)
@@ -194,6 +195,60 @@ func (r *ResidenceRepository) HasAccess(residenceID, userID uint) (bool, error)
return count > 0, nil
}
// JoinWithShareCode atomically redeems a one-time share code (audit C9/H9):
// it locks the share-code row, re-checks validity under the lock, adds the
// user to the residence, and deactivates the code — all in one transaction.
// Concurrent redemptions of the same code serialize on the row lock; the
// loser sees is_active=false and is rejected. A failure to deactivate aborts
// the whole join. Returns gorm.ErrRecordNotFound for an unknown, inactive, or
// expired code so the caller can map every case to one generic error.
func (r *ResidenceRepository) JoinWithShareCode(code string, userID uint) (residenceID uint, alreadyMember bool, err error) {
err = r.db.Transaction(func(tx *gorm.DB) error {
var sc models.ResidenceShareCode
if e := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Where("code = ?", code).First(&sc).Error; e != nil {
return e
}
if !sc.IsActive || (sc.ExpiresAt != nil && time.Now().UTC().After(*sc.ExpiresAt)) {
return gorm.ErrRecordNotFound
}
residenceID = sc.ResidenceID
// Already a member (owner or shared user)?
var accessCount int64
if e := tx.Raw(`
SELECT COUNT(*) FROM (
SELECT 1 FROM residence_residence
WHERE id = ? AND owner_id = ? AND is_active = true
UNION
SELECT 1 FROM residence_residence_users
WHERE residence_id = ? AND user_id = ?
) ac
`, sc.ResidenceID, userID, sc.ResidenceID, userID).Scan(&accessCount).Error; e != nil {
return e
}
if accessCount > 0 {
alreadyMember = true
return nil
}
if e := tx.Exec(
"INSERT INTO residence_residence_users (residence_id, user_id) VALUES (?, ?) ON CONFLICT DO NOTHING",
sc.ResidenceID, userID,
).Error; e != nil {
return e
}
// One-time use: deactivate the code. A failure here aborts the join.
if e := tx.Model(&models.ResidenceShareCode{}).
Where("id = ?", sc.ID).Update("is_active", false).Error; e != nil {
return e
}
return nil
})
return residenceID, alreadyMember, err
}
// IsOwner checks if a user is the owner of a residence
func (r *ResidenceRepository) IsOwner(residenceID, userID uint) (bool, error) {
var count int64
@@ -151,6 +151,28 @@ func (r *SubscriptionRepository) FindByAppleReceiptContains(transactionID string
return &sub, nil
}
// FindByAppleOriginalTransactionID finds a subscription by the Apple original
// transaction ID (audit C5/C13). Exact match on an indexed column — replaces
// the LIKE scan in FindByAppleReceiptContains for both replay detection and
// webhook user lookup.
func (r *SubscriptionRepository) FindByAppleOriginalTransactionID(originalTransactionID string) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("apple_original_transaction_id = ?", originalTransactionID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// UpdateAppleOriginalTransactionID binds an Apple original transaction ID to a
// user's subscription. A partial unique index enforces one account per
// transaction (audit C5) — a second account claiming the same ID fails here.
func (r *SubscriptionRepository) UpdateAppleOriginalTransactionID(userID uint, originalTransactionID string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("apple_original_transaction_id", originalTransactionID).Error
}
// FindByGoogleToken finds a subscription by Google purchase token
// Used by webhooks to find the user associated with a purchase
func (r *SubscriptionRepository) FindByGoogleToken(purchaseToken string) (*models.UserSubscription, error) {
@@ -226,3 +226,48 @@ func TestUpdateExpiresAt(t *testing.T) {
require.NotNil(t, updated.ExpiresAt)
assert.WithinDuration(t, newExpiry, *updated.ExpiresAt, time.Second, "expires_at should be updated")
}
// TestSubscriptionRepo_IAPTransactionReplayRejected is the regression test for
// audit C5/C6: an in-app-purchase transaction (an Apple original transaction
// ID or a Google purchase token) may be bound to exactly one account. Without
// that guarantee a valid receipt could be replayed against a second account
// to grant Pro for free. The guarantee is the pair of partial unique indexes
// added by migration 000004; AutoMigrate does not create them, so this test
// recreates them verbatim to exercise the same DB-level enforcement.
func TestSubscriptionRepo_IAPTransactionReplayRejected(t *testing.T) {
db := testutil.SetupTestDB(t)
require.NoError(t, db.Exec(`CREATE UNIQUE INDEX uq_subscription_apple_original_txn `+
`ON subscription_usersubscription (apple_original_transaction_id) `+
`WHERE apple_original_transaction_id IS NOT NULL AND apple_original_transaction_id <> ''`).Error)
require.NoError(t, db.Exec(`CREATE UNIQUE INDEX uq_subscription_google_purchase_token `+
`ON subscription_usersubscription (google_purchase_token) `+
`WHERE google_purchase_token IS NOT NULL AND google_purchase_token <> ''`).Error)
repo := NewSubscriptionRepository(db)
userA := testutil.CreateTestUser(t, db, "iapusera", "iapa@test.com", "password")
userB := testutil.CreateTestUser(t, db, "iapuserb", "iapb@test.com", "password")
require.NoError(t, db.Create(&models.UserSubscription{UserID: userA.ID, Tier: models.TierFree}).Error)
require.NoError(t, db.Create(&models.UserSubscription{UserID: userB.ID, Tier: models.TierFree}).Error)
t.Run("apple transaction cannot be claimed by a second account", func(t *testing.T) {
require.NoError(t, repo.UpdateAppleOriginalTransactionID(userA.ID, "apple-original-txn-1"),
"the first account binding the transaction must succeed")
err := repo.UpdateAppleOriginalTransactionID(userB.ID, "apple-original-txn-1")
require.Error(t, err,
"replaying account A's Apple transaction onto account B must be rejected (C5)")
})
t.Run("google purchase token cannot be claimed by a second account", func(t *testing.T) {
require.NoError(t, repo.UpdatePurchaseToken(userA.ID, "google-purchase-token-1"),
"the first account binding the token must succeed")
err := repo.UpdatePurchaseToken(userB.ID, "google-purchase-token-1")
require.Error(t, err,
"replaying account A's Google purchase token onto account B must be rejected (C6)")
})
t.Run("re-binding the same transaction to the same account is allowed", func(t *testing.T) {
// A renewal re-submitting the same transaction for its owner must not
// be rejected — the partial unique index excludes the row's own value.
require.NoError(t, repo.UpdateAppleOriginalTransactionID(userA.ID, "apple-original-txn-1"))
})
}
+44 -5
View File
@@ -174,10 +174,12 @@ func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error
return &token, nil
}
// FindTokenByKey looks up an auth token by its key value.
func (r *UserRepository) FindTokenByKey(key string) (*models.AuthToken, error) {
// FindTokenByKey looks up an auth token by its raw key value. The raw token
// is hashed (audit C1) before the indexed lookup, since only the hash is
// stored.
func (r *UserRepository) FindTokenByKey(rawKey string) (*models.AuthToken, error) {
var token models.AuthToken
if err := r.db.Where("key = ?", key).First(&token).Error; err != nil {
if err := r.db.Where("key = ?", models.HashToken(rawKey)).First(&token).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTokenNotFound
}
@@ -195,9 +197,46 @@ func (r *UserRepository) CreateToken(userID uint) (*models.AuthToken, error) {
return &token, nil
}
// DeleteToken deletes an auth token
// CreateFreshToken issues a new auth token for the user, replacing any
// existing one. Because tokens are stored hashed (audit C1) the server
// cannot re-issue a previously-minted token's plaintext, so every login
// mints a fresh token. The returned token's Plaintext field carries the
// raw value to hand to the client; it is never persisted.
//
// It also returns the stored hashes of the token rows it deleted, so the
// caller can evict those entries from the Redis token cache (audit MEDIUM-1).
// Without that, a prior (e.g. stolen) token keeps authenticating via a cache
// hit for up to the cache TTL even though its DB row is gone.
func (r *UserRepository) CreateFreshToken(userID uint) (*models.AuthToken, []string, error) {
var token models.AuthToken
var oldHashes []string
err := r.db.Transaction(func(tx *gorm.DB) error {
var old []models.AuthToken
if err := tx.Where("user_id = ?", userID).Find(&old).Error; err != nil {
return err
}
oldHashes = make([]string, 0, len(old))
for i := range old {
if old[i].Key != "" {
oldHashes = append(oldHashes, old[i].Key)
}
}
if err := tx.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error; err != nil {
return err
}
token = models.AuthToken{UserID: userID}
return tx.Create(&token).Error
})
if err != nil {
return nil, nil, err
}
return &token, oldHashes, nil
}
// DeleteToken deletes an auth token by its raw key value. The raw token is
// hashed (audit C1) before the lookup, since only the hash is stored.
func (r *UserRepository) DeleteToken(token string) error {
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})
result := r.db.Where("key = ?", models.HashToken(token)).Delete(&models.AuthToken{})
if result.Error != nil {
return result.Error
}
@@ -104,7 +104,7 @@ func TestUserRepository_FindTokenByKey(t *testing.T) {
token, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
found, err := repo.FindTokenByKey(token.Key)
found, err := repo.FindTokenByKey(token.Plaintext)
require.NoError(t, err)
assert.Equal(t, token.Key, found.Key)
assert.Equal(t, user.ID, found.UserID)
@@ -128,10 +128,10 @@ func TestUserRepository_DeleteToken(t *testing.T) {
token, err := repo.GetOrCreateToken(user.ID)
require.NoError(t, err)
err = repo.DeleteToken(token.Key)
err = repo.DeleteToken(token.Plaintext)
require.NoError(t, err)
_, err = repo.FindTokenByKey(token.Key)
_, err = repo.FindTokenByKey(token.Plaintext)
assert.ErrorIs(t, err, ErrTokenNotFound)
}
+29 -9
View File
@@ -75,10 +75,13 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// responses are unaffected — they don't load any assets, so any CSP is fine.
// frame-ancestors stays 'none' to block clickjacking.
e.Use(middleware.SecureWithConfig(middleware.SecureConfig{
XSSProtection: "1; mode=block",
// XSSProtection deliberately empty (audit L7): the X-XSS-Protection
// header is deprecated and has itself caused XSS in legacy browsers.
XSSProtection: "",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
HSTSMaxAge: 31536000,
HSTSMaxAge: 63072000, // 2 years — preload-eligible (audit L5/CODE-L3)
HSTSPreloadEnabled: true,
ReferrerPolicy: "strict-origin-when-cross-origin",
ContentSecurityPolicy: "default-src 'self'; " +
"style-src 'self' https://fonts.googleapis.com; " +
@@ -86,6 +89,8 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
"img-src 'self' data:; " +
"script-src 'self'; " +
"connect-src 'self'; " +
"object-src 'none'; " + // audit L8 — disable plugins/embeds
"base-uri 'self'; " + // audit L8 — block <base> hijacking
"frame-ancestors 'none'",
}))
e.Use(middleware.BodyLimitWithConfig(middleware.BodyLimitConfig{
@@ -136,9 +141,20 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// labeled by route pattern, method, and status code.
e.Use(prom.HTTPMiddleware())
// /metrics endpoint exposed for vmagent scrape. No auth — bound to
// the cluster network only; not exposed via Cloudflare.
e.GET("/metrics", prom.Handler())
// /metrics endpoint for the in-cluster vmagent scrape (audit LIVE-L1).
// vmagent scrapes api pods directly (pod-to-pod), so its requests carry
// no X-Forwarded-For. Any request that DOES carry one reached us through
// Traefik/Cloudflare — i.e. the public internet — and is refused with a
// 404. The api pod port is not exposed outside the cluster, so a request
// cannot reach /metrics without going through Traefik, and Traefik always
// appends X-Forwarded-For — the check cannot be bypassed.
metricsHandler := prom.Handler()
e.GET("/metrics", func(c echo.Context) error {
if c.Request().Header.Get("X-Forwarded-For") != "" {
return echo.NewHTTPError(http.StatusNotFound)
}
return metricsHandler(c)
})
// Serve landing page static files (if static directory is configured)
staticDir := cfg.Server.StaticDir
@@ -204,6 +220,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
// Wire Redis cache for residence-ID lookups across the four services that
// read it on the request hot path. Cache is best-effort; nil cache is OK.
if deps.Cache != nil {
authService.SetCacheService(deps.Cache) // per-account login lockout (audit M5)
residenceService.SetCacheService(deps.Cache)
taskService.SetCacheService(deps.Cache)
contractorService.SetCacheService(deps.Cache)
@@ -316,7 +333,7 @@ func SetupRouter(deps *Dependencies) *echo.Echo {
protected.Use(custommiddleware.TimezoneMiddleware())
{
setupProtectedAuthRoutes(protected, authHandler)
setupResidenceRoutes(protected, residenceHandler)
setupResidenceRoutes(protected, residenceHandler, authMiddleware.RequireVerified())
setupTaskRoutes(protected, taskHandler)
setupSuggestionRoutes(protected, suggestionHandler)
setupContractorRoutes(protected, contractorHandler)
@@ -583,7 +600,7 @@ func setupPublicDataRoutes(api *echo.Group, residenceHandler *handlers.Residence
}
// setupResidenceRoutes configures residence routes
func setupResidenceRoutes(api *echo.Group, residenceHandler *handlers.ResidenceHandler) {
func setupResidenceRoutes(api *echo.Group, residenceHandler *handlers.ResidenceHandler, requireVerified echo.MiddlewareFunc) {
residences := api.Group("/residences")
{
residences.GET("/", residenceHandler.ListResidences)
@@ -598,8 +615,11 @@ func setupResidenceRoutes(api *echo.Group, residenceHandler *handlers.ResidenceH
residences.DELETE("/:id/", residenceHandler.DeleteResidence)
residences.GET("/:id/share-code/", residenceHandler.GetShareCode)
residences.POST("/:id/generate-share-code/", residenceHandler.GenerateShareCode)
residences.POST("/:id/generate-share-package/", residenceHandler.GenerateSharePackage)
// Audit LIVE-L19: generating a residence share code requires a
// verified email — it blocks bad-faith unverified signups from
// minting share codes.
residences.POST("/:id/generate-share-code/", residenceHandler.GenerateShareCode, requireVerified)
residences.POST("/:id/generate-share-package/", residenceHandler.GenerateSharePackage, requireVerified)
residences.POST("/:id/generate-tasks-report/", residenceHandler.GenerateTasksReport)
residences.GET("/:id/users/", residenceHandler.GetResidenceUsers)
residences.DELETE("/:id/users/:user_id/", residenceHandler.RemoveResidenceUser)
+14 -12
View File
@@ -75,9 +75,9 @@ func TestRefreshToken_FreshToken_ReturnsExisting(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "fresh token should return the same token")
assert.Equal(t, token.Plaintext, resp.Token, "fresh token should return the same token")
assert.Contains(t, resp.Message, "still valid")
}
@@ -88,23 +88,25 @@ func TestRefreshToken_InRenewalWindow_ReturnsNewToken(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "should return a new token")
assert.NotEqual(t, token.Plaintext, resp.Token, "should return a new token")
assert.Contains(t, resp.Message, "refreshed")
// Verify old token was deleted
var count int64
// The DB stores the SHA-256 hash, so query by token.Key (the hash).
db.Model(&models.AuthToken{}).Where("key = ?", token.Key).Count(&count)
assert.Equal(t, int64(0), count, "old token should be deleted")
// Verify new token exists in DB
db.Model(&models.AuthToken{}).Where("key = ?", resp.Token).Count(&count)
// resp.Token is the raw token; the DB stores its hash.
db.Model(&models.AuthToken{}).Where("key = ?", models.HashToken(resp.Token)).Count(&count)
assert.Equal(t, int64(1), count, "new token should exist in DB")
// Verify new token belongs to the same user
var newToken models.AuthToken
require.NoError(t, db.Where("key = ?", resp.Token).First(&newToken).Error)
require.NoError(t, db.Where("key = ?", models.HashToken(resp.Token)).First(&newToken).Error)
assert.Equal(t, user.ID, newToken.UserID)
}
@@ -115,7 +117,7 @@ func TestRefreshToken_ExpiredToken_Returns401(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.token_expired")
@@ -130,9 +132,9 @@ func TestRefreshToken_AtExactBoundary60Days(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.NotEqual(t, token.Key, resp.Token, "token at 61 days should be refreshed")
assert.NotEqual(t, token.Plaintext, resp.Token, "token at 61 days should be refreshed")
}
func TestRefreshToken_InvalidToken_Returns401(t *testing.T) {
@@ -155,7 +157,7 @@ func TestRefreshToken_WrongUser_Returns401(t *testing.T) {
svc := newTestAuthService(db)
// Try to refresh with a different user ID
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID+999)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID+999)
require.Error(t, err)
assert.Nil(t, resp)
assert.Contains(t, err.Error(), "error.invalid_token")
@@ -168,7 +170,7 @@ func TestRefreshToken_FreshTokenAt59Days_ReturnsExisting(t *testing.T) {
svc := newTestAuthService(db)
resp, err := svc.RefreshToken(context.Background(), token.Key, user.ID)
resp, err := svc.RefreshToken(context.Background(), token.Plaintext, user.ID)
require.NoError(t, err)
assert.Equal(t, token.Key, resp.Token, "token at 59 days should NOT be refreshed")
assert.Equal(t, token.Plaintext, resp.Token, "token at 59 days should NOT be refreshed")
}
+132 -52
View File
@@ -3,6 +3,7 @@ package services
import (
"context"
"crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -36,13 +37,32 @@ var (
ErrGoogleSignInFailed = errors.New("Google Sign In failed")
)
// Per-account login lockout (audit M5, hardened per MEDIUM-3).
const (
// maxLoginFailureIPs is how many DISTINCT source IPs may fail to log in to
// one account within the window before that account is locked. Counting
// distinct IPs (not raw attempts) means a single attacker who knows a
// victim's email cannot lock the victim out by spamming failures — only a
// genuinely distributed credential-stuffing attack reaches this threshold.
maxLoginFailureIPs = 5
// loginLockWindow is how long the failed-IP set persists; it is refreshed
// on each failure so an active attack keeps the window open.
loginLockWindow = 15 * time.Minute
)
// AuthService handles authentication business logic
type AuthService struct {
userRepo *repositories.UserRepository
notificationRepo *repositories.NotificationRepository
cache *CacheService
cfg *config.Config
}
// SetCacheService wires Redis for per-account login-failure tracking (M5).
func (s *AuthService) SetCacheService(cache *CacheService) {
s.cache = cache
}
// NewAuthService creates a new auth service
func NewAuthService(userRepo *repositories.UserRepository, cfg *config.Config) *AuthService {
return &AuthService{
@@ -56,34 +76,89 @@ func (s *AuthService) SetNotificationRepository(notificationRepo *repositories.N
s.notificationRepo = notificationRepo
}
// Login authenticates a user and returns a token
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*responses.LoginResponse, error) {
// dummyPasswordHash is a valid bcrypt hash used to keep login response time
// constant when the account does not exist (audit LIVE-L11). It is computed
// once at startup; the plaintext it hashes is irrelevant and never used.
var dummyPasswordHash = func() string {
h, err := bcrypt.GenerateFromPassword([]byte("honeydue-login-timing-equalizer"), models.BcryptCost)
if err != nil {
return "" // CompareHashAndPassword against "" always fails — safe
}
return string(h)
}()
// freshToken mints a new auth token for the user and evicts any prior token's
// Redis cache entry (audit MEDIUM-1). Without the eviction a re-login would
// not actually kill a previously-issued token until the cache TTL lapsed — a
// stolen token would keep working for up to 5 minutes after the victim
// re-authenticates. A cache-eviction failure is logged, not fatal: the token
// row is already gone, so the stale entry simply ages out on its own.
func (s *AuthService) freshToken(ctx context.Context, userID uint) (*models.AuthToken, error) {
token, oldHashes, err := s.userRepo.WithContext(ctx).CreateFreshToken(userID)
if err != nil {
return nil, err
}
if s.cache != nil && len(oldHashes) > 0 {
if cErr := s.cache.InvalidateAuthTokenHashes(ctx, oldHashes...); cErr != nil {
log.Warn().Err(cErr).Uint("user_id", userID).
Msg("failed to evict prior auth-token cache entries on re-login")
}
}
return token, nil
}
// Login authenticates a user and returns a token. clientIP is the request's
// source IP (echo c.RealIP()), used for the distributed-attack lockout.
func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest, clientIP string) (*responses.LoginResponse, error) {
// Find user by username or email
identifier := req.Username
if identifier == "" {
identifier = req.Email
}
lockKey := strings.ToLower(strings.TrimSpace(identifier))
// Audit M5 (hardened per MEDIUM-3): per-account lockout keyed on the set
// of distinct source IPs that have failed. Once enough distinct IPs have
// failed for one account within the window, reject — this still catches
// distributed credential stuffing, without letting a single attacker lock
// a victim out by spamming failed logins from one IP.
if s.cache != nil && lockKey != "" {
if n, cErr := s.cache.LoginFailureIPCount(ctx, lockKey); cErr == nil && n >= maxLoginFailureIPs {
return nil, apperrors.TooManyRequests("error.too_many_login_attempts")
}
}
user, err := s.userRepo.WithContext(ctx).FindByUsernameOrEmail(identifier)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
return nil, apperrors.Unauthorized("error.invalid_credentials")
}
if err != nil && !errors.Is(err, repositories.ErrUserNotFound) {
return nil, apperrors.Internal(err)
}
// Check if user is active
if !user.IsActive {
return nil, apperrors.Unauthorized("error.account_inactive")
// Constant-time login (audit LIVE-L11): always run a bcrypt comparison,
// even when the account does not exist or is inactive, so response
// timing never reveals which emails are real accounts. Compare against
// the user's hash when available, otherwise a fixed dummy hash.
passwordHash := dummyPasswordHash
if user != nil {
passwordHash = user.Password
}
passwordOK := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)) == nil
// Verify password
if !user.CheckPassword(req.Password) {
// One generic error for not-found, inactive, and wrong-password
// (audit L1) — none of them disclose which condition failed.
if user == nil || !user.IsActive || !passwordOK {
if s.cache != nil && lockKey != "" {
_, _ = s.cache.RegisterLoginFailure(ctx, lockKey, clientIP, loginLockWindow)
}
return nil, apperrors.Unauthorized("error.invalid_credentials")
}
// Successful authentication — clear the failure counter (audit M5).
if s.cache != nil && lockKey != "" {
_ = s.cache.ClearLoginFailures(ctx, lockKey)
}
// Get or create auth token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
token, err := s.freshToken(ctx, user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -95,7 +170,7 @@ func (s *AuthService) Login(ctx context.Context, req *requests.LoginRequest) (*r
}
return &responses.LoginResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(user),
}, nil
}
@@ -176,13 +251,13 @@ func (s *AuthService) Register(ctx context.Context, req *requests.RegisterReques
}
// Create auth token (outside transaction since token generation is idempotent)
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
token, err := s.freshToken(ctx, user.ID)
if err != nil {
return nil, "", apperrors.Internal(err)
}
return &responses.RegisterResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(user),
Message: "Registration successful. Please check your email to verify your account.",
}, code, nil
@@ -243,7 +318,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, tokenKey string, userID
}
return &responses.RefreshTokenResponse{
Token: newToken.Key,
Token: newToken.Plaintext,
Message: "Token refreshed successfully.",
}, nil
}
@@ -390,26 +465,26 @@ func (s *AuthService) VerifyEmail(ctx context.Context, userID uint, code string)
return nil
}
// Find and validate confirmation code
confirmCode, err := s.userRepo.WithContext(ctx).FindConfirmationCode(userID, code)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) {
// Audit M4: validate the code, consume it, and flip the verified flag in
// one transaction so the three writes commit or roll back together.
txErr := s.userRepo.WithContext(ctx).Transaction(func(txRepo *repositories.UserRepository) error {
confirmCode, err := txRepo.FindConfirmationCode(userID, code)
if err != nil {
return err
}
if err := txRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
return err
}
return txRepo.SetProfileVerified(userID, true)
})
if txErr != nil {
if errors.Is(txErr, repositories.ErrCodeNotFound) {
return apperrors.BadRequest("error.invalid_verification_code")
}
if errors.Is(err, repositories.ErrCodeExpired) {
if errors.Is(txErr, repositories.ErrCodeExpired) {
return apperrors.BadRequest("error.verification_code_expired")
}
return apperrors.Internal(err)
}
// Mark code as used
if err := s.userRepo.WithContext(ctx).MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
return apperrors.Internal(err)
}
// Set profile as verified
if err := s.userRepo.WithContext(ctx).SetProfileVerified(userID, true); err != nil {
return apperrors.Internal(err)
return apperrors.Internal(txErr)
}
return nil
@@ -476,7 +551,7 @@ func (s *AuthService) ForgotPassword(ctx context.Context, email string) (string,
expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry)
// Hash the code before storing
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), models.BcryptCost)
if err != nil {
return "", nil, apperrors.Internal(err)
}
@@ -596,7 +671,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
}
// Get or create token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
token, err := s.freshToken(ctx, user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -605,7 +680,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
return &responses.AppleSignInResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(user),
IsNewUser: false,
}, nil
@@ -638,7 +713,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
_ = s.userRepo.WithContext(ctx).SetProfileVerified(existingUser.ID, true)
// Get or create token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
token, err := s.freshToken(ctx, existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -653,7 +728,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
}
return &responses.AppleSignInResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(existingUser),
IsNewUser: false,
}, nil
@@ -704,7 +779,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
}
// Create token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
token, err := s.freshToken(ctx, user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -716,7 +791,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
}
return &responses.AppleSignInResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(user),
IsNewUser: true,
}, nil
@@ -749,7 +824,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
// Get or create token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
token, err := s.freshToken(ctx, user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -758,7 +833,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
_ = s.userRepo.WithContext(ctx).UpdateLastLogin(user.ID)
return &responses.GoogleSignInResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(user),
IsNewUser: false,
}, nil
@@ -794,7 +869,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
// Get or create token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(existingUser.ID)
token, err := s.freshToken(ctx, existingUser.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -809,7 +884,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
return &responses.GoogleSignInResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(existingUser),
IsNewUser: false,
}, nil
@@ -861,7 +936,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
// Create token
token, err := s.userRepo.WithContext(ctx).GetOrCreateToken(user.ID)
token, err := s.freshToken(ctx, user.ID)
if err != nil {
return nil, apperrors.Internal(err)
}
@@ -873,7 +948,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
}
return &responses.GoogleSignInResponse{
Token: token.Key,
Token: token.Plaintext,
User: responses.NewUserResponse(user),
IsNewUser: true,
}, nil
@@ -882,14 +957,19 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// Helper functions
func generateSixDigitCode() string {
b := make([]byte, 4)
rand.Read(b)
num := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if num < 0 {
num = -num
// Uniform 000000999999 via rejection sampling on crypto/rand,
// removing the modulo bias of `n % 1000000` (audit H4).
for {
var b [4]byte
if _, err := rand.Read(b[:]); err != nil {
continue
}
// 4294000000 is the largest multiple of 1e6 <= MaxUint32.
n := binary.BigEndian.Uint32(b[:])
if n < 4294000000 {
return fmt.Sprintf("%06d", n%1000000)
}
}
code := num % 1000000
return fmt.Sprintf("%06d", code)
}
func generateResetToken() string {
+11 -9
View File
@@ -54,7 +54,7 @@ func TestAuthService_Login(t *testing.T) {
Password: "Password123",
}
resp, err := service.Login(context.Background(), req)
resp, err := service.Login(context.Background(), req, "")
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.Equal(t, "testuser", resp.User.Username)
@@ -75,7 +75,7 @@ func TestAuthService_Login_ByEmail(t *testing.T) {
Password: "Password123",
}
resp, err := service.Login(context.Background(), req)
resp, err := service.Login(context.Background(), req, "")
require.NoError(t, err)
assert.NotEmpty(t, resp.Token)
}
@@ -95,7 +95,7 @@ func TestAuthService_Login_InvalidCredentials(t *testing.T) {
Password: "WrongPassword1",
}
_, err := service.Login(context.Background(), req)
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
@@ -112,7 +112,7 @@ func TestAuthService_Login_UserNotFound(t *testing.T) {
Password: "Password123",
}
_, err := service.Login(context.Background(), req)
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
@@ -134,8 +134,10 @@ func TestAuthService_Login_InactiveUser(t *testing.T) {
Password: "Password123",
}
_, err := service.Login(context.Background(), req)
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
_, err := service.Login(context.Background(), req, "")
// Audit L1: inactive accounts return the same generic error as bad
// credentials so login does not disclose which accounts exist.
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
// === Register ===
@@ -443,7 +445,7 @@ func TestAuthService_ResetPassword(t *testing.T) {
Username: "testuser",
Password: "NewPassword123",
}
loginResp, err := service.Login(context.Background(), loginReq)
loginResp, err := service.Login(context.Background(), loginReq, "")
require.NoError(t, err)
assert.NotEmpty(t, loginResp.Token)
}
@@ -472,7 +474,7 @@ func TestAuthService_Logout(t *testing.T) {
Username: "testuser",
Password: "Password123",
}
loginResp, err := service.Login(context.Background(), loginReq)
loginResp, err := service.Login(context.Background(), loginReq, "")
require.NoError(t, err)
// Logout
@@ -659,7 +661,7 @@ func TestAuthService_Login_EmptyPassword(t *testing.T) {
Password: "",
}
_, err := service.Login(context.Background(), req)
_, err := service.Login(context.Background(), req, "")
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
}
+67 -10
View File
@@ -12,6 +12,7 @@ import (
"github.com/rs/zerolog/log"
"github.com/treytartt/honeydue-api/internal/config"
"github.com/treytartt/honeydue-api/internal/models"
)
// CacheService provides Redis caching functionality
@@ -139,22 +140,25 @@ const (
TokenCacheTTL = 5 * time.Minute
)
// authTokenCacheKey returns the Redis key for an auth token. The raw token
// is hashed (audit C1) so the plaintext token never appears in a Redis key.
func authTokenCacheKey(token string) string {
return AuthTokenPrefix + models.HashToken(token)
}
// CacheAuthToken caches a user ID for a token
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
key := AuthTokenPrefix + token
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d", userID), TokenCacheTTL)
}
// CacheAuthTokenWithCreated caches a user ID and token creation time for a token
func (c *CacheService) CacheAuthTokenWithCreated(ctx context.Context, token string, userID uint, createdUnix int64) error {
key := AuthTokenPrefix + token
return c.SetString(ctx, key, fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
return c.SetString(ctx, authTokenCacheKey(token), fmt.Sprintf("%d|%d", userID, createdUnix), TokenCacheTTL)
}
// GetCachedAuthToken gets a cached user ID for a token
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
key := AuthTokenPrefix + token
val, err := c.GetString(ctx, key)
val, err := c.GetString(ctx, authTokenCacheKey(token))
if err != nil {
return 0, err
}
@@ -167,8 +171,7 @@ func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (ui
// GetCachedAuthTokenWithCreated gets a cached user ID and token creation time.
// Returns userID, createdUnix, error. createdUnix is 0 if not stored (legacy format).
func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token string) (uint, int64, error) {
key := AuthTokenPrefix + token
val, err := c.GetString(ctx, key)
val, err := c.GetString(ctx, authTokenCacheKey(token))
if err != nil {
return 0, 0, err
}
@@ -184,8 +187,62 @@ func (c *CacheService) GetCachedAuthTokenWithCreated(ctx context.Context, token
// InvalidateAuthToken removes a cached token
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
key := AuthTokenPrefix + token
return c.Delete(ctx, key)
return c.Delete(ctx, authTokenCacheKey(token))
}
// InvalidateAuthTokenHashes removes cached entries for already-hashed token
// keys. Unlike InvalidateAuthToken (which hashes a plaintext), this takes the
// stored hash directly — used to evict a user's prior token on re-login
// (audit MEDIUM-1), where the server no longer has the plaintext.
func (c *CacheService) InvalidateAuthTokenHashes(ctx context.Context, hashes ...string) error {
keys := make([]string, 0, len(hashes))
for _, h := range hashes {
if h != "" {
keys = append(keys, AuthTokenPrefix+h)
}
}
if len(keys) == 0 {
return nil
}
return c.Delete(ctx, keys...)
}
// --- Per-account login-failure tracking (audit M5) ---
const loginFailPrefix = "login_fail:"
// RegisterLoginFailure records a failed login for an account from a given
// source IP, and returns the number of DISTINCT source IPs that have failed
// for this account within the window. Tracking distinct IPs as a set rather
// than a raw counter (audit MEDIUM-3) means one attacker, from one IP, cannot
// run the count up and lock a victim out by knowing only their email — a
// single IP is bounded by the per-IP edge/app rate limiters instead. A
// genuinely distributed credential-stuffing attack still trips the lockout.
func (c *CacheService) RegisterLoginFailure(ctx context.Context, identifier, ip string, window time.Duration) (int64, error) {
key := loginFailPrefix + identifier
member := ip
if member == "" {
member = "unknown"
}
if err := c.client.SAdd(ctx, key, member).Err(); err != nil {
return 0, err
}
// Refresh the TTL on each failure: an active attack keeps the window
// open, while a quiet account ages out `window` after its last failure.
_ = c.client.Expire(ctx, key, window).Err()
return c.client.SCard(ctx, key).Result()
}
// LoginFailureIPCount returns how many distinct source IPs have failed to log
// in to this account within the window (audit MEDIUM-3). SCard on a missing
// key returns 0.
func (c *CacheService) LoginFailureIPCount(ctx context.Context, identifier string) (int64, error) {
return c.client.SCard(ctx, loginFailPrefix+identifier).Result()
}
// ClearLoginFailures resets the failed-login IP set after a successful login.
func (c *CacheService) ClearLoginFailures(ctx context.Context, identifier string) error {
return c.client.Del(ctx, loginFailPrefix+identifier).Err()
}
// Static data cache helpers
+6 -1
View File
@@ -296,9 +296,14 @@ func (s *ContractorService) ToggleFavorite(ctx context.Context, contractorID, us
return nil, apperrors.Internal(err)
}
// Re-fetch the contractor to get the updated state with all relations
// Re-fetch to get the updated state with all relations. Audit M12: if the
// contractor was deleted concurrently between the toggle and this read,
// surface a clean 404 instead of a 500.
contractor, err = s.contractorRepo.WithContext(ctx).FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.contractor_not_found")
}
return nil, apperrors.Internal(err)
}
+29 -14
View File
@@ -5,9 +5,9 @@ import (
"gorm.io/gorm"
)
// FileOwnershipService checks whether a user owns a file referenced by URL.
// It queries task completion images, document files, and document images
// to determine ownership through residence access.
// FileOwnershipService checks whether a user has access to a file referenced
// by URL. It queries task completion images, document files, and document
// images, resolving access through residence ownership or membership.
type FileOwnershipService struct {
db *gorm.DB
}
@@ -17,16 +17,31 @@ func NewFileOwnershipService(db *gorm.DB) *FileOwnershipService {
return &FileOwnershipService{db: db}
}
// IsFileOwnedByUser checks if the given file URL belongs to a record
// that the user has access to (via residence membership).
// accessibleResidenceIDs returns a subquery of residence IDs the user can
// access: residences they own (residence_residence.owner_id) UNION residences
// they are a member of (residence_residence_users).
//
// Audit C7: the previous queries joined residence_residence_users only, so a
// residence owner who was not also a member of the join table could not pass
// the ownership check for files in their own property.
func (s *FileOwnershipService) accessibleResidenceIDs(userID uint) *gorm.DB {
return s.db.Raw(`
SELECT id FROM residence_residence WHERE owner_id = ?
UNION
SELECT residence_id FROM residence_residence_users WHERE user_id = ?
`, userID, userID)
}
// IsFileOwnedByUser checks if the given file URL belongs to a record in a
// residence the user owns or is a member of.
func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (bool, error) {
// Check task completion images: image_url -> completion -> task -> residence -> user access
// Task completion images: image_url -> completion -> task -> residence.
var completionImageCount int64
err := s.db.Model(&models.TaskCompletionImage{}).
Joins("JOIN task_taskcompletion ON task_taskcompletion.id = task_taskcompletionimage.completion_id").
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_task.residence_id").
Where("task_taskcompletionimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Where("task_taskcompletionimage.image_url = ?", fileURL).
Where("task_task.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
Count(&completionImageCount).Error
if err != nil {
return false, err
@@ -35,11 +50,11 @@ func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (b
return true, nil
}
// Check document files: file_url -> document -> residence -> user access
// Document files: file_url -> document -> residence.
var documentCount int64
err = s.db.Model(&models.Document{}).
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
Where("task_document.file_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Where("task_document.file_url = ?", fileURL).
Where("task_document.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
Count(&documentCount).Error
if err != nil {
return false, err
@@ -48,12 +63,12 @@ func (s *FileOwnershipService) IsFileOwnedByUser(fileURL string, userID uint) (b
return true, nil
}
// Check document images: image_url -> document_image -> document -> residence -> user access
// Document images: image_url -> document_image -> document -> residence.
var documentImageCount int64
err = s.db.Model(&models.DocumentImage{}).
Joins("JOIN task_document ON task_document.id = task_documentimage.document_id").
Joins("JOIN residence_residence_users ON residence_residence_users.residence_id = task_document.residence_id").
Where("task_documentimage.image_url = ? AND residence_residence_users.user_id = ?", fileURL, userID).
Where("task_documentimage.image_url = ?", fileURL).
Where("task_document.residence_id IN (?)", s.accessibleResidenceIDs(userID)).
Count(&documentImageCount).Error
if err != nil {
return false, err
+245 -71
View File
@@ -2,132 +2,306 @@ package services
import (
"context"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/treytartt/honeydue-api/internal/config"
)
const (
googleTokenInfoURL = "https://oauth2.googleapis.com/tokeninfo"
// googleKeysURL is Google's JWKS endpoint for ID-token signature verification.
googleKeysURL = "https://www.googleapis.com/oauth2/v3/certs"
googleKeysCacheTTL = 24 * time.Hour
googleKeysCacheKey = "google:public_keys"
)
// googleIssuers is the set of valid `iss` claim values for a Google ID token.
var googleIssuers = map[string]bool{
"accounts.google.com": true,
"https://accounts.google.com": true,
}
var (
ErrInvalidGoogleToken = errors.New("invalid Google ID token")
ErrGoogleTokenExpired = errors.New("Google ID token has expired")
ErrInvalidGoogleAudience = errors.New("invalid Google token audience")
ErrInvalidGoogleIssuer = errors.New("invalid Google token issuer")
ErrGoogleKeyNotFound = errors.New("Google public key not found")
)
// GoogleTokenInfo represents the response from Google's token info endpoint
type GoogleTokenInfo struct {
Sub string `json:"sub"` // Unique Google user ID
Email string `json:"email"` // User's email
EmailVerified string `json:"email_verified"` // "true" or "false"
Name string `json:"name"` // Full name
GivenName string `json:"given_name"` // First name
FamilyName string `json:"family_name"` // Last name
Picture string `json:"picture"` // Profile picture URL
Aud string `json:"aud"` // Audience (client ID)
Azp string `json:"azp"` // Authorized party
Exp string `json:"exp"` // Expiration time
Iss string `json:"iss"` // Issuer
// GoogleJWKS represents Google's JSON Web Key Set.
type GoogleJWKS struct {
Keys []GoogleJWK `json:"keys"`
}
// IsEmailVerified returns whether the email is verified
// GoogleJWK represents a single JSON Web Key from Google.
type GoogleJWK struct {
Kty string `json:"kty"` // Key type (RSA)
Kid string `json:"kid"` // Key ID
Use string `json:"use"` // Key use (sig)
Alg string `json:"alg"` // Algorithm (RS256)
N string `json:"n"` // RSA modulus
E string `json:"e"` // RSA exponent
}
// GoogleTokenClaims represents the claims in a Google ID token JWT.
type GoogleTokenClaims struct {
jwt.RegisteredClaims
Email string `json:"email,omitempty"`
EmailVerified bool `json:"email_verified,omitempty"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
Picture string `json:"picture,omitempty"`
Azp string `json:"azp,omitempty"` // Authorized party
}
// GoogleTokenInfo is the verified, caller-facing view of a Google ID token.
type GoogleTokenInfo struct {
Sub string // Unique Google user ID
Email string
EmailVerified string // "true" or "false" — string for caller compatibility
Name string
GivenName string
FamilyName string
Picture string
Aud string
Azp string
Iss string
}
// IsEmailVerified returns whether the email is verified.
func (t *GoogleTokenInfo) IsEmailVerified() bool {
return t.EmailVerified == "true"
}
// GoogleAuthService handles Google Sign In token verification
// GoogleAuthService handles Google Sign In token verification.
type GoogleAuthService struct {
cache *CacheService
config *config.Config
client *http.Client
}
// NewGoogleAuthService creates a new Google auth service
// NewGoogleAuthService creates a new Google auth service.
func NewGoogleAuthService(cache *CacheService, cfg *config.Config) *GoogleAuthService {
return &GoogleAuthService{
cache: cache,
config: cfg,
client: &http.Client{
Timeout: 10 * time.Second,
},
client: &http.Client{Timeout: 10 * time.Second},
}
}
// VerifyIDToken verifies a Google ID token and returns the token info
// VerifyIDToken verifies a Google ID token locally (audit C2/C3): it checks
// the RS256 signature against Google's published JWKS and the iss, aud, and
// exp claims. It never sends the token to a third-party endpoint, so it no
// longer depends on the deprecated tokeninfo service and never leaks the
// token in a request URL.
func (s *GoogleAuthService) VerifyIDToken(ctx context.Context, idToken string) (*GoogleTokenInfo, error) {
// Call Google's tokeninfo endpoint to verify the token
url := fmt.Sprintf("%s?id_token=%s", googleTokenInfoURL, idToken)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to verify token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
// Parse the token header to get the key ID.
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return nil, ErrInvalidGoogleToken
}
var tokenInfo GoogleTokenInfo
if err := json.NewDecoder(resp.Body).Decode(&tokenInfo); err != nil {
return nil, fmt.Errorf("failed to decode token info: %w", err)
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode token header: %w", err)
}
var header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("failed to parse token header: %w", err)
}
// Verify the audience matches our client ID(s)
if !s.verifyAudience(tokenInfo.Aud, tokenInfo.Azp) {
publicKey, err := s.getPublicKey(ctx, header.Kid)
if err != nil {
return nil, err
}
// Parse and verify the signature. jwt v5 validates exp/iat/nbf automatically.
token, err := jwt.ParseWithClaims(idToken, &GoogleTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return publicKey, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrGoogleTokenExpired
}
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*GoogleTokenClaims)
if !ok || !token.Valid {
return nil, ErrInvalidGoogleToken
}
// Verify the issuer (audit C3).
if !googleIssuers[claims.Issuer] {
return nil, ErrInvalidGoogleIssuer
}
// Verify the audience matches one of our configured client IDs.
if !s.verifyAudience(claims.Audience, claims.Azp) {
return nil, ErrInvalidGoogleAudience
}
// Verify the token is not expired (tokeninfo endpoint already checks this,
// but we double-check for security)
if tokenInfo.Sub == "" {
if claims.Subject == "" {
return nil, ErrInvalidGoogleToken
}
return &tokenInfo, nil
emailVerified := "false"
if claims.EmailVerified {
emailVerified = "true"
}
aud := ""
if len(claims.Audience) > 0 {
aud = claims.Audience[0]
}
return &GoogleTokenInfo{
Sub: claims.Subject,
Email: claims.Email,
EmailVerified: emailVerified,
Name: claims.Name,
GivenName: claims.GivenName,
FamilyName: claims.FamilyName,
Picture: claims.Picture,
Aud: aud,
Azp: claims.Azp,
Iss: claims.Issuer,
}, nil
}
// verifyAudience checks if the token audience matches our client ID(s).
// In production (non-debug), an empty clientID causes verification to fail
// rather than silently bypassing the check.
func (s *GoogleAuthService) verifyAudience(aud, azp string) bool {
// verifyAudience checks the token audience against our configured client IDs.
// In production (non-debug) an empty client ID fails verification rather than
// silently bypassing the check.
func (s *GoogleAuthService) verifyAudience(audience jwt.ClaimStrings, azp string) bool {
clientID := s.config.GoogleAuth.ClientID
if clientID == "" {
if s.config.Server.Debug {
// In debug mode only, skip audience verification for local development
// In debug mode only, skip audience verification for local development.
return s.config.Server.Debug
}
candidates := []string{clientID}
if id := s.config.GoogleAuth.AndroidClientID; id != "" {
candidates = append(candidates, id)
}
if id := s.config.GoogleAuth.IOSClientID; id != "" {
candidates = append(candidates, id)
}
for _, want := range candidates {
if azp == want {
return true
}
// In production, missing client ID means we cannot verify the audience
return false
for _, aud := range audience {
if aud == want {
return true
}
}
}
// Check both aud and azp (Android vs iOS may use different values)
if aud == clientID || azp == clientID {
return true
}
// Also check Android client ID if configured
androidClientID := s.config.GoogleAuth.AndroidClientID
if androidClientID != "" && (aud == androidClientID || azp == androidClientID) {
return true
}
// Also check iOS client ID if configured
iosClientID := s.config.GoogleAuth.IOSClientID
if iosClientID != "" && (aud == iosClientID || azp == iosClientID) {
return true
}
return false
}
// getPublicKey returns the RSA public key for the given key ID, using a
// Redis-cached copy of Google's JWKS and re-fetching once on a cache miss
// (Google rotates signing keys roughly daily).
func (s *GoogleAuthService) getPublicKey(ctx context.Context, kid string) (*rsa.PublicKey, error) {
keys, err := s.getCachedKeys(ctx)
if err != nil || keys == nil {
keys, err = s.fetchGooglePublicKeys(ctx)
if err != nil {
return nil, err
}
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
// Cache miss for this kid — keys may have rotated; fetch fresh.
keys, err = s.fetchGooglePublicKeys(ctx)
if err != nil {
return nil, err
}
if pubKey, ok := keys[kid]; ok {
return pubKey, nil
}
return nil, ErrGoogleKeyNotFound
}
// getCachedKeys retrieves cached Google public keys from Redis.
func (s *GoogleAuthService) getCachedKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
if s.cache == nil {
return nil, nil
}
data, err := s.cache.GetString(ctx, googleKeysCacheKey)
if err != nil || data == "" {
return nil, nil
}
var jwks GoogleJWKS
if err := json.Unmarshal([]byte(data), &jwks); err != nil {
return nil, nil
}
return s.parseJWKS(&jwks), nil
}
// fetchGooglePublicKeys fetches Google's JWKS and caches it.
func (s *GoogleAuthService) fetchGooglePublicKeys(ctx context.Context) (map[string]*rsa.PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, googleKeysURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch Google keys: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Google keys endpoint returned status %d", resp.StatusCode)
}
var jwks GoogleJWKS
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode Google keys: %w", err)
}
if s.cache != nil {
keysJSON, _ := json.Marshal(jwks)
_ = s.cache.SetString(ctx, googleKeysCacheKey, string(keysJSON), googleKeysCacheTTL)
}
return s.parseJWKS(&jwks), nil
}
// parseJWKS converts Google's JWKS into a map of RSA public keys by key ID.
func (s *GoogleAuthService) parseJWKS(jwks *GoogleJWKS) map[string]*rsa.PublicKey {
keys := make(map[string]*rsa.PublicKey)
for _, key := range jwks.Keys {
if key.Kty != "RSA" {
continue
}
nBytes, err := base64.RawURLEncoding.DecodeString(key.N)
if err != nil {
continue
}
eBytes, err := base64.RawURLEncoding.DecodeString(key.E)
if err != nil {
continue
}
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
keys[key.Kid] = &rsa.PublicKey{N: new(big.Int).SetBytes(nBytes), E: e}
}
return keys
}
+49 -29
View File
@@ -68,13 +68,14 @@ type AppleTransactionInfo struct {
// AppleValidationResult contains the result of Apple receipt validation
type AppleValidationResult struct {
Valid bool
TransactionID string
ProductID string
ExpiresAt time.Time
IsTrialPeriod bool
AutoRenewEnabled bool
Environment string
Valid bool
TransactionID string
OriginalTransactionID string // stable across renewals — the replay key
ProductID string
ExpiresAt time.Time
IsTrialPeriod bool
AutoRenewEnabled bool
Environment string
}
// GoogleValidationResult contains the result of Google token validation
@@ -95,6 +96,21 @@ func NewAppleIAPClient(cfg config.AppleIAPConfig) (*AppleIAPClient, error) {
return nil, ErrIAPNotConfigured
}
// Audit H5 (relaxed per MEDIUM-2): refuse to load the IAP signing key from
// a world-accessible file — a leaked .p8 lets an attacker forge App Store
// Server API requests. The original "0600 or stricter" check is
// incompatible with a Kubernetes Secret volume: the kubelet widens secret
// files to 0440 once fsGroup is set, so 0600 is unattainable for a
// non-root container. Group access is scoped to the pod's fsGroup; the
// real exposure is the "other" bits, so reject only those.
info, err := os.Stat(cfg.KeyPath)
if err != nil {
return nil, fmt.Errorf("failed to stat Apple IAP key: %w", err)
}
if perm := info.Mode().Perm(); perm&0o007 != 0 {
return nil, fmt.Errorf("Apple IAP key %s is world-accessible (permissions %#o); remove other-rwx bits", cfg.KeyPath, perm)
}
// Read the private key
keyData, err := os.ReadFile(cfg.KeyPath)
if err != nil {
@@ -215,11 +231,12 @@ func (c *AppleIAPClient) ValidateTransaction(ctx context.Context, transactionID
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
return &AppleValidationResult{
Valid: true,
TransactionID: transactionInfo.TransactionID,
ProductID: transactionInfo.ProductID,
ExpiresAt: expiresAt,
Environment: transactionInfo.Environment,
Valid: true,
TransactionID: transactionInfo.TransactionID,
OriginalTransactionID: transactionInfo.OriginalTransactionID,
ProductID: transactionInfo.ProductID,
ExpiresAt: expiresAt,
Environment: transactionInfo.Environment,
}, nil
}
@@ -243,11 +260,12 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
if err == nil {
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
return &AppleValidationResult{
Valid: true,
TransactionID: transactionInfo.TransactionID,
ProductID: transactionInfo.ProductID,
ExpiresAt: expiresAt,
Environment: transactionInfo.Environment,
Valid: true,
TransactionID: transactionInfo.TransactionID,
OriginalTransactionID: transactionInfo.OriginalTransactionID,
ProductID: transactionInfo.ProductID,
ExpiresAt: expiresAt,
Environment: transactionInfo.Environment,
}, nil
}
}
@@ -317,11 +335,12 @@ func (c *AppleIAPClient) ValidateReceipt(ctx context.Context, receiptData string
expiresAt := time.Unix(transactionInfo.ExpiresDate/1000, 0)
return &AppleValidationResult{
Valid: true,
TransactionID: transactionInfo.TransactionID,
ProductID: transactionInfo.ProductID,
ExpiresAt: expiresAt,
Environment: transactionInfo.Environment,
Valid: true,
TransactionID: transactionInfo.TransactionID,
OriginalTransactionID: transactionInfo.OriginalTransactionID,
ProductID: transactionInfo.ProductID,
ExpiresAt: expiresAt,
Environment: transactionInfo.Environment,
}, nil
}
@@ -418,13 +437,14 @@ func (c *AppleIAPClient) validateLegacyReceiptWithSandbox(ctx context.Context, r
}
return &AppleValidationResult{
Valid: true,
TransactionID: latestReceipt.TransactionID,
ProductID: latestReceipt.ProductID,
ExpiresAt: expiresAt,
IsTrialPeriod: latestReceipt.IsTrialPeriod == "true",
AutoRenewEnabled: autoRenew,
Environment: legacyResponse.Environment,
Valid: true,
TransactionID: latestReceipt.TransactionID,
OriginalTransactionID: latestReceipt.OriginalTransactionID,
ProductID: latestReceipt.ProductID,
ExpiresAt: expiresAt,
IsTrialPeriod: latestReceipt.IsTrialPeriod == "true",
AutoRenewEnabled: autoRenew,
Environment: legacyResponse.Environment,
}, nil
}
+24 -2
View File
@@ -308,7 +308,18 @@ func (s *NotificationService) registerAPNSDevice(ctx context.Context, userID uin
// Check if device exists
existing, err := s.notificationRepo.WithContext(ctx).FindAPNSDeviceByToken(req.RegistrationID)
if err == nil {
// Update existing device
// Audit C8 / LOW-3: APNs device tokens are recycled across devices,
// app reinstalls and OS reassignments, so a token already bound to a
// different account is a stale binding — not a hijack. Reassign it to
// the current (authenticated) registrant rather than reject: a 409
// here would lock the legitimate new owner of a recycled token out of
// push entirely. The reassignment is logged as a security-relevant
// event so a genuine token-takeover attempt is still traceable.
if existing.UserID != nil && *existing.UserID != userID {
log.Warn().Uint("user_id", userID).Uint("previous_owner_id", *existing.UserID).
Msg("APNS device token reassigned to a new account")
}
// Update existing device — reassign to the current user
existing.UserID = &userID
existing.Active = true
existing.Name = req.Name
@@ -337,7 +348,18 @@ func (s *NotificationService) registerGCMDevice(ctx context.Context, userID uint
// Check if device exists
existing, err := s.notificationRepo.WithContext(ctx).FindGCMDeviceByToken(req.RegistrationID)
if err == nil {
// Update existing device
// Audit C8 / LOW-3: FCM device tokens are recycled across devices,
// app reinstalls and OS reassignments, so a token already bound to a
// different account is a stale binding — not a hijack. Reassign it to
// the current (authenticated) registrant rather than reject: a 409
// here would lock the legitimate new owner of a recycled token out of
// push entirely. The reassignment is logged as a security-relevant
// event so a genuine token-takeover attempt is still traceable.
if existing.UserID != nil && *existing.UserID != userID {
log.Warn().Uint("user_id", userID).Uint("previous_owner_id", *existing.UserID).
Msg("GCM device token reassigned to a new account")
}
// Update existing device — reassign to the current user
existing.UserID = &userID
existing.Active = true
existing.Name = req.Name
+7 -22
View File
@@ -559,30 +559,22 @@ func (s *ResidenceService) GenerateSharePackage(ctx context.Context, residenceID
}, nil
}
// JoinWithCode allows a user to join a residence using a share code
// JoinWithCode allows a user to join a residence using a share code.
// Audit C9/H9: the code lookup, membership add, and one-time-code
// deactivation run as a single locked transaction in the repository, so a
// code can never be redeemed twice and a deactivation failure aborts the join.
func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID uint) (*responses.JoinResidenceResponse, error) {
// Find the share code
shareCode, err := s.residenceRepo.WithContext(ctx).FindShareCodeByCode(code)
residenceID, alreadyMember, err := s.residenceRepo.WithContext(ctx).JoinWithShareCode(code, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.share_code_invalid")
}
return nil, apperrors.Internal(err)
}
// Check if already a member
hasAccess, err := s.residenceRepo.WithContext(ctx).HasAccess(shareCode.ResidenceID, userID)
if err != nil {
return nil, apperrors.Internal(err)
}
if hasAccess {
if alreadyMember {
return nil, apperrors.Conflict("error.user_already_member")
}
// Add user to residence
if err := s.residenceRepo.WithContext(ctx).AddUser(shareCode.ResidenceID, userID); err != nil {
return nil, apperrors.Internal(err)
}
if s.cache != nil {
// The joining user's residence-IDs cache is now stale, and their
// subscription status now reflects an extra residence with all of its
@@ -591,15 +583,8 @@ func (s *ResidenceService) JoinWithCode(ctx context.Context, code string, userID
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
}
// Mark share code as used (one-time use)
if err := s.residenceRepo.WithContext(ctx).DeactivateShareCode(shareCode.ID); err != nil {
// Log the error but don't fail the join - the user has already been added
// The code will just be usable by others until it expires
log.Error().Err(err).Uint("code_id", shareCode.ID).Msg("Failed to deactivate share code after join")
}
// Get the residence with full details
residence, err := s.residenceRepo.WithContext(ctx).FindByID(shareCode.ResidenceID)
residence, err := s.residenceRepo.WithContext(ctx).FindByID(residenceID)
if err != nil {
return nil, apperrors.Internal(err)
}
+80 -45
View File
@@ -399,99 +399,135 @@ func (s *SubscriptionService) GetActivePromotions(ctx context.Context, userID ui
return result, nil
}
// ProcessApplePurchase processes an Apple IAP purchase
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID)
// ProcessApplePurchase processes an Apple IAP purchase.
// Supports both StoreKit 1 (receiptData) and StoreKit 2 (transactionID).
func (s *SubscriptionService) ProcessApplePurchase(ctx context.Context, userID uint, receiptData string, transactionID string) (*SubscriptionResponse, error) {
// Store receipt/transaction data
dataToStore := receiptData
if dataToStore == "" {
dataToStore = transactionID
}
if err := s.subscriptionRepo.WithContext(ctx).UpdateReceiptData(userID, dataToStore); err != nil {
return nil, apperrors.Internal(err)
}
// Apple IAP client must be configured to validate purchases.
// Without server-side validation, we cannot trust client-provided receipts.
// Apple IAP client must be configured — without server-side validation
// we cannot trust client-provided receipts.
if s.appleClient == nil {
log.Error().Uint("user_id", userID).Msg("Apple IAP validation not configured, rejecting purchase")
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// Validation is a network call to Apple; detach from the request context
// so a client disconnect cannot abort an in-flight grant.
vctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var result *AppleValidationResult
var err error
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1)
// Prefer transaction ID (StoreKit 2), fall back to receipt data (StoreKit 1).
if transactionID != "" {
result, err = s.appleClient.ValidateTransaction(ctx, transactionID)
result, err = s.appleClient.ValidateTransaction(vctx, transactionID)
} else if receiptData != "" {
result, err = s.appleClient.ValidateReceipt(ctx, receiptData)
result, err = s.appleClient.ValidateReceipt(vctx, receiptData)
}
if err != nil {
// Validation failed -- do NOT fall through to grant Pro.
log.Error().Err(err).Uint("user_id", userID).Msg("Apple validation failed")
return nil, err
}
if result == nil {
return nil, apperrors.BadRequest("error.no_receipt_or_transaction")
}
// Audit C5/C10: replay protection. A validated transaction may only ever
// be bound to one account — re-submitting a valid receipt against a
// second account must not grant Pro for free. The partial unique index
// on apple_original_transaction_id is the backstop for the check/store
// race below.
if result.OriginalTransactionID != "" {
existing, lookupErr := s.subscriptionRepo.WithContext(vctx).FindByAppleOriginalTransactionID(result.OriginalTransactionID)
switch {
case lookupErr == nil && existing != nil && existing.UserID != userID:
log.Warn().Uint("user_id", userID).Uint("bound_user_id", existing.UserID).
Msg("Apple purchase rejected — transaction already claimed by another account")
return nil, apperrors.Forbidden("error.iap_transaction_already_claimed")
case lookupErr != nil && !errors.Is(lookupErr, gorm.ErrRecordNotFound):
return nil, apperrors.Internal(lookupErr)
}
}
// Persist the receipt blob and the replay key.
dataToStore := receiptData
if dataToStore == "" {
dataToStore = transactionID
}
if err := s.subscriptionRepo.WithContext(vctx).UpdateReceiptData(userID, dataToStore); err != nil {
return nil, apperrors.Internal(err)
}
if result.OriginalTransactionID != "" {
if err := s.subscriptionRepo.WithContext(vctx).UpdateAppleOriginalTransactionID(userID, result.OriginalTransactionID); err != nil {
// The unique index rejected the bind — a concurrent request
// claimed the same transaction first.
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to bind Apple transaction ID")
return nil, apperrors.Internal(err)
}
}
expiresAt := result.ExpiresAt
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Str("env", result.Environment).Msg("Apple purchase validated")
// Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
if err := s.subscriptionRepo.WithContext(vctx).UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return nil, apperrors.Internal(err)
}
// Tier flipped — drop cached SubscriptionStatusResponse so the next call
// returns Pro immediately instead of stale Free.
if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
_ = s.cache.InvalidateSubscriptionStatusForUsers(vctx, userID)
}
return s.GetSubscription(ctx, userID)
return s.GetSubscription(vctx, userID)
}
// ProcessGooglePurchase processes a Google Play purchase
// productID is optional but helps validate the specific subscription
func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
// Store purchase token first
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
return nil, apperrors.Internal(err)
}
// Google IAP client must be configured to validate purchases.
// Without server-side validation, we cannot trust client-provided tokens.
// Google IAP client must be configured — without server-side validation
// we cannot trust client-provided tokens.
if s.googleClient == nil {
log.Error().Uint("user_id", userID).Msg("Google IAP validation not configured, rejecting purchase")
return nil, apperrors.BadRequest("error.iap_validation_not_configured")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// Audit C6/C10: replay protection — a purchase token may only ever be
// bound to one account. The partial unique index on google_purchase_token
// is the backstop for the check/store race.
if purchaseToken != "" {
existing, lookupErr := s.subscriptionRepo.WithContext(ctx).FindByGoogleToken(purchaseToken)
switch {
case lookupErr == nil && existing != nil && existing.UserID != userID:
log.Warn().Uint("user_id", userID).Uint("bound_user_id", existing.UserID).
Msg("Google purchase rejected — token already claimed by another account")
return nil, apperrors.Forbidden("error.iap_transaction_already_claimed")
case lookupErr != nil && !errors.Is(lookupErr, gorm.ErrRecordNotFound):
return nil, apperrors.Internal(lookupErr)
}
}
// Store the purchase token (the replay key).
if err := s.subscriptionRepo.WithContext(ctx).UpdatePurchaseToken(userID, purchaseToken); err != nil {
return nil, apperrors.Internal(err)
}
// Validation is a network call; detach from the request context.
vctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var result *GoogleValidationResult
var err error
// If productID is provided, use it directly; otherwise try known IDs
// If productID is provided, use it directly; otherwise try known IDs.
if productID != "" {
result, err = s.googleClient.ValidateSubscription(ctx, productID, purchaseToken)
result, err = s.googleClient.ValidateSubscription(vctx, productID, purchaseToken)
} else {
result, err = s.googleClient.ValidatePurchaseToken(ctx, purchaseToken, KnownSubscriptionIDs)
result, err = s.googleClient.ValidatePurchaseToken(vctx, purchaseToken, KnownSubscriptionIDs)
}
if err != nil {
// Validation failed -- do NOT fall through to grant Pro.
log.Error().Err(err).Uint("user_id", userID).Msg("Google purchase validation failed")
return nil, err
}
if result == nil {
return nil, apperrors.BadRequest("error.no_purchase_token")
}
@@ -499,24 +535,23 @@ func (s *SubscriptionService) ProcessGooglePurchase(ctx context.Context, userID
expiresAt := result.ExpiresAt
log.Info().Uint("user_id", userID).Str("product", result.ProductID).Time("expires", result.ExpiresAt).Bool("auto_renew", result.AutoRenewing).Msg("Google purchase validated")
// Acknowledge the subscription if not already acknowledged
// Acknowledge the subscription if not already acknowledged.
if !result.AcknowledgedState {
if err := s.googleClient.AcknowledgeSubscription(ctx, result.ProductID, purchaseToken); err != nil {
if err := s.googleClient.AcknowledgeSubscription(vctx, result.ProductID, purchaseToken); err != nil {
log.Warn().Err(err).Uint("user_id", userID).Msg("Failed to acknowledge Google subscription")
// Don't fail the purchase, just log the warning
// Don't fail the purchase, just log the warning.
}
}
// Upgrade to Pro with the validated expiration
if err := s.subscriptionRepo.WithContext(ctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
if err := s.subscriptionRepo.WithContext(vctx).UpgradeToPro(userID, expiresAt, "android"); err != nil {
return nil, apperrors.Internal(err)
}
if s.cache != nil {
_ = s.cache.InvalidateSubscriptionStatusForUsers(ctx, userID)
_ = s.cache.InvalidateSubscriptionStatusForUsers(vctx, userID)
}
return s.GetSubscription(ctx, userID)
return s.GetSubscription(vctx, userID)
}
// CancelSubscription cancels a subscription (downgrades to free at end of period)