Compare commits

...

2 Commits

Author SHA1 Message Date
Trey t cb1dc383b4 tools: add admin-reset and notif-diag operational CLIs
Backend CI / Test (push) Has been cancelled
Backend CI / Contract Tests (push) Has been cancelled
Backend CI / Build (push) Has been cancelled
Backend CI / Lint (push) Has been cancelled
Backend CI / Secret Scanning (push) Has been cancelled
Two small Go CLIs for production ops that previously required ad-hoc
psql or kubectl gymnastics. Both load DB credentials from prod.env-style
env vars and read POSTGRES_PASSWORD from deploy/secrets/postgres_password.txt
by default, so the workflow is `set -a && source deploy/prod.env && set +a`
followed by go run.

cmd/admin-reset/main.go:
  --list                  print all admin_users rows
  --verify --email X      bcrypt-check a password against the stored hash
                          using the same case-insensitive lookup the live
                          /api/admin/auth/login endpoint uses
  --new-email Y           rename an admin's email (with unique-index check)
  default (--email X)     prompt for a new password twice (no echo, min 12
                          chars), bcrypt at DefaultCost, update the row

cmd/notif-diag/main.go:
  default                 print pending/sent counts, breakdown by type and
                          age, the 5 most recent pending rows with their
                          error_message, and registered APNs/FCM device
                          counts
  --mark-failed-as-sent   cosmetic cleanup — UPDATE pending rows that have
                          a recorded error to sent=true,
                          sent_at=COALESCE(updated_at, NOW())
  --yes                   skip the interactive confirmation prompt

Both bypass internal/config.Load() entirely so they don't need
SECRET_KEY or other unrelated env vars to run. .gitignore excludes the
build artifacts at /admin-reset and /notif-diag.

go.mod adds golang.org/x/term v0.41.0 (promoted from indirect to direct)
for no-echo password input in admin-reset.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-01 08:36:13 -07:00
Trey t 8fce568532 fix(config): replace sync.Once reset-from-Do with mutex
Load()'s validation-failure path reassigned cfgOnce = sync.Once{} from
inside Do(). When Do() returned and tried to unlock the original mutex,
the Once struct had already been replaced with a fresh one whose mutex
was unlocked, panicking with "sync: unlock of unlocked mutex" on every
boot where any required env var was missing or invalid.

Replaced the Once with a plain sync.Mutex around a nil-check on the
package-level cfg, building the candidate into a local first and only
assigning to cfg after validate() succeeds. Same caching semantics, no
race, and a failed Load() leaves cfg nil so the next caller retries
cleanly.

Also documented AppleAuthConfig.TeamID as currently dead — it's loaded
from APPLE_TEAM_ID but no service reads it. Wire-up point noted for
when Sign in with Apple revocation/refresh is added.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-01 08:35:54 -07:00
7 changed files with 759 additions and 160 deletions
+2
View File
@@ -8,6 +8,8 @@ bin/
/api
/worker
/admin
/admin-reset
/notif-diag
!admin/
*.exe
*.exe~
+257
View File
@@ -0,0 +1,257 @@
// admin-reset is a one-off CLI for resetting an admin_users row's password.
//
// It reads DB connection settings from environment variables (the same names
// the API uses), looks up the admin user by email, prompts for a new password
// twice (no echo), bcrypts it, and updates the row. Safe to keep in the repo
// — running it requires DB credentials.
//
// Usage:
//
// # load env (host, user, db, sslmode) and password from secrets file
// set -a && source deploy/prod.env && set +a
// go run ./cmd/admin-reset
//
// # or with a non-default secrets path / different admin
// go run ./cmd/admin-reset --password-file path/to/postgres_password.txt
// go run ./cmd/admin-reset --email someone@example.com
package main
import (
"bufio"
"errors"
"flag"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/bcrypt"
"golang.org/x/term"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/honeydue-api/internal/models"
)
const minPasswordLen = 12
func main() {
email := flag.String("email", "admin@myhoneydue.com", "Admin email to reset")
passwordFile := flag.String("password-file", "deploy/secrets/postgres_password.txt",
"Path to file containing POSTGRES_PASSWORD (used if env var is empty)")
list := flag.Bool("list", false, "List all rows in admin_users and exit (no changes)")
verify := flag.Bool("verify", false, "Prompt for a password and check it against the stored hash; no changes")
newEmail := flag.String("new-email", "", "If set: rename the matched admin's email to this value and exit (no password change)")
flag.Parse()
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339})
dsn, host, err := buildDSN(*passwordFile)
if err != nil {
log.Fatal().Err(err).Msg("failed to build database DSN")
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
log.Fatal().Err(err).Msg("failed to connect to database")
}
if *list {
var admins []models.AdminUser
if err := db.Order("id").Find(&admins).Error; err != nil {
log.Fatal().Err(err).Msg("failed to list admin users")
}
fmt.Fprintf(os.Stderr, "DB host: %s\n%d admin user(s):\n\n", host, len(admins))
fmt.Fprintf(os.Stderr, "%-4s %-40s %-12s %-6s %s\n", "ID", "EMAIL", "ROLE", "ACTIVE", "LAST_LOGIN")
for _, a := range admins {
last := "-"
if a.LastLogin != nil {
last = a.LastLogin.Format(time.RFC3339)
}
fmt.Fprintf(os.Stderr, "%-4d %-40s %-12s %-6t %s\n", a.ID, a.Email, a.Role, a.IsActive, last)
}
return
}
// Mirror the live API's case-insensitive lookup so --verify reflects what
// /api/admin/auth/login actually does. The reset path uses the same query
// for consistency.
var admin models.AdminUser
if err := db.Where("LOWER(email) = LOWER(?)", *email).First(&admin).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Fatal().Str("email", *email).Msg("admin user not found (try --list to see existing rows)")
}
log.Fatal().Err(err).Msg("failed to look up admin user")
}
if *newEmail != "" {
target := strings.TrimSpace(*newEmail)
if target == "" || !strings.Contains(target, "@") {
log.Fatal().Str("new_email", *newEmail).Msg("--new-email must be a valid email address")
}
if strings.EqualFold(target, admin.Email) {
fmt.Fprintf(os.Stderr, "No change — current email already matches %q\n", target)
return
}
// Catch the unique-index conflict early with a clear message instead of a Postgres error.
var collisionCount int64
if err := db.Model(&models.AdminUser{}).
Where("LOWER(email) = LOWER(?) AND id <> ?", target, admin.ID).
Count(&collisionCount).Error; err != nil {
log.Fatal().Err(err).Msg("failed to check for email collision")
}
if collisionCount > 0 {
log.Fatal().Str("new_email", target).Msg("another admin row already uses this email — aborting")
}
fmt.Fprintf(os.Stderr, "Renaming admin email: %s → %s (id=%d)\n", admin.Email, target, admin.ID)
fmt.Fprintf(os.Stderr, "DB host: %s\n\n", host)
res := db.Model(&models.AdminUser{}).
Where("id = ?", admin.ID).
Updates(map[string]any{
"email": target,
"updated_at": time.Now().UTC(),
})
if res.Error != nil {
log.Fatal().Err(res.Error).Msg("failed to rename admin email")
}
if res.RowsAffected != 1 {
log.Fatal().Int64("rows", res.RowsAffected).Msg("expected exactly 1 row updated")
}
fmt.Fprintf(os.Stderr, "OK — email is now %s\n", target)
return
}
if *verify {
fmt.Fprintf(os.Stderr, "Verifying password for: %s (id=%d, role=%s, active=%t)\n",
admin.Email, admin.ID, admin.Role, admin.IsActive)
fmt.Fprintf(os.Stderr, "DB host: %s\n\n", host)
pw, err := readPassword("Password: ")
if err != nil {
log.Fatal().Err(err).Msg("failed to read password")
}
if admin.CheckPassword(pw) {
fmt.Fprintln(os.Stderr, "PASS — bcrypt hash matches the supplied password")
if !admin.IsActive {
fmt.Fprintln(os.Stderr, "WARNING: is_active = false — login will still be rejected with \"Account is disabled\"")
}
} else {
fmt.Fprintln(os.Stderr, "FAIL — bcrypt hash does NOT match the supplied password")
os.Exit(1)
}
return
}
fmt.Fprintf(os.Stderr, "Resetting password for: %s (id=%d, role=%s, active=%t)\n",
admin.Email, admin.ID, admin.Role, admin.IsActive)
fmt.Fprintf(os.Stderr, "DB host: %s\n\n", host)
pw1, err := readPassword("New password: ")
if err != nil {
log.Fatal().Err(err).Msg("failed to read password")
}
if len(pw1) < minPasswordLen {
log.Fatal().Int("min", minPasswordLen).Msg("password too short")
}
pw2, err := readPassword("Confirm password: ")
if err != nil {
log.Fatal().Err(err).Msg("failed to read password")
}
if pw1 != pw2 {
log.Fatal().Msg("passwords do not match")
}
hash, err := bcrypt.GenerateFromPassword([]byte(pw1), bcrypt.DefaultCost)
if err != nil {
log.Fatal().Err(err).Msg("failed to hash password")
}
res := db.Model(&models.AdminUser{}).
Where("id = ?", admin.ID).
Updates(map[string]any{
"password": string(hash),
"updated_at": time.Now().UTC(),
})
if res.Error != nil {
log.Fatal().Err(res.Error).Msg("failed to update admin user")
}
if res.RowsAffected != 1 {
log.Fatal().Int64("rows", res.RowsAffected).Msg("expected exactly 1 row updated")
}
fmt.Fprintf(os.Stderr, "\nOK — password reset for %s\n", admin.Email)
}
func buildDSN(passwordFile string) (dsn, host string, err error) {
host = os.Getenv("DB_HOST")
user := os.Getenv("POSTGRES_USER")
dbname := os.Getenv("POSTGRES_DB")
sslmode := os.Getenv("DB_SSLMODE")
if sslmode == "" {
sslmode = "require"
}
port := 5432
if s := os.Getenv("DB_PORT"); s != "" {
p, perr := strconv.Atoi(s)
if perr != nil {
return "", "", fmt.Errorf("invalid DB_PORT %q: %w", s, perr)
}
port = p
}
password := os.Getenv("POSTGRES_PASSWORD")
if password == "" && passwordFile != "" {
b, rerr := os.ReadFile(passwordFile)
if rerr != nil {
return "", "", fmt.Errorf("POSTGRES_PASSWORD not set and could not read %s: %w", passwordFile, rerr)
}
password = strings.TrimRight(string(b), "\r\n")
}
missing := []string{}
if host == "" {
missing = append(missing, "DB_HOST")
}
if user == "" {
missing = append(missing, "POSTGRES_USER")
}
if dbname == "" {
missing = append(missing, "POSTGRES_DB")
}
if password == "" {
missing = append(missing, "POSTGRES_PASSWORD")
}
if len(missing) > 0 {
return "", "", fmt.Errorf("missing required env vars: %s", strings.Join(missing, ", "))
}
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
host, port, user, password, dbname, sslmode)
return dsn, host, nil
}
func readPassword(prompt string) (string, error) {
fmt.Fprint(os.Stderr, prompt)
if term.IsTerminal(int(os.Stdin.Fd())) {
b, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Fprintln(os.Stderr)
if err != nil {
return "", err
}
return strings.TrimRight(string(b), "\r\n"), nil
}
s, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", err
}
return strings.TrimRight(s, "\r\n"), nil
}
+333
View File
@@ -0,0 +1,333 @@
// notif-diag is a CLI for inspecting and (optionally) cleaning up stuck
// notification rows. Default mode is read-only — runs SELECTs and prints a
// summary. With --mark-failed-as-sent, marks pending rows that already have a
// recorded error as sent (cosmetic — no retry, no resend).
//
// Usage:
//
// set -a && source deploy/prod.env && set +a
// go run ./cmd/notif-diag # diagnose
// go run ./cmd/notif-diag --mark-failed-as-sent --yes # clean up errored backlog
package main
import (
"bufio"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func main() {
passwordFile := stringFlag("password-file", "deploy/secrets/postgres_password.txt",
"Path to file containing POSTGRES_PASSWORD (used if env var is empty)")
markFailed := boolFlag("mark-failed-as-sent",
"Mark every pending row with a non-empty error_message as sent. Cosmetic only — does not retry the push.")
yes := boolFlag("yes", "Skip the interactive confirmation prompt for destructive actions.")
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339})
dsn, host, err := buildDSN(*passwordFile)
if err != nil {
log.Fatal().Err(err).Msg("failed to build database DSN")
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
log.Fatal().Err(err).Msg("failed to connect to database")
}
fmt.Printf("DB host: %s\n", host)
fmt.Println(strings.Repeat("=", 80))
overallTotals(db)
pendingByType(db)
recentPending(db)
deviceCounts(db)
if *markFailed {
markFailedAsSent(db, *yes)
}
}
// markFailedAsSent updates pending rows whose error_message is non-empty,
// flipping them to sent=true with sent_at=updated_at. This is purely cosmetic:
// it removes them from the "pending" count so dashboards and the diag tool
// don't keep flagging an old, unfixable backlog. It does NOT re-send anything.
func markFailedAsSent(db *gorm.DB, skipPrompt bool) {
var candidate int64
if err := db.Raw(`
SELECT COUNT(*) FROM notifications_notification
WHERE sent = false AND error_message IS NOT NULL AND error_message <> ''
`).Scan(&candidate).Error; err != nil {
log.Fatal().Err(err).Msg("failed to count cleanup candidates")
}
fmt.Printf("\n# Cleanup candidate count: %d\n", candidate)
if candidate == 0 {
fmt.Println(" (nothing to clean up)")
return
}
fmt.Println(" These rows have a recorded send error and will never be retried.")
fmt.Println(" Marking them sent=true is cosmetic — it just prevents them from")
fmt.Println(" showing up as pending in admin dashboards going forward.")
if !skipPrompt {
fmt.Printf("\nProceed? Type 'yes' to update %d rows: ", candidate)
s, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
log.Fatal().Err(err).Msg("failed to read confirmation")
}
if strings.TrimSpace(s) != "yes" {
fmt.Println("Aborted.")
return
}
}
res := db.Exec(`
UPDATE notifications_notification
SET sent = true, sent_at = COALESCE(updated_at, NOW())
WHERE sent = false AND error_message IS NOT NULL AND error_message <> ''
`)
if res.Error != nil {
log.Fatal().Err(res.Error).Msg("failed to update rows")
}
fmt.Printf("OK — updated %d rows.\n", res.RowsAffected)
}
// overallTotals shows the high-level sent/pending/read split.
func overallTotals(db *gorm.DB) {
type row struct {
Total int64
Sent int64
Pending int64
Read int64
Errored int64
}
var r row
db.Raw(`
SELECT
COUNT(*) AS total,
COUNT(*) FILTER (WHERE sent = true) AS sent,
COUNT(*) FILTER (WHERE sent = false) AS pending,
COUNT(*) FILTER (WHERE read = true) AS read,
COUNT(*) FILTER (WHERE error_message IS NOT NULL AND error_message <> '') AS errored
FROM notifications_notification
`).Scan(&r)
fmt.Println("\n# Overall notification counts")
fmt.Printf(" total: %d\n", r.Total)
fmt.Printf(" sent: %d\n", r.Sent)
fmt.Printf(" pending: %d\n", r.Pending)
fmt.Printf(" read: %d\n", r.Read)
fmt.Printf(" errored: %d (rows with non-empty error_message)\n", r.Errored)
}
// pendingByType breaks the pending rows down by type and age.
func pendingByType(db *gorm.DB) {
type row struct {
NotificationType string
PendingCount int64
Oldest *time.Time
Newest *time.Time
WithErrors int64
Last24h int64
Last7d int64
}
var rows []row
db.Raw(`
SELECT
notification_type,
COUNT(*) AS pending_count,
MIN(created_at) AS oldest,
MAX(created_at) AS newest,
COUNT(*) FILTER (WHERE error_message IS NOT NULL AND error_message <> '') AS with_errors,
COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '24 hours') AS last_24h,
COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '7 days') AS last_7d
FROM notifications_notification
WHERE sent = false
GROUP BY notification_type
ORDER BY MAX(created_at) DESC NULLS LAST
`).Scan(&rows)
fmt.Println("\n# Pending rows by type")
if len(rows) == 0 {
fmt.Println(" (no pending notifications)")
return
}
fmt.Printf(" %-22s %7s %7s %7s %7s %-19s %-19s\n",
"TYPE", "PENDING", "ERRORED", "LAST24H", "LAST7D", "OLDEST", "NEWEST")
for _, r := range rows {
fmt.Printf(" %-22s %7d %7d %7d %7d %-19s %-19s\n",
r.NotificationType, r.PendingCount, r.WithErrors, r.Last24h, r.Last7d,
fmtTime(r.Oldest), fmtTime(r.Newest))
}
}
// recentPending shows the 5 most recent pending rows with full detail.
func recentPending(db *gorm.DB) {
type row struct {
ID uint
UserID uint
NotificationType string
Title string
Body string
ErrorMessage string
CreatedAt time.Time
}
var rows []row
db.Raw(`
SELECT id, user_id, notification_type, title, body, COALESCE(error_message, '') AS error_message, created_at
FROM notifications_notification
WHERE sent = false
ORDER BY created_at DESC
LIMIT 5
`).Scan(&rows)
fmt.Println("\n# 5 most recent pending notifications")
if len(rows) == 0 {
fmt.Println(" (none)")
return
}
for _, r := range rows {
errPart := ""
if r.ErrorMessage != "" {
errPart = fmt.Sprintf("\n error: %s", r.ErrorMessage)
}
fmt.Printf(" [%d] user=%d %s %s%s\n title: %s\n body: %s\n",
r.ID, r.UserID, r.CreatedAt.Format("2006-01-02 15:04:05"), r.NotificationType, errPart,
truncate(r.Title, 100), truncate(r.Body, 100))
}
}
// deviceCounts shows how many push devices are registered (active vs inactive).
func deviceCounts(db *gorm.DB) {
type row struct {
Total int64
Active int64
WithUser int64
DistinctUsers int64
}
fmt.Println("\n# Registered push devices")
for _, t := range []struct {
label string
table string
}{
{"APNs (iOS)", "push_notifications_apnsdevice"},
{"GCM (Android)", "push_notifications_gcmdevice"},
} {
var r row
err := db.Raw(fmt.Sprintf(`
SELECT
COUNT(*) AS total,
COUNT(*) FILTER (WHERE active = true) AS active,
COUNT(*) FILTER (WHERE user_id IS NOT NULL) AS with_user,
COUNT(DISTINCT user_id) AS distinct_users
FROM %s
`, t.table)).Scan(&r).Error
if err != nil {
fmt.Printf(" %-15s ERROR: %v\n", t.label, err)
continue
}
fmt.Printf(" %-15s total=%-5d active=%-5d with_user=%-5d distinct_users=%d\n",
t.label, r.Total, r.Active, r.WithUser, r.DistinctUsers)
}
}
func buildDSN(passwordFile string) (dsn, host string, err error) {
host = os.Getenv("DB_HOST")
user := os.Getenv("POSTGRES_USER")
dbname := os.Getenv("POSTGRES_DB")
sslmode := os.Getenv("DB_SSLMODE")
if sslmode == "" {
sslmode = "require"
}
port := 5432
if s := os.Getenv("DB_PORT"); s != "" {
p, perr := strconv.Atoi(s)
if perr != nil {
return "", "", fmt.Errorf("invalid DB_PORT %q: %w", s, perr)
}
port = p
}
password := os.Getenv("POSTGRES_PASSWORD")
if password == "" && passwordFile != "" {
b, rerr := os.ReadFile(passwordFile)
if rerr != nil {
return "", "", fmt.Errorf("POSTGRES_PASSWORD not set and could not read %s: %w", passwordFile, rerr)
}
password = strings.TrimRight(string(b), "\r\n")
}
missing := []string{}
if host == "" {
missing = append(missing, "DB_HOST")
}
if user == "" {
missing = append(missing, "POSTGRES_USER")
}
if dbname == "" {
missing = append(missing, "POSTGRES_DB")
}
if password == "" {
missing = append(missing, "POSTGRES_PASSWORD")
}
if len(missing) > 0 {
return "", "", fmt.Errorf("missing required env vars: %s", strings.Join(missing, ", "))
}
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
host, port, user, password, dbname, sslmode)
return dsn, host, nil
}
// stringFlag is a tiny stand-in for flag.String to keep imports lean — using it
// also dodges flag-package quirks when this file is rebuilt with go run.
func stringFlag(name, def, _usage string) *string {
v := def
prefix := "--" + name + "="
for _, a := range os.Args[1:] {
if strings.HasPrefix(a, prefix) {
v = strings.TrimPrefix(a, prefix)
}
}
return &v
}
// boolFlag is true if --name is present in os.Args (no value form).
func boolFlag(name, _usage string) *bool {
want := "--" + name
v := false
for _, a := range os.Args[1:] {
if a == want {
v = true
}
}
return &v
}
func fmtTime(t *time.Time) string {
if t == nil {
return "-"
}
return t.Format("2006-01-02 15:04:05")
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "…"
}
+1
View File
@@ -29,6 +29,7 @@ require (
go.opentelemetry.io/otel/sdk v1.43.0
golang.org/x/crypto v0.49.0
golang.org/x/oauth2 v0.35.0
golang.org/x/term v0.41.0
golang.org/x/text v0.35.0
golang.org/x/time v0.15.0
google.golang.org/api v0.257.0
+2
View File
@@ -266,6 +266,8 @@ golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
+162 -158
View File
@@ -89,8 +89,12 @@ type PushConfig struct {
}
type AppleAuthConfig struct {
ClientID string // Bundle ID (e.g., com.tt.honeyDue.honeyDueDev)
TeamID string // Apple Developer Team ID
ClientID string // Bundle ID, used as the `aud` claim in Sign in with Apple identity tokens
// TeamID is currently unused — services/apple_auth.go validates identity tokens
// against ClientID + Apple's JWKS only, with no server-to-server REST calls.
// Wire this in if/when token revocation or refresh-token exchange is added,
// since both require signing a client_secret JWT with team_id + key_id.
TeamID string
}
type GoogleAuthConfig struct {
@@ -178,8 +182,8 @@ type FeatureFlags struct {
}
var (
cfg *Config
cfgOnce sync.Once
cfg *Config
cfgMu sync.Mutex
)
// knownWeakSecretKeys contains well-known default or placeholder secret keys
@@ -192,163 +196,163 @@ var knownWeakSecretKeys = map[string]bool{
"change-me-in-production-secret-key-12345": true,
}
// Load reads configuration from environment variables
// Load reads configuration from environment variables.
//
// Caches the result so repeated calls are cheap. On validation failure, the
// cache stays nil so a subsequent call (after env is corrected) can retry. The
// previous implementation used sync.Once with an in-Do reset of the Once
// itself, which races and panics with "sync: unlock of unlocked mutex".
func Load() (*Config, error) {
var loadErr error
cfgOnce.Do(func() {
viper.SetEnvPrefix("")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Set defaults
setDefaults()
// Parse DATABASE_URL if set (Dokku-style)
dbConfig := DatabaseConfig{
Host: viper.GetString("DB_HOST"),
Port: viper.GetInt("DB_PORT"),
User: viper.GetString("POSTGRES_USER"),
Password: viper.GetString("POSTGRES_PASSWORD"),
Database: viper.GetString("POSTGRES_DB"),
SSLMode: viper.GetString("DB_SSLMODE"),
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
MaxIdleTime: viper.GetDuration("DB_MAX_IDLE_TIME"),
}
// Override with DATABASE_URL if present (F-16: log warning on parse failure)
if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" {
parsed, err := parseDatabaseURL(databaseURL)
if err != nil {
maskedURL := MaskURLCredentials(databaseURL)
fmt.Printf("WARNING: Failed to parse DATABASE_URL (%s): %v — falling back to individual DB_* env vars\n", maskedURL, err)
} else {
dbConfig.Host = parsed.Host
dbConfig.Port = parsed.Port
dbConfig.User = parsed.User
dbConfig.Password = parsed.Password
dbConfig.Database = parsed.Database
if parsed.SSLMode != "" {
dbConfig.SSLMode = parsed.SSLMode
}
}
}
cfg = &Config{
Server: ServerConfig{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
DebugFixedCodes: viper.GetBool("DEBUG_FIXED_CODES"),
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")),
Timezone: viper.GetString("TIMEZONE"),
StaticDir: viper.GetString("STATIC_DIR"),
BaseURL: viper.GetString("BASE_URL"),
},
Database: dbConfig,
Redis: RedisConfig{
URL: viper.GetString("REDIS_URL"),
Password: viper.GetString("REDIS_PASSWORD"),
DB: viper.GetInt("REDIS_DB"),
},
Email: EmailConfig{
Host: viper.GetString("EMAIL_HOST"),
Port: viper.GetInt("EMAIL_PORT"),
User: viper.GetString("EMAIL_HOST_USER"),
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
From: viper.GetString("DEFAULT_FROM_EMAIL"),
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
},
Push: PushConfig{
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
APNSTopic: viper.GetString("APNS_TOPIC"),
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
APNSProduction: viper.GetBool("APNS_PRODUCTION"),
FCMProjectID: viper.GetString("FCM_PROJECT_ID"),
FCMServiceAccountPath: viper.GetString("FCM_SERVICE_ACCOUNT_PATH"),
FCMServiceAccountJSON: viper.GetString("FCM_SERVICE_ACCOUNT_JSON"),
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
},
Worker: WorkerConfig{
TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"),
OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"),
DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"),
},
Security: SecurityConfig{
SecretKey: viper.GetString("SECRET_KEY"),
TokenCacheTTL: 5 * time.Minute,
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
},
Storage: StorageConfig{
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
BaseURL: viper.GetString("STORAGE_BASE_URL"),
S3Endpoint: viper.GetString("B2_ENDPOINT"),
S3KeyID: viper.GetString("B2_KEY_ID"),
S3AppKey: viper.GetString("B2_APP_KEY"),
S3Bucket: viper.GetString("B2_BUCKET_NAME"),
S3UseSSL: viper.GetString("STORAGE_USE_SSL") == "" || viper.GetBool("STORAGE_USE_SSL"),
S3Region: viper.GetString("B2_REGION"),
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
EncryptionKey: viper.GetString("STORAGE_ENCRYPTION_KEY"),
},
AppleAuth: AppleAuthConfig{
ClientID: viper.GetString("APPLE_CLIENT_ID"),
TeamID: viper.GetString("APPLE_TEAM_ID"),
},
GoogleAuth: GoogleAuthConfig{
ClientID: viper.GetString("GOOGLE_CLIENT_ID"),
AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"),
IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"),
},
AppleIAP: AppleIAPConfig{
KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"),
KeyID: viper.GetString("APPLE_IAP_KEY_ID"),
IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"),
BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"),
Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"),
},
GoogleIAP: GoogleIAPConfig{
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
},
Stripe: StripeConfig{
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
},
Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
},
}
// Validate required fields
if err := validate(cfg); err != nil {
loadErr = err
// Reset so a subsequent call can retry after env is fixed
cfg = nil
cfgOnce = sync.Once{}
}
})
if loadErr != nil {
return nil, loadErr
cfgMu.Lock()
defer cfgMu.Unlock()
if cfg != nil {
return cfg, nil
}
viper.SetEnvPrefix("")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Set defaults
setDefaults()
// Parse DATABASE_URL if set (Dokku-style)
dbConfig := DatabaseConfig{
Host: viper.GetString("DB_HOST"),
Port: viper.GetInt("DB_PORT"),
User: viper.GetString("POSTGRES_USER"),
Password: viper.GetString("POSTGRES_PASSWORD"),
Database: viper.GetString("POSTGRES_DB"),
SSLMode: viper.GetString("DB_SSLMODE"),
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
MaxIdleTime: viper.GetDuration("DB_MAX_IDLE_TIME"),
}
// Override with DATABASE_URL if present (F-16: log warning on parse failure)
if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" {
parsed, err := parseDatabaseURL(databaseURL)
if err != nil {
maskedURL := MaskURLCredentials(databaseURL)
fmt.Printf("WARNING: Failed to parse DATABASE_URL (%s): %v — falling back to individual DB_* env vars\n", maskedURL, err)
} else {
dbConfig.Host = parsed.Host
dbConfig.Port = parsed.Port
dbConfig.User = parsed.User
dbConfig.Password = parsed.Password
dbConfig.Database = parsed.Database
if parsed.SSLMode != "" {
dbConfig.SSLMode = parsed.SSLMode
}
}
}
c := &Config{
Server: ServerConfig{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
DebugFixedCodes: viper.GetBool("DEBUG_FIXED_CODES"),
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
CorsAllowedOrigins: parseCorsOrigins(viper.GetString("CORS_ALLOWED_ORIGINS")),
Timezone: viper.GetString("TIMEZONE"),
StaticDir: viper.GetString("STATIC_DIR"),
BaseURL: viper.GetString("BASE_URL"),
},
Database: dbConfig,
Redis: RedisConfig{
URL: viper.GetString("REDIS_URL"),
Password: viper.GetString("REDIS_PASSWORD"),
DB: viper.GetInt("REDIS_DB"),
},
Email: EmailConfig{
Host: viper.GetString("EMAIL_HOST"),
Port: viper.GetInt("EMAIL_PORT"),
User: viper.GetString("EMAIL_HOST_USER"),
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
From: viper.GetString("DEFAULT_FROM_EMAIL"),
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
},
Push: PushConfig{
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
APNSTopic: viper.GetString("APNS_TOPIC"),
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
APNSProduction: viper.GetBool("APNS_PRODUCTION"),
FCMProjectID: viper.GetString("FCM_PROJECT_ID"),
FCMServiceAccountPath: viper.GetString("FCM_SERVICE_ACCOUNT_PATH"),
FCMServiceAccountJSON: viper.GetString("FCM_SERVICE_ACCOUNT_JSON"),
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
},
Worker: WorkerConfig{
TaskReminderHour: viper.GetInt("TASK_REMINDER_HOUR"),
OverdueReminderHour: viper.GetInt("OVERDUE_REMINDER_HOUR"),
DailyNotifHour: viper.GetInt("DAILY_DIGEST_HOUR"),
},
Security: SecurityConfig{
SecretKey: viper.GetString("SECRET_KEY"),
TokenCacheTTL: 5 * time.Minute,
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
TokenExpiryDays: viper.GetInt("TOKEN_EXPIRY_DAYS"),
TokenRefreshDays: viper.GetInt("TOKEN_REFRESH_DAYS"),
},
Storage: StorageConfig{
UploadDir: viper.GetString("STORAGE_UPLOAD_DIR"),
BaseURL: viper.GetString("STORAGE_BASE_URL"),
S3Endpoint: viper.GetString("B2_ENDPOINT"),
S3KeyID: viper.GetString("B2_KEY_ID"),
S3AppKey: viper.GetString("B2_APP_KEY"),
S3Bucket: viper.GetString("B2_BUCKET_NAME"),
S3UseSSL: viper.GetString("STORAGE_USE_SSL") == "" || viper.GetBool("STORAGE_USE_SSL"),
S3Region: viper.GetString("B2_REGION"),
MaxFileSize: viper.GetInt64("STORAGE_MAX_FILE_SIZE"),
AllowedTypes: viper.GetString("STORAGE_ALLOWED_TYPES"),
EncryptionKey: viper.GetString("STORAGE_ENCRYPTION_KEY"),
},
AppleAuth: AppleAuthConfig{
ClientID: viper.GetString("APPLE_CLIENT_ID"),
TeamID: viper.GetString("APPLE_TEAM_ID"),
},
GoogleAuth: GoogleAuthConfig{
ClientID: viper.GetString("GOOGLE_CLIENT_ID"),
AndroidClientID: viper.GetString("GOOGLE_ANDROID_CLIENT_ID"),
IOSClientID: viper.GetString("GOOGLE_IOS_CLIENT_ID"),
},
AppleIAP: AppleIAPConfig{
KeyPath: viper.GetString("APPLE_IAP_KEY_PATH"),
KeyID: viper.GetString("APPLE_IAP_KEY_ID"),
IssuerID: viper.GetString("APPLE_IAP_ISSUER_ID"),
BundleID: viper.GetString("APPLE_IAP_BUNDLE_ID"),
Sandbox: viper.GetBool("APPLE_IAP_SANDBOX"),
},
GoogleIAP: GoogleIAPConfig{
ServiceAccountPath: viper.GetString("GOOGLE_IAP_SERVICE_ACCOUNT_PATH"),
PackageName: viper.GetString("GOOGLE_IAP_PACKAGE_NAME"),
},
Stripe: StripeConfig{
SecretKey: viper.GetString("STRIPE_SECRET_KEY"),
WebhookSecret: viper.GetString("STRIPE_WEBHOOK_SECRET"),
PriceMonthly: viper.GetString("STRIPE_PRICE_MONTHLY"),
PriceYearly: viper.GetString("STRIPE_PRICE_YEARLY"),
},
Features: FeatureFlags{
PushEnabled: viper.GetBool("FEATURE_PUSH_ENABLED"),
EmailEnabled: viper.GetBool("FEATURE_EMAIL_ENABLED"),
WebhooksEnabled: viper.GetBool("FEATURE_WEBHOOKS_ENABLED"),
OnboardingEmailsEnabled: viper.GetBool("FEATURE_ONBOARDING_EMAILS_ENABLED"),
PDFReportsEnabled: viper.GetBool("FEATURE_PDF_REPORTS_ENABLED"),
WorkerEnabled: viper.GetBool("FEATURE_WORKER_ENABLED"),
},
}
if err := validate(c); err != nil {
// Leave cfg nil so the next Load() retries after env is corrected.
return nil, err
}
cfg = c
return cfg, nil
}
+2 -2
View File
@@ -1,7 +1,6 @@
package config
import (
"sync"
"testing"
"github.com/spf13/viper"
@@ -11,8 +10,9 @@ import (
// resetConfigState resets the package-level singleton so each test starts fresh.
func resetConfigState() {
cfgMu.Lock()
cfg = nil
cfgOnce = sync.Once{}
cfgMu.Unlock()
viper.Reset()
}