diff --git a/internal/config/config.go b/internal/config/config.go index 1348c7e..2e9ed76 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,7 +2,9 @@ package config import ( "fmt" + "net/url" "os" + "strconv" "strings" "time" @@ -96,6 +98,34 @@ func Load() (*Config, error) { // Set defaults setDefaults() + // Parse DATABASE_URL if set (Dokku-style) + dbConfig := DatabaseConfig{ + Host: viper.GetString("DB_HOST"), + Port: viper.GetInt("DB_PORT"), + User: viper.GetString("POSTGRES_USER"), + Password: viper.GetString("POSTGRES_PASSWORD"), + Database: viper.GetString("POSTGRES_DB"), + SSLMode: viper.GetString("DB_SSLMODE"), + MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"), + MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"), + MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"), + } + + // Override with DATABASE_URL if present + if databaseURL := viper.GetString("DATABASE_URL"); databaseURL != "" { + parsed, err := parseDatabaseURL(databaseURL) + if err == nil { + dbConfig.Host = parsed.Host + dbConfig.Port = parsed.Port + dbConfig.User = parsed.User + dbConfig.Password = parsed.Password + dbConfig.Database = parsed.Database + if parsed.SSLMode != "" { + dbConfig.SSLMode = parsed.SSLMode + } + } + } + cfg = &Config{ Server: ServerConfig{ Port: viper.GetInt("PORT"), @@ -103,17 +133,7 @@ func Load() (*Config, error) { AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","), Timezone: viper.GetString("TIMEZONE"), }, - Database: DatabaseConfig{ - Host: viper.GetString("DB_HOST"), - Port: viper.GetInt("DB_PORT"), - User: viper.GetString("POSTGRES_USER"), - Password: viper.GetString("POSTGRES_PASSWORD"), - Database: viper.GetString("POSTGRES_DB"), - SSLMode: viper.GetString("DB_SSLMODE"), - MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"), - MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"), - MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"), - }, + Database: dbConfig, Redis: RedisConfig{ URL: viper.GetString("REDIS_URL"), Password: viper.GetString("REDIS_PASSWORD"), @@ -239,3 +259,39 @@ func (p *PushConfig) ReadAPNSKey() (string, error) { return string(content), nil } + +// parseDatabaseURL parses a PostgreSQL URL into DatabaseConfig +// Format: postgres://user:password@host:port/database?sslmode=disable +func parseDatabaseURL(databaseURL string) (*DatabaseConfig, error) { + u, err := url.Parse(databaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse DATABASE_URL: %w", err) + } + + // Default port + port := 5432 + if u.Port() != "" { + port, err = strconv.Atoi(u.Port()) + if err != nil { + return nil, fmt.Errorf("invalid port in DATABASE_URL: %w", err) + } + } + + // Get password + password, _ := u.User.Password() + + // Get database name (remove leading slash) + database := strings.TrimPrefix(u.Path, "/") + + // Get sslmode from query params + sslMode := u.Query().Get("sslmode") + + return &DatabaseConfig{ + Host: u.Hostname(), + Port: port, + User: u.User.Username(), + Password: password, + Database: database, + SSLMode: sslMode, + }, nil +}