package database import ( "fmt" "time" "github.com/rs/zerolog/log" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" "github.com/treytartt/mycrib-api/internal/config" "github.com/treytartt/mycrib-api/internal/models" ) var db *gorm.DB // Connect establishes a connection to the PostgreSQL database func Connect(cfg *config.DatabaseConfig, debug bool) (*gorm.DB, error) { // Configure GORM logger logLevel := logger.Silent if debug { logLevel = logger.Info } gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logLevel), NowFunc: func() time.Time { return time.Now().UTC() }, PrepareStmt: true, // Cache prepared statements } // Connect to database var err error db, err = gorm.Open(postgres.Open(cfg.DSN()), gormConfig) if err != nil { return nil, fmt.Errorf("failed to connect to database: %w", err) } // Get underlying sql.DB for connection pool settings sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err) } // Configure connection pool sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) sqlDB.SetConnMaxLifetime(cfg.MaxLifetime) // Test connection if err := sqlDB.Ping(); err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) } log.Info(). Str("host", cfg.Host). Int("port", cfg.Port). Str("database", cfg.Database). Msg("Connected to PostgreSQL database") return db, nil } // Get returns the database instance func Get() *gorm.DB { return db } // Close closes the database connection func Close() error { if db != nil { sqlDB, err := db.DB() if err != nil { return err } return sqlDB.Close() } return nil } // WithTransaction executes a function within a database transaction func WithTransaction(fn func(tx *gorm.DB) error) error { return db.Transaction(fn) } // Paginate returns a GORM scope for pagination func Paginate(page, pageSize int) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { if page <= 0 { page = 1 } if pageSize <= 0 { pageSize = 100 } if pageSize > 1000 { pageSize = 1000 } offset := (page - 1) * pageSize return db.Offset(offset).Limit(pageSize) } } // Migrate runs database migrations for all models func Migrate() error { log.Info().Msg("Running database migrations...") // Migrate all models in order (respecting foreign key constraints) err := db.AutoMigrate( // Lookup tables first (no foreign keys) &models.ResidenceType{}, &models.TaskCategory{}, &models.TaskPriority{}, &models.TaskFrequency{}, &models.TaskStatus{}, &models.ContractorSpecialty{}, // User and auth tables &models.User{}, &models.AuthToken{}, &models.UserProfile{}, &models.ConfirmationCode{}, &models.PasswordResetCode{}, // Main entity tables (order matters for foreign keys!) &models.Residence{}, &models.ResidenceShareCode{}, &models.Contractor{}, // Contractor before Task (Task references Contractor) &models.Task{}, &models.TaskCompletion{}, &models.Document{}, // Notification tables &models.Notification{}, &models.NotificationPreference{}, &models.APNSDevice{}, &models.GCMDevice{}, // Subscription tables &models.SubscriptionSettings{}, &models.UserSubscription{}, &models.UpgradeTrigger{}, &models.FeatureBenefit{}, &models.Promotion{}, &models.TierLimits{}, ) if err != nil { return fmt.Errorf("failed to run migrations: %w", err) } log.Info().Msg("Database migrations completed successfully") return nil }