diff --git a/.gitignore b/.gitignore index 2436489..620f8c7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ bin/ /api /worker /admin +/admin-reset +/notif-diag !admin/ *.exe *.exe~ diff --git a/cmd/admin-reset/main.go b/cmd/admin-reset/main.go new file mode 100644 index 0000000..dfe79ea --- /dev/null +++ b/cmd/admin-reset/main.go @@ -0,0 +1,257 @@ +// admin-reset is a one-off CLI for resetting an admin_users row's password. +// +// It reads DB connection settings from environment variables (the same names +// the API uses), looks up the admin user by email, prompts for a new password +// twice (no echo), bcrypts it, and updates the row. Safe to keep in the repo +// — running it requires DB credentials. +// +// Usage: +// +// # load env (host, user, db, sslmode) and password from secrets file +// set -a && source deploy/prod.env && set +a +// go run ./cmd/admin-reset +// +// # or with a non-default secrets path / different admin +// go run ./cmd/admin-reset --password-file path/to/postgres_password.txt +// go run ./cmd/admin-reset --email someone@example.com +package main + +import ( + "bufio" + "errors" + "flag" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "golang.org/x/crypto/bcrypt" + "golang.org/x/term" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" + + "github.com/treytartt/honeydue-api/internal/models" +) + +const minPasswordLen = 12 + +func main() { + email := flag.String("email", "admin@myhoneydue.com", "Admin email to reset") + passwordFile := flag.String("password-file", "deploy/secrets/postgres_password.txt", + "Path to file containing POSTGRES_PASSWORD (used if env var is empty)") + list := flag.Bool("list", false, "List all rows in admin_users and exit (no changes)") + verify := flag.Bool("verify", false, "Prompt for a password and check it against the stored hash; no changes") + newEmail := flag.String("new-email", "", "If set: rename the matched admin's email to this value and exit (no password change)") + flag.Parse() + + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}) + + dsn, host, err := buildDSN(*passwordFile) + if err != nil { + log.Fatal().Err(err).Msg("failed to build database DSN") + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + log.Fatal().Err(err).Msg("failed to connect to database") + } + + if *list { + var admins []models.AdminUser + if err := db.Order("id").Find(&admins).Error; err != nil { + log.Fatal().Err(err).Msg("failed to list admin users") + } + fmt.Fprintf(os.Stderr, "DB host: %s\n%d admin user(s):\n\n", host, len(admins)) + fmt.Fprintf(os.Stderr, "%-4s %-40s %-12s %-6s %s\n", "ID", "EMAIL", "ROLE", "ACTIVE", "LAST_LOGIN") + for _, a := range admins { + last := "-" + if a.LastLogin != nil { + last = a.LastLogin.Format(time.RFC3339) + } + fmt.Fprintf(os.Stderr, "%-4d %-40s %-12s %-6t %s\n", a.ID, a.Email, a.Role, a.IsActive, last) + } + return + } + + // Mirror the live API's case-insensitive lookup so --verify reflects what + // /api/admin/auth/login actually does. The reset path uses the same query + // for consistency. + var admin models.AdminUser + if err := db.Where("LOWER(email) = LOWER(?)", *email).First(&admin).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Fatal().Str("email", *email).Msg("admin user not found (try --list to see existing rows)") + } + log.Fatal().Err(err).Msg("failed to look up admin user") + } + + if *newEmail != "" { + target := strings.TrimSpace(*newEmail) + if target == "" || !strings.Contains(target, "@") { + log.Fatal().Str("new_email", *newEmail).Msg("--new-email must be a valid email address") + } + if strings.EqualFold(target, admin.Email) { + fmt.Fprintf(os.Stderr, "No change — current email already matches %q\n", target) + return + } + // Catch the unique-index conflict early with a clear message instead of a Postgres error. + var collisionCount int64 + if err := db.Model(&models.AdminUser{}). + Where("LOWER(email) = LOWER(?) AND id <> ?", target, admin.ID). + Count(&collisionCount).Error; err != nil { + log.Fatal().Err(err).Msg("failed to check for email collision") + } + if collisionCount > 0 { + log.Fatal().Str("new_email", target).Msg("another admin row already uses this email — aborting") + } + + fmt.Fprintf(os.Stderr, "Renaming admin email: %s → %s (id=%d)\n", admin.Email, target, admin.ID) + fmt.Fprintf(os.Stderr, "DB host: %s\n\n", host) + res := db.Model(&models.AdminUser{}). + Where("id = ?", admin.ID). + Updates(map[string]any{ + "email": target, + "updated_at": time.Now().UTC(), + }) + if res.Error != nil { + log.Fatal().Err(res.Error).Msg("failed to rename admin email") + } + if res.RowsAffected != 1 { + log.Fatal().Int64("rows", res.RowsAffected).Msg("expected exactly 1 row updated") + } + fmt.Fprintf(os.Stderr, "OK — email is now %s\n", target) + return + } + + if *verify { + fmt.Fprintf(os.Stderr, "Verifying password for: %s (id=%d, role=%s, active=%t)\n", + admin.Email, admin.ID, admin.Role, admin.IsActive) + fmt.Fprintf(os.Stderr, "DB host: %s\n\n", host) + + pw, err := readPassword("Password: ") + if err != nil { + log.Fatal().Err(err).Msg("failed to read password") + } + if admin.CheckPassword(pw) { + fmt.Fprintln(os.Stderr, "PASS — bcrypt hash matches the supplied password") + if !admin.IsActive { + fmt.Fprintln(os.Stderr, "WARNING: is_active = false — login will still be rejected with \"Account is disabled\"") + } + } else { + fmt.Fprintln(os.Stderr, "FAIL — bcrypt hash does NOT match the supplied password") + os.Exit(1) + } + return + } + + fmt.Fprintf(os.Stderr, "Resetting password for: %s (id=%d, role=%s, active=%t)\n", + admin.Email, admin.ID, admin.Role, admin.IsActive) + fmt.Fprintf(os.Stderr, "DB host: %s\n\n", host) + + pw1, err := readPassword("New password: ") + if err != nil { + log.Fatal().Err(err).Msg("failed to read password") + } + if len(pw1) < minPasswordLen { + log.Fatal().Int("min", minPasswordLen).Msg("password too short") + } + + pw2, err := readPassword("Confirm password: ") + if err != nil { + log.Fatal().Err(err).Msg("failed to read password") + } + if pw1 != pw2 { + log.Fatal().Msg("passwords do not match") + } + + hash, err := bcrypt.GenerateFromPassword([]byte(pw1), bcrypt.DefaultCost) + if err != nil { + log.Fatal().Err(err).Msg("failed to hash password") + } + + res := db.Model(&models.AdminUser{}). + Where("id = ?", admin.ID). + Updates(map[string]any{ + "password": string(hash), + "updated_at": time.Now().UTC(), + }) + if res.Error != nil { + log.Fatal().Err(res.Error).Msg("failed to update admin user") + } + if res.RowsAffected != 1 { + log.Fatal().Int64("rows", res.RowsAffected).Msg("expected exactly 1 row updated") + } + + fmt.Fprintf(os.Stderr, "\nOK — password reset for %s\n", admin.Email) +} + +func buildDSN(passwordFile string) (dsn, host string, err error) { + host = os.Getenv("DB_HOST") + user := os.Getenv("POSTGRES_USER") + dbname := os.Getenv("POSTGRES_DB") + sslmode := os.Getenv("DB_SSLMODE") + if sslmode == "" { + sslmode = "require" + } + + port := 5432 + if s := os.Getenv("DB_PORT"); s != "" { + p, perr := strconv.Atoi(s) + if perr != nil { + return "", "", fmt.Errorf("invalid DB_PORT %q: %w", s, perr) + } + port = p + } + + password := os.Getenv("POSTGRES_PASSWORD") + if password == "" && passwordFile != "" { + b, rerr := os.ReadFile(passwordFile) + if rerr != nil { + return "", "", fmt.Errorf("POSTGRES_PASSWORD not set and could not read %s: %w", passwordFile, rerr) + } + password = strings.TrimRight(string(b), "\r\n") + } + + missing := []string{} + if host == "" { + missing = append(missing, "DB_HOST") + } + if user == "" { + missing = append(missing, "POSTGRES_USER") + } + if dbname == "" { + missing = append(missing, "POSTGRES_DB") + } + if password == "" { + missing = append(missing, "POSTGRES_PASSWORD") + } + if len(missing) > 0 { + return "", "", fmt.Errorf("missing required env vars: %s", strings.Join(missing, ", ")) + } + + dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + host, port, user, password, dbname, sslmode) + return dsn, host, nil +} + +func readPassword(prompt string) (string, error) { + fmt.Fprint(os.Stderr, prompt) + if term.IsTerminal(int(os.Stdin.Fd())) { + b, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + return "", err + } + return strings.TrimRight(string(b), "\r\n"), nil + } + s, err := bufio.NewReader(os.Stdin).ReadString('\n') + if err != nil { + return "", err + } + return strings.TrimRight(s, "\r\n"), nil +} diff --git a/cmd/notif-diag/main.go b/cmd/notif-diag/main.go new file mode 100644 index 0000000..7f66dba --- /dev/null +++ b/cmd/notif-diag/main.go @@ -0,0 +1,333 @@ +// notif-diag is a CLI for inspecting and (optionally) cleaning up stuck +// notification rows. Default mode is read-only — runs SELECTs and prints a +// summary. With --mark-failed-as-sent, marks pending rows that already have a +// recorded error as sent (cosmetic — no retry, no resend). +// +// Usage: +// +// set -a && source deploy/prod.env && set +a +// go run ./cmd/notif-diag # diagnose +// go run ./cmd/notif-diag --mark-failed-as-sent --yes # clean up errored backlog +package main + +import ( + "bufio" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func main() { + passwordFile := stringFlag("password-file", "deploy/secrets/postgres_password.txt", + "Path to file containing POSTGRES_PASSWORD (used if env var is empty)") + markFailed := boolFlag("mark-failed-as-sent", + "Mark every pending row with a non-empty error_message as sent. Cosmetic only — does not retry the push.") + yes := boolFlag("yes", "Skip the interactive confirmation prompt for destructive actions.") + + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}) + + dsn, host, err := buildDSN(*passwordFile) + if err != nil { + log.Fatal().Err(err).Msg("failed to build database DSN") + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + log.Fatal().Err(err).Msg("failed to connect to database") + } + + fmt.Printf("DB host: %s\n", host) + fmt.Println(strings.Repeat("=", 80)) + + overallTotals(db) + pendingByType(db) + recentPending(db) + deviceCounts(db) + + if *markFailed { + markFailedAsSent(db, *yes) + } +} + +// markFailedAsSent updates pending rows whose error_message is non-empty, +// flipping them to sent=true with sent_at=updated_at. This is purely cosmetic: +// it removes them from the "pending" count so dashboards and the diag tool +// don't keep flagging an old, unfixable backlog. It does NOT re-send anything. +func markFailedAsSent(db *gorm.DB, skipPrompt bool) { + var candidate int64 + if err := db.Raw(` + SELECT COUNT(*) FROM notifications_notification + WHERE sent = false AND error_message IS NOT NULL AND error_message <> '' + `).Scan(&candidate).Error; err != nil { + log.Fatal().Err(err).Msg("failed to count cleanup candidates") + } + + fmt.Printf("\n# Cleanup candidate count: %d\n", candidate) + if candidate == 0 { + fmt.Println(" (nothing to clean up)") + return + } + fmt.Println(" These rows have a recorded send error and will never be retried.") + fmt.Println(" Marking them sent=true is cosmetic — it just prevents them from") + fmt.Println(" showing up as pending in admin dashboards going forward.") + + if !skipPrompt { + fmt.Printf("\nProceed? Type 'yes' to update %d rows: ", candidate) + s, err := bufio.NewReader(os.Stdin).ReadString('\n') + if err != nil { + log.Fatal().Err(err).Msg("failed to read confirmation") + } + if strings.TrimSpace(s) != "yes" { + fmt.Println("Aborted.") + return + } + } + + res := db.Exec(` + UPDATE notifications_notification + SET sent = true, sent_at = COALESCE(updated_at, NOW()) + WHERE sent = false AND error_message IS NOT NULL AND error_message <> '' + `) + if res.Error != nil { + log.Fatal().Err(res.Error).Msg("failed to update rows") + } + fmt.Printf("OK — updated %d rows.\n", res.RowsAffected) +} + +// overallTotals shows the high-level sent/pending/read split. +func overallTotals(db *gorm.DB) { + type row struct { + Total int64 + Sent int64 + Pending int64 + Read int64 + Errored int64 + } + var r row + db.Raw(` + SELECT + COUNT(*) AS total, + COUNT(*) FILTER (WHERE sent = true) AS sent, + COUNT(*) FILTER (WHERE sent = false) AS pending, + COUNT(*) FILTER (WHERE read = true) AS read, + COUNT(*) FILTER (WHERE error_message IS NOT NULL AND error_message <> '') AS errored + FROM notifications_notification + `).Scan(&r) + + fmt.Println("\n# Overall notification counts") + fmt.Printf(" total: %d\n", r.Total) + fmt.Printf(" sent: %d\n", r.Sent) + fmt.Printf(" pending: %d\n", r.Pending) + fmt.Printf(" read: %d\n", r.Read) + fmt.Printf(" errored: %d (rows with non-empty error_message)\n", r.Errored) +} + +// pendingByType breaks the pending rows down by type and age. +func pendingByType(db *gorm.DB) { + type row struct { + NotificationType string + PendingCount int64 + Oldest *time.Time + Newest *time.Time + WithErrors int64 + Last24h int64 + Last7d int64 + } + var rows []row + db.Raw(` + SELECT + notification_type, + COUNT(*) AS pending_count, + MIN(created_at) AS oldest, + MAX(created_at) AS newest, + COUNT(*) FILTER (WHERE error_message IS NOT NULL AND error_message <> '') AS with_errors, + COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '24 hours') AS last_24h, + COUNT(*) FILTER (WHERE created_at > NOW() - INTERVAL '7 days') AS last_7d + FROM notifications_notification + WHERE sent = false + GROUP BY notification_type + ORDER BY MAX(created_at) DESC NULLS LAST + `).Scan(&rows) + + fmt.Println("\n# Pending rows by type") + if len(rows) == 0 { + fmt.Println(" (no pending notifications)") + return + } + fmt.Printf(" %-22s %7s %7s %7s %7s %-19s %-19s\n", + "TYPE", "PENDING", "ERRORED", "LAST24H", "LAST7D", "OLDEST", "NEWEST") + for _, r := range rows { + fmt.Printf(" %-22s %7d %7d %7d %7d %-19s %-19s\n", + r.NotificationType, r.PendingCount, r.WithErrors, r.Last24h, r.Last7d, + fmtTime(r.Oldest), fmtTime(r.Newest)) + } +} + +// recentPending shows the 5 most recent pending rows with full detail. +func recentPending(db *gorm.DB) { + type row struct { + ID uint + UserID uint + NotificationType string + Title string + Body string + ErrorMessage string + CreatedAt time.Time + } + var rows []row + db.Raw(` + SELECT id, user_id, notification_type, title, body, COALESCE(error_message, '') AS error_message, created_at + FROM notifications_notification + WHERE sent = false + ORDER BY created_at DESC + LIMIT 5 + `).Scan(&rows) + + fmt.Println("\n# 5 most recent pending notifications") + if len(rows) == 0 { + fmt.Println(" (none)") + return + } + for _, r := range rows { + errPart := "" + if r.ErrorMessage != "" { + errPart = fmt.Sprintf("\n error: %s", r.ErrorMessage) + } + fmt.Printf(" [%d] user=%d %s %s%s\n title: %s\n body: %s\n", + r.ID, r.UserID, r.CreatedAt.Format("2006-01-02 15:04:05"), r.NotificationType, errPart, + truncate(r.Title, 100), truncate(r.Body, 100)) + } +} + +// deviceCounts shows how many push devices are registered (active vs inactive). +func deviceCounts(db *gorm.DB) { + type row struct { + Total int64 + Active int64 + WithUser int64 + DistinctUsers int64 + } + + fmt.Println("\n# Registered push devices") + for _, t := range []struct { + label string + table string + }{ + {"APNs (iOS)", "push_notifications_apnsdevice"}, + {"GCM (Android)", "push_notifications_gcmdevice"}, + } { + var r row + err := db.Raw(fmt.Sprintf(` + SELECT + COUNT(*) AS total, + COUNT(*) FILTER (WHERE active = true) AS active, + COUNT(*) FILTER (WHERE user_id IS NOT NULL) AS with_user, + COUNT(DISTINCT user_id) AS distinct_users + FROM %s + `, t.table)).Scan(&r).Error + if err != nil { + fmt.Printf(" %-15s ERROR: %v\n", t.label, err) + continue + } + fmt.Printf(" %-15s total=%-5d active=%-5d with_user=%-5d distinct_users=%d\n", + t.label, r.Total, r.Active, r.WithUser, r.DistinctUsers) + } +} + +func buildDSN(passwordFile string) (dsn, host string, err error) { + host = os.Getenv("DB_HOST") + user := os.Getenv("POSTGRES_USER") + dbname := os.Getenv("POSTGRES_DB") + sslmode := os.Getenv("DB_SSLMODE") + if sslmode == "" { + sslmode = "require" + } + + port := 5432 + if s := os.Getenv("DB_PORT"); s != "" { + p, perr := strconv.Atoi(s) + if perr != nil { + return "", "", fmt.Errorf("invalid DB_PORT %q: %w", s, perr) + } + port = p + } + + password := os.Getenv("POSTGRES_PASSWORD") + if password == "" && passwordFile != "" { + b, rerr := os.ReadFile(passwordFile) + if rerr != nil { + return "", "", fmt.Errorf("POSTGRES_PASSWORD not set and could not read %s: %w", passwordFile, rerr) + } + password = strings.TrimRight(string(b), "\r\n") + } + + missing := []string{} + if host == "" { + missing = append(missing, "DB_HOST") + } + if user == "" { + missing = append(missing, "POSTGRES_USER") + } + if dbname == "" { + missing = append(missing, "POSTGRES_DB") + } + if password == "" { + missing = append(missing, "POSTGRES_PASSWORD") + } + if len(missing) > 0 { + return "", "", fmt.Errorf("missing required env vars: %s", strings.Join(missing, ", ")) + } + + dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + host, port, user, password, dbname, sslmode) + return dsn, host, nil +} + +// stringFlag is a tiny stand-in for flag.String to keep imports lean — using it +// also dodges flag-package quirks when this file is rebuilt with go run. +func stringFlag(name, def, _usage string) *string { + v := def + prefix := "--" + name + "=" + for _, a := range os.Args[1:] { + if strings.HasPrefix(a, prefix) { + v = strings.TrimPrefix(a, prefix) + } + } + return &v +} + +// boolFlag is true if --name is present in os.Args (no value form). +func boolFlag(name, _usage string) *bool { + want := "--" + name + v := false + for _, a := range os.Args[1:] { + if a == want { + v = true + } + } + return &v +} + +func fmtTime(t *time.Time) string { + if t == nil { + return "-" + } + return t.Format("2006-01-02 15:04:05") +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "…" +} diff --git a/go.mod b/go.mod index c23edc7..a47f4e8 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( go.opentelemetry.io/otel/sdk v1.43.0 golang.org/x/crypto v0.49.0 golang.org/x/oauth2 v0.35.0 + golang.org/x/term v0.41.0 golang.org/x/text v0.35.0 golang.org/x/time v0.15.0 google.golang.org/api v0.257.0 diff --git a/go.sum b/go.sum index e226c09..87a37e5 100644 --- a/go.sum +++ b/go.sum @@ -266,6 +266,8 @@ golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=