Initial commit: MyCrib API in Go

Complete rewrite of Django REST API to Go with:
- Gin web framework for HTTP routing
- GORM for database operations
- GoAdmin for admin panel
- Gorush integration for push notifications
- Redis for caching and job queues

Features implemented:
- User authentication (login, register, logout, password reset)
- Residence management (CRUD, sharing, share codes)
- Task management (CRUD, kanban board, completions)
- Contractor management (CRUD, specialties)
- Document management (CRUD, warranties)
- Notifications (preferences, push notifications)
- Subscription management (tiers, limits)

Infrastructure:
- Docker Compose for local development
- Database migrations and seed data
- Admin panel for data management

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Trey t
2025-11-26 20:07:16 -06:00
commit 1f12f3f62a
78 changed files with 13821 additions and 0 deletions

88
internal/admin/admin.go Normal file
View File

@@ -0,0 +1,88 @@
package admin
import (
"fmt"
_ "github.com/GoAdminGroup/go-admin/adapter/gin" // Gin adapter for GoAdmin
"github.com/GoAdminGroup/go-admin/engine"
"github.com/GoAdminGroup/go-admin/modules/config"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/modules/language"
"github.com/GoAdminGroup/go-admin/plugins/admin"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template"
"github.com/GoAdminGroup/go-admin/template/chartjs"
"github.com/GoAdminGroup/themes/adminlte"
"github.com/gin-gonic/gin"
appconfig "github.com/treytartt/mycrib-api/internal/config"
"github.com/treytartt/mycrib-api/internal/admin/tables"
)
// Setup initializes the GoAdmin panel
func Setup(r *gin.Engine, cfg *appconfig.Config) (*engine.Engine, error) {
eng := engine.Default()
// Register the AdminLTE theme
template.AddComp(chartjs.NewChart())
// Configure GoAdmin
adminConfig := config.Config{
Databases: config.DatabaseList{
"default": {
Host: cfg.Database.Host,
Port: fmt.Sprintf("%d", cfg.Database.Port),
User: cfg.Database.User,
Pwd: cfg.Database.Password,
Name: cfg.Database.Database,
MaxIdleConns: cfg.Database.MaxIdleConns,
MaxOpenConns: cfg.Database.MaxOpenConns,
Driver: db.DriverPostgresql,
},
},
UrlPrefix: "admin",
IndexUrl: "/",
Debug: cfg.Server.Debug,
Language: language.EN,
Theme: "adminlte",
Store: config.Store{
Path: "./uploads",
Prefix: "uploads",
},
Title: "MyCrib Admin",
Logo: "MyCrib",
MiniLogo: "MC",
BootstrapFilePath: "",
GoModFilePath: "",
ColorScheme: adminlte.ColorschemeSkinBlack,
Animation: config.PageAnimation{
Type: "fadeInUp",
},
}
// Add the admin plugin with generators
adminPlugin := admin.NewAdmin(GetTables())
// Initialize engine and add generators
if err := eng.AddConfig(&adminConfig).
AddGenerators(GetTables()).
AddPlugins(adminPlugin).
Use(r); err != nil {
return nil, err
}
// Add redirect for /admin to dashboard
r.GET("/admin", func(c *gin.Context) {
c.Redirect(302, "/admin/menu")
})
r.GET("/admin/", func(c *gin.Context) {
c.Redirect(302, "/admin/menu")
})
return eng, nil
}
// GetTables returns all table generators for the admin panel
func GetTables() table.GeneratorList {
return tables.Generators
}

View File

@@ -0,0 +1,73 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetContractorsTable returns the contractors table configuration
func GetContractorsTable(ctx *context.Context) table.Table {
contractors := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := contractors.GetInfo()
info.SetTable("task_contractor")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Company", "company", db.Varchar).FieldFilterable()
info.AddField("Phone", "phone", db.Varchar).FieldFilterable()
info.AddField("Email", "email", db.Varchar).FieldFilterable()
info.AddField("Residence ID", "residence_id", db.Int).FieldFilterable()
info.AddField("Specialty ID", "specialty_id", db.Int).FieldFilterable()
info.AddField("Is Favorite", "is_favorite", db.Bool).FieldFilterable()
info.AddField("Notes", "notes", db.Text)
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.AddField("Updated At", "updated_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := contractors.GetForm()
formList.SetTable("task_contractor")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Company", "company", db.Varchar, form.Text)
formList.AddField("Phone", "phone", db.Varchar, form.Text)
formList.AddField("Email", "email", db.Varchar, form.Email)
formList.AddField("Website", "website", db.Varchar, form.Url)
formList.AddField("Address", "address", db.Varchar, form.Text)
formList.AddField("City", "city", db.Varchar, form.Text)
formList.AddField("State", "state", db.Varchar, form.Text)
formList.AddField("Zip Code", "zip_code", db.Varchar, form.Text)
formList.AddField("Residence ID", "residence_id", db.Int, form.Number).FieldMust()
formList.AddField("Specialty ID", "specialty_id", db.Int, form.Number)
formList.AddField("Is Favorite", "is_favorite", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Notes", "notes", db.Text, form.TextArea)
return contractors
}
// GetContractorSpecialtiesTable returns the contractor specialties lookup table configuration
func GetContractorSpecialtiesTable(ctx *context.Context) table.Table {
specialties := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := specialties.GetInfo()
info.SetTable("task_contractorspecialty")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Icon (iOS)", "icon_ios", db.Varchar)
info.AddField("Icon (Android)", "icon_android", db.Varchar)
info.AddField("Display Order", "display_order", db.Int).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := specialties.GetForm()
formList.SetTable("task_contractorspecialty")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Icon (iOS)", "icon_ios", db.Varchar, form.Text)
formList.AddField("Icon (Android)", "icon_android", db.Varchar, form.Text)
formList.AddField("Display Order", "display_order", db.Int, form.Number).FieldDefault("0")
return specialties
}

View File

@@ -0,0 +1,57 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetDocumentsTable returns the documents table configuration
func GetDocumentsTable(ctx *context.Context) table.Table {
documents := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := documents.GetInfo()
info.SetTable("task_document")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Title", "title", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Type", "document_type", db.Varchar).FieldFilterable()
info.AddField("Residence ID", "residence_id", db.Int).FieldFilterable()
info.AddField("File URL", "file_url", db.Varchar)
info.AddField("Is Active", "is_active", db.Bool).FieldFilterable()
info.AddField("Expiration Date", "expiration_date", db.Date).FieldSortable()
info.AddField("Created By ID", "created_by_id", db.Int).FieldFilterable()
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.AddField("Updated At", "updated_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := documents.GetForm()
formList.SetTable("task_document")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Title", "title", db.Varchar, form.Text).FieldMust()
formList.AddField("Description", "description", db.Text, form.TextArea)
formList.AddField("Type", "document_type", db.Varchar, form.SelectSingle).
FieldOptions(types.FieldOptions{
{Value: "warranty", Text: "Warranty"},
{Value: "contract", Text: "Contract"},
{Value: "receipt", Text: "Receipt"},
{Value: "manual", Text: "Manual"},
{Value: "insurance", Text: "Insurance"},
{Value: "other", Text: "Other"},
})
formList.AddField("Residence ID", "residence_id", db.Int, form.Number).FieldMust()
formList.AddField("File URL", "file_url", db.Varchar, form.Url)
formList.AddField("Is Active", "is_active", db.Bool, form.Switch).FieldDefault("true")
formList.AddField("Expiration Date", "expiration_date", db.Date, form.Date)
formList.AddField("Purchase Date", "purchase_date", db.Date, form.Date)
formList.AddField("Purchase Amount", "purchase_amount", db.Decimal, form.Currency)
formList.AddField("Vendor", "vendor", db.Varchar, form.Text)
formList.AddField("Serial Number", "serial_number", db.Varchar, form.Text)
formList.AddField("Model Number", "model_number", db.Varchar, form.Text)
formList.AddField("Created By ID", "created_by_id", db.Int, form.Number)
formList.AddField("Notes", "notes", db.Text, form.TextArea)
return documents
}

View File

@@ -0,0 +1,51 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetNotificationsTable returns the notifications table configuration
func GetNotificationsTable(ctx *context.Context) table.Table {
notifications := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := notifications.GetInfo()
info.SetTable("notifications_notification")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("User ID", "user_id", db.Int).FieldFilterable()
info.AddField("Title", "title", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Message", "message", db.Text)
info.AddField("Type", "notification_type", db.Varchar).FieldFilterable()
info.AddField("Is Read", "is_read", db.Bool).FieldFilterable()
info.AddField("Task ID", "task_id", db.Int).FieldFilterable()
info.AddField("Residence ID", "residence_id", db.Int).FieldFilterable()
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := notifications.GetForm()
formList.SetTable("notifications_notification")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("User ID", "user_id", db.Int, form.Number).FieldMust()
formList.AddField("Title", "title", db.Varchar, form.Text).FieldMust()
formList.AddField("Message", "message", db.Text, form.TextArea).FieldMust()
formList.AddField("Type", "notification_type", db.Varchar, form.SelectSingle).
FieldOptions(types.FieldOptions{
{Value: "task_assigned", Text: "Task Assigned"},
{Value: "task_completed", Text: "Task Completed"},
{Value: "task_due", Text: "Task Due"},
{Value: "task_overdue", Text: "Task Overdue"},
{Value: "residence_shared", Text: "Residence Shared"},
{Value: "system", Text: "System"},
})
formList.AddField("Is Read", "is_read", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Task ID", "task_id", db.Int, form.Number)
formList.AddField("Residence ID", "residence_id", db.Int, form.Number)
formList.AddField("Data JSON", "data", db.Text, form.TextArea)
formList.AddField("Action URL", "action_url", db.Varchar, form.Url)
return notifications
}

View File

@@ -0,0 +1,70 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetResidencesTable returns the residences table configuration
func GetResidencesTable(ctx *context.Context) table.Table {
residences := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := residences.GetInfo()
info.SetTable("residence_residence")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Address", "address", db.Varchar).FieldFilterable()
info.AddField("City", "city", db.Varchar).FieldFilterable()
info.AddField("State", "state", db.Varchar).FieldFilterable()
info.AddField("Zip Code", "zip_code", db.Varchar).FieldFilterable()
info.AddField("Owner ID", "owner_id", db.Int).FieldFilterable()
info.AddField("Type ID", "residence_type_id", db.Int).FieldFilterable()
info.AddField("Is Active", "is_active", db.Bool).FieldFilterable()
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.AddField("Updated At", "updated_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := residences.GetForm()
formList.SetTable("residence_residence")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Address", "address", db.Varchar, form.Text)
formList.AddField("City", "city", db.Varchar, form.Text)
formList.AddField("State", "state", db.Varchar, form.Text)
formList.AddField("Zip Code", "zip_code", db.Varchar, form.Text)
formList.AddField("Owner ID", "owner_id", db.Int, form.Number).FieldMust()
formList.AddField("Type ID", "residence_type_id", db.Int, form.Number)
formList.AddField("Is Active", "is_active", db.Bool, form.Switch).FieldDefault("true")
formList.AddField("Share Code", "share_code", db.Varchar, form.Text)
formList.AddField("Share Code Expires", "share_code_expires_at", db.Timestamp, form.Datetime)
return residences
}
// GetResidenceTypesTable returns the residence types lookup table configuration
func GetResidenceTypesTable(ctx *context.Context) table.Table {
types := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := types.GetInfo()
info.SetTable("residence_residencetype")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Icon (iOS)", "icon_ios", db.Varchar)
info.AddField("Icon (Android)", "icon_android", db.Varchar)
info.AddField("Display Order", "display_order", db.Int).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := types.GetForm()
formList.SetTable("residence_residencetype")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Icon (iOS)", "icon_ios", db.Varchar, form.Text)
formList.AddField("Icon (Android)", "icon_android", db.Varchar, form.Text)
formList.AddField("Display Order", "display_order", db.Int, form.Number).FieldDefault("0")
return types
}

View File

@@ -0,0 +1,53 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetSubscriptionsTable returns the user subscriptions table configuration
func GetSubscriptionsTable(ctx *context.Context) table.Table {
subscriptions := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := subscriptions.GetInfo()
info.SetTable("subscription_usersubscription")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("User ID", "user_id", db.Int).FieldFilterable()
info.AddField("Tier", "tier", db.Varchar).FieldFilterable()
info.AddField("Subscribed At", "subscribed_at", db.Timestamp).FieldSortable()
info.AddField("Expires At", "expires_at", db.Timestamp).FieldSortable()
info.AddField("Cancelled At", "cancelled_at", db.Timestamp).FieldSortable()
info.AddField("Auto Renew", "auto_renew", db.Bool).FieldFilterable()
info.AddField("Platform", "platform", db.Varchar).FieldFilterable()
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.AddField("Updated At", "updated_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := subscriptions.GetForm()
formList.SetTable("subscription_usersubscription")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("User ID", "user_id", db.Int, form.Number).FieldMust()
formList.AddField("Tier", "tier", db.Varchar, form.SelectSingle).
FieldOptions(types.FieldOptions{
{Value: "free", Text: "Free"},
{Value: "pro", Text: "Pro"},
}).FieldDefault("free")
formList.AddField("Subscribed At", "subscribed_at", db.Timestamp, form.Datetime)
formList.AddField("Expires At", "expires_at", db.Timestamp, form.Datetime)
formList.AddField("Cancelled At", "cancelled_at", db.Timestamp, form.Datetime)
formList.AddField("Auto Renew", "auto_renew", db.Bool, form.Switch).FieldDefault("true")
formList.AddField("Platform", "platform", db.Varchar, form.SelectSingle).
FieldOptions(types.FieldOptions{
{Value: "", Text: "None"},
{Value: "ios", Text: "iOS"},
{Value: "android", Text: "Android"},
})
formList.AddField("Apple Receipt Data", "apple_receipt_data", db.Text, form.TextArea)
formList.AddField("Google Purchase Token", "google_purchase_token", db.Text, form.TextArea)
return subscriptions
}

View File

@@ -0,0 +1,21 @@
package tables
import "github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
// Generators is a map of table generators
var Generators = map[string]table.Generator{
"users": GetUsersTable,
"residences": GetResidencesTable,
"tasks": GetTasksTable,
"task_completions": GetTaskCompletionsTable,
"contractors": GetContractorsTable,
"documents": GetDocumentsTable,
"notifications": GetNotificationsTable,
"user_subscriptions": GetSubscriptionsTable,
"task_categories": GetTaskCategoriesTable,
"task_priorities": GetTaskPrioritiesTable,
"task_statuses": GetTaskStatusesTable,
"task_frequencies": GetTaskFrequenciesTable,
"contractor_specialties": GetContractorSpecialtiesTable,
"residence_types": GetResidenceTypesTable,
}

View File

@@ -0,0 +1,187 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetTasksTable returns the tasks table configuration
func GetTasksTable(ctx *context.Context) table.Table {
tasks := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := tasks.GetInfo()
info.SetTable("task_task")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Title", "title", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Description", "description", db.Text)
info.AddField("Residence ID", "residence_id", db.Int).FieldFilterable()
info.AddField("Category ID", "category_id", db.Int).FieldFilterable()
info.AddField("Priority ID", "priority_id", db.Int).FieldFilterable()
info.AddField("Status ID", "status_id", db.Int).FieldFilterable()
info.AddField("Frequency ID", "frequency_id", db.Int).FieldFilterable()
info.AddField("Due Date", "due_date", db.Date).FieldFilterable().FieldSortable()
info.AddField("Created By ID", "created_by_id", db.Int).FieldFilterable()
info.AddField("Assigned To ID", "assigned_to_id", db.Int).FieldFilterable()
info.AddField("Is Recurring", "is_recurring", db.Bool).FieldFilterable()
info.AddField("Is Cancelled", "is_cancelled", db.Bool).FieldFilterable()
info.AddField("Is Archived", "is_archived", db.Bool).FieldFilterable()
info.AddField("Estimated Cost", "estimated_cost", db.Decimal)
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.AddField("Updated At", "updated_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := tasks.GetForm()
formList.SetTable("task_task")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Title", "title", db.Varchar, form.Text).FieldMust()
formList.AddField("Description", "description", db.Text, form.TextArea)
formList.AddField("Residence ID", "residence_id", db.Int, form.Number).FieldMust()
formList.AddField("Category ID", "category_id", db.Int, form.Number)
formList.AddField("Priority ID", "priority_id", db.Int, form.Number)
formList.AddField("Status ID", "status_id", db.Int, form.Number)
formList.AddField("Frequency ID", "frequency_id", db.Int, form.Number)
formList.AddField("Due Date", "due_date", db.Date, form.Date)
formList.AddField("Created By ID", "created_by_id", db.Int, form.Number)
formList.AddField("Assigned To ID", "assigned_to_id", db.Int, form.Number)
formList.AddField("Contractor ID", "contractor_id", db.Int, form.Number)
formList.AddField("Is Recurring", "is_recurring", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Is Cancelled", "is_cancelled", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Is Archived", "is_archived", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Estimated Cost", "estimated_cost", db.Decimal, form.Currency)
formList.AddField("Notes", "notes", db.Text, form.TextArea)
return tasks
}
// GetTaskCompletionsTable returns the task completions table configuration
func GetTaskCompletionsTable(ctx *context.Context) table.Table {
completions := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := completions.GetInfo()
info.SetTable("task_taskcompletion")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Task ID", "task_id", db.Int).FieldFilterable()
info.AddField("User ID", "user_id", db.Int).FieldFilterable()
info.AddField("Completed At", "completed_at", db.Timestamp).FieldSortable()
info.AddField("Notes", "notes", db.Text)
info.AddField("Actual Cost", "actual_cost", db.Decimal)
info.AddField("Receipt URL", "receipt_url", db.Varchar)
info.AddField("Created At", "created_at", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := completions.GetForm()
formList.SetTable("task_taskcompletion")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Task ID", "task_id", db.Int, form.Number).FieldMust()
formList.AddField("User ID", "user_id", db.Int, form.Number).FieldMust()
formList.AddField("Completed At", "completed_at", db.Timestamp, form.Datetime)
formList.AddField("Notes", "notes", db.Text, form.TextArea)
formList.AddField("Actual Cost", "actual_cost", db.Decimal, form.Currency)
formList.AddField("Receipt URL", "receipt_url", db.Varchar, form.Url)
return completions
}
// GetTaskCategoriesTable returns the task categories lookup table configuration
func GetTaskCategoriesTable(ctx *context.Context) table.Table {
categories := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := categories.GetInfo()
info.SetTable("task_taskcategory")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Icon (iOS)", "icon_ios", db.Varchar)
info.AddField("Icon (Android)", "icon_android", db.Varchar)
info.AddField("Color", "color", db.Varchar)
info.AddField("Display Order", "display_order", db.Int).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := categories.GetForm()
formList.SetTable("task_taskcategory")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Icon (iOS)", "icon_ios", db.Varchar, form.Text)
formList.AddField("Icon (Android)", "icon_android", db.Varchar, form.Text)
formList.AddField("Color", "color", db.Varchar, form.Color)
formList.AddField("Display Order", "display_order", db.Int, form.Number).FieldDefault("0")
return categories
}
// GetTaskPrioritiesTable returns the task priorities lookup table configuration
func GetTaskPrioritiesTable(ctx *context.Context) table.Table {
priorities := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := priorities.GetInfo()
info.SetTable("task_taskpriority")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Level", "level", db.Int).FieldSortable()
info.AddField("Color", "color", db.Varchar)
info.AddField("Display Order", "display_order", db.Int).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := priorities.GetForm()
formList.SetTable("task_taskpriority")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Level", "level", db.Int, form.Number).FieldMust()
formList.AddField("Color", "color", db.Varchar, form.Color)
formList.AddField("Display Order", "display_order", db.Int, form.Number).FieldDefault("0")
return priorities
}
// GetTaskStatusesTable returns the task statuses lookup table configuration
func GetTaskStatusesTable(ctx *context.Context) table.Table {
statuses := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := statuses.GetInfo()
info.SetTable("task_taskstatus")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Is Terminal", "is_terminal", db.Bool).FieldFilterable()
info.AddField("Color", "color", db.Varchar)
info.AddField("Display Order", "display_order", db.Int).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := statuses.GetForm()
formList.SetTable("task_taskstatus")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Is Terminal", "is_terminal", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Color", "color", db.Varchar, form.Color)
formList.AddField("Display Order", "display_order", db.Int, form.Number).FieldDefault("0")
return statuses
}
// GetTaskFrequenciesTable returns the task frequencies lookup table configuration
func GetTaskFrequenciesTable(ctx *context.Context) table.Table {
frequencies := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := frequencies.GetInfo()
info.SetTable("task_taskfrequency")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Name", "name", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("Days", "days", db.Int).FieldSortable()
info.AddField("Display Order", "display_order", db.Int).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := frequencies.GetForm()
formList.SetTable("task_taskfrequency")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Name", "name", db.Varchar, form.Text).FieldMust()
formList.AddField("Days", "days", db.Int, form.Number)
formList.AddField("Display Order", "display_order", db.Int, form.Number).FieldDefault("0")
return frequencies
}

View File

@@ -0,0 +1,42 @@
package tables
import (
"github.com/GoAdminGroup/go-admin/context"
"github.com/GoAdminGroup/go-admin/modules/db"
"github.com/GoAdminGroup/go-admin/plugins/admin/modules/table"
"github.com/GoAdminGroup/go-admin/template/types/form"
)
// GetUsersTable returns the users table configuration
func GetUsersTable(ctx *context.Context) table.Table {
users := table.NewDefaultTable(ctx, table.DefaultConfigWithDriver(db.DriverPostgresql))
info := users.GetInfo()
info.SetTable("auth_user")
info.AddField("ID", "id", db.Int).FieldFilterable()
info.AddField("Email", "email", db.Varchar).FieldFilterable().FieldSortable()
info.AddField("First Name", "first_name", db.Varchar).FieldFilterable()
info.AddField("Last Name", "last_name", db.Varchar).FieldFilterable()
info.AddField("Is Active", "is_active", db.Bool).FieldFilterable()
info.AddField("Is Staff", "is_staff", db.Bool).FieldFilterable()
info.AddField("Is Superuser", "is_superuser", db.Bool).FieldFilterable()
info.AddField("Email Verified", "email_verified", db.Bool).FieldFilterable()
info.AddField("Date Joined", "date_joined", db.Timestamp).FieldSortable()
info.AddField("Last Login", "last_login", db.Timestamp).FieldSortable()
info.SetFilterFormLayout(form.LayoutThreeCol)
formList := users.GetForm()
formList.SetTable("auth_user")
formList.AddField("ID", "id", db.Int, form.Default).FieldNotAllowAdd().FieldNotAllowEdit()
formList.AddField("Email", "email", db.Varchar, form.Email).FieldMust()
formList.AddField("First Name", "first_name", db.Varchar, form.Text)
formList.AddField("Last Name", "last_name", db.Varchar, form.Text)
formList.AddField("Is Active", "is_active", db.Bool, form.Switch).FieldDefault("true")
formList.AddField("Is Staff", "is_staff", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Is Superuser", "is_superuser", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Email Verified", "email_verified", db.Bool, form.Switch).FieldDefault("false")
formList.AddField("Timezone", "timezone", db.Varchar, form.Text).FieldDefault("UTC")
return users
}

241
internal/config/config.go Normal file
View File

@@ -0,0 +1,241 @@
package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/spf13/viper"
)
// Config holds all configuration for the application
type Config struct {
Server ServerConfig
Database DatabaseConfig
Redis RedisConfig
Email EmailConfig
Push PushConfig
Worker WorkerConfig
Security SecurityConfig
}
type ServerConfig struct {
Port int
Debug bool
AllowedHosts []string
Timezone string
}
type DatabaseConfig struct {
Host string
Port int
User string
Password string
Database string
SSLMode string
MaxOpenConns int
MaxIdleConns int
MaxLifetime time.Duration
}
type RedisConfig struct {
URL string
Password string
DB int
}
type EmailConfig struct {
Host string
Port int
User string
Password string
From string
UseTLS bool
}
type PushConfig struct {
// Gorush server URL
GorushURL string
// APNs (iOS)
APNSKeyPath string
APNSKeyID string
APNSTeamID string
APNSTopic string
APNSSandbox bool
// FCM (Android)
FCMServerKey string
}
type WorkerConfig struct {
// Scheduled job times (UTC)
TaskReminderHour int
TaskReminderMinute int
OverdueReminderHour int
DailyNotifHour int
}
type SecurityConfig struct {
SecretKey string
TokenCacheTTL time.Duration
PasswordResetExpiry time.Duration
ConfirmationExpiry time.Duration
MaxPasswordResetRate int // per hour
}
var cfg *Config
// Load reads configuration from environment variables
func Load() (*Config, error) {
viper.SetEnvPrefix("")
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// Set defaults
setDefaults()
cfg = &Config{
Server: ServerConfig{
Port: viper.GetInt("PORT"),
Debug: viper.GetBool("DEBUG"),
AllowedHosts: strings.Split(viper.GetString("ALLOWED_HOSTS"), ","),
Timezone: viper.GetString("TIMEZONE"),
},
Database: DatabaseConfig{
Host: viper.GetString("DB_HOST"),
Port: viper.GetInt("DB_PORT"),
User: viper.GetString("POSTGRES_USER"),
Password: viper.GetString("POSTGRES_PASSWORD"),
Database: viper.GetString("POSTGRES_DB"),
SSLMode: viper.GetString("DB_SSLMODE"),
MaxOpenConns: viper.GetInt("DB_MAX_OPEN_CONNS"),
MaxIdleConns: viper.GetInt("DB_MAX_IDLE_CONNS"),
MaxLifetime: viper.GetDuration("DB_MAX_LIFETIME"),
},
Redis: RedisConfig{
URL: viper.GetString("REDIS_URL"),
Password: viper.GetString("REDIS_PASSWORD"),
DB: viper.GetInt("REDIS_DB"),
},
Email: EmailConfig{
Host: viper.GetString("EMAIL_HOST"),
Port: viper.GetInt("EMAIL_PORT"),
User: viper.GetString("EMAIL_HOST_USER"),
Password: viper.GetString("EMAIL_HOST_PASSWORD"),
From: viper.GetString("DEFAULT_FROM_EMAIL"),
UseTLS: viper.GetBool("EMAIL_USE_TLS"),
},
Push: PushConfig{
GorushURL: viper.GetString("GORUSH_URL"),
APNSKeyPath: viper.GetString("APNS_AUTH_KEY_PATH"),
APNSKeyID: viper.GetString("APNS_AUTH_KEY_ID"),
APNSTeamID: viper.GetString("APNS_TEAM_ID"),
APNSTopic: viper.GetString("APNS_TOPIC"),
APNSSandbox: viper.GetBool("APNS_USE_SANDBOX"),
FCMServerKey: viper.GetString("FCM_SERVER_KEY"),
},
Worker: WorkerConfig{
TaskReminderHour: viper.GetInt("CELERY_BEAT_REMINDER_HOUR"),
TaskReminderMinute: viper.GetInt("CELERY_BEAT_REMINDER_MINUTE"),
OverdueReminderHour: 9, // 9:00 AM UTC
DailyNotifHour: 11, // 11:00 AM UTC
},
Security: SecurityConfig{
SecretKey: viper.GetString("SECRET_KEY"),
TokenCacheTTL: 5 * time.Minute,
PasswordResetExpiry: 15 * time.Minute,
ConfirmationExpiry: 24 * time.Hour,
MaxPasswordResetRate: 3,
},
}
// Validate required fields
if err := validate(cfg); err != nil {
return nil, err
}
return cfg, nil
}
// Get returns the current configuration
func Get() *Config {
return cfg
}
func setDefaults() {
// Server defaults
viper.SetDefault("PORT", 8000)
viper.SetDefault("DEBUG", false)
viper.SetDefault("ALLOWED_HOSTS", "localhost,127.0.0.1")
viper.SetDefault("TIMEZONE", "UTC")
// Database defaults
viper.SetDefault("DB_HOST", "localhost")
viper.SetDefault("DB_PORT", 5432)
viper.SetDefault("POSTGRES_USER", "postgres")
viper.SetDefault("POSTGRES_DB", "mycrib")
viper.SetDefault("DB_SSLMODE", "disable")
viper.SetDefault("DB_MAX_OPEN_CONNS", 25)
viper.SetDefault("DB_MAX_IDLE_CONNS", 10)
viper.SetDefault("DB_MAX_LIFETIME", 600*time.Second)
// Redis defaults
viper.SetDefault("REDIS_URL", "redis://localhost:6379/0")
viper.SetDefault("REDIS_DB", 0)
// Email defaults
viper.SetDefault("EMAIL_HOST", "smtp.gmail.com")
viper.SetDefault("EMAIL_PORT", 587)
viper.SetDefault("EMAIL_USE_TLS", true)
viper.SetDefault("DEFAULT_FROM_EMAIL", "MyCrib <noreply@mycrib.com>")
// Push notification defaults
viper.SetDefault("GORUSH_URL", "http://localhost:8088")
viper.SetDefault("APNS_TOPIC", "com.example.mycrib")
viper.SetDefault("APNS_USE_SANDBOX", true)
// Worker defaults
viper.SetDefault("CELERY_BEAT_REMINDER_HOUR", 20)
viper.SetDefault("CELERY_BEAT_REMINDER_MINUTE", 0)
}
func validate(cfg *Config) error {
if cfg.Security.SecretKey == "" {
// In development, use a default key
if cfg.Server.Debug {
cfg.Security.SecretKey = "development-secret-key-change-in-production"
} else {
return fmt.Errorf("SECRET_KEY is required in production")
}
}
if cfg.Database.Password == "" && !cfg.Server.Debug {
return fmt.Errorf("POSTGRES_PASSWORD is required")
}
return nil
}
// DSN returns the database connection string
func (d *DatabaseConfig) DSN() string {
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.Database, d.SSLMode,
)
}
// ReadAPNSKey reads the APNs key from file if path is provided
func (p *PushConfig) ReadAPNSKey() (string, error) {
if p.APNSKeyPath == "" {
return "", nil
}
content, err := os.ReadFile(p.APNSKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read APNs key: %w", err)
}
return string(content), nil
}

View File

@@ -0,0 +1,156 @@
package database
import (
"fmt"
"time"
"github.com/rs/zerolog/log"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/treytartt/mycrib-api/internal/config"
"github.com/treytartt/mycrib-api/internal/models"
)
var db *gorm.DB
// Connect establishes a connection to the PostgreSQL database
func Connect(cfg *config.DatabaseConfig, debug bool) (*gorm.DB, error) {
// Configure GORM logger
logLevel := logger.Silent
if debug {
logLevel = logger.Info
}
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logLevel),
NowFunc: func() time.Time {
return time.Now().UTC()
},
PrepareStmt: true, // Cache prepared statements
}
// Connect to database
var err error
db, err = gorm.Open(postgres.Open(cfg.DSN()), gormConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Get underlying sql.DB for connection pool settings
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
// Configure connection pool
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetConnMaxLifetime(cfg.MaxLifetime)
// Test connection
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
log.Info().
Str("host", cfg.Host).
Int("port", cfg.Port).
Str("database", cfg.Database).
Msg("Connected to PostgreSQL database")
return db, nil
}
// Get returns the database instance
func Get() *gorm.DB {
return db
}
// Close closes the database connection
func Close() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}
// WithTransaction executes a function within a database transaction
func WithTransaction(fn func(tx *gorm.DB) error) error {
return db.Transaction(fn)
}
// Paginate returns a GORM scope for pagination
func Paginate(page, pageSize int) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 100
}
if pageSize > 1000 {
pageSize = 1000
}
offset := (page - 1) * pageSize
return db.Offset(offset).Limit(pageSize)
}
}
// Migrate runs database migrations for all models
func Migrate() error {
log.Info().Msg("Running database migrations...")
// Migrate all models in order (respecting foreign key constraints)
err := db.AutoMigrate(
// Lookup tables first (no foreign keys)
&models.ResidenceType{},
&models.TaskCategory{},
&models.TaskPriority{},
&models.TaskFrequency{},
&models.TaskStatus{},
&models.ContractorSpecialty{},
// User and auth tables
&models.User{},
&models.AuthToken{},
&models.UserProfile{},
&models.ConfirmationCode{},
&models.PasswordResetCode{},
// Main entity tables (order matters for foreign keys!)
&models.Residence{},
&models.ResidenceShareCode{},
&models.Contractor{}, // Contractor before Task (Task references Contractor)
&models.Task{},
&models.TaskCompletion{},
&models.Document{},
// Notification tables
&models.Notification{},
&models.NotificationPreference{},
&models.APNSDevice{},
&models.GCMDevice{},
// Subscription tables
&models.SubscriptionSettings{},
&models.UserSubscription{},
&models.UpgradeTrigger{},
&models.FeatureBenefit{},
&models.Promotion{},
&models.TierLimits{},
)
if err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
log.Info().Msg("Database migrations completed successfully")
return nil
}

View File

@@ -0,0 +1,51 @@
package requests
// LoginRequest represents the login request body
type LoginRequest struct {
Username string `json:"username" binding:"required_without=Email"`
Email string `json:"email" binding:"required_without=Username,omitempty,email"`
Password string `json:"password" binding:"required,min=1"`
}
// RegisterRequest represents the registration request body
type RegisterRequest struct {
Username string `json:"username" binding:"required,min=3,max=150"`
Email string `json:"email" binding:"required,email,max=254"`
Password string `json:"password" binding:"required,min=8"`
FirstName string `json:"first_name" binding:"max=150"`
LastName string `json:"last_name" binding:"max=150"`
}
// VerifyEmailRequest represents the email verification request body
type VerifyEmailRequest struct {
Code string `json:"code" binding:"required,len=6"`
}
// ForgotPasswordRequest represents the forgot password request body
type ForgotPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
}
// VerifyResetCodeRequest represents the verify reset code request body
type VerifyResetCodeRequest struct {
Email string `json:"email" binding:"required,email"`
Code string `json:"code" binding:"required,len=6"`
}
// ResetPasswordRequest represents the reset password request body
type ResetPasswordRequest struct {
ResetToken string `json:"reset_token" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=8"`
}
// UpdateProfileRequest represents the profile update request body
type UpdateProfileRequest struct {
Email *string `json:"email" binding:"omitempty,email,max=254"`
FirstName *string `json:"first_name" binding:"omitempty,max=150"`
LastName *string `json:"last_name" binding:"omitempty,max=150"`
}
// ResendVerificationRequest represents the resend verification email request
type ResendVerificationRequest struct {
// No body needed - uses authenticated user's email
}

View File

@@ -0,0 +1,36 @@
package requests
// CreateContractorRequest represents the request to create a contractor
type CreateContractorRequest struct {
ResidenceID uint `json:"residence_id" binding:"required"`
Name string `json:"name" binding:"required,min=1,max=200"`
Company string `json:"company" binding:"max=200"`
Phone string `json:"phone" binding:"max=20"`
Email string `json:"email" binding:"omitempty,email,max=254"`
Website string `json:"website" binding:"max=200"`
Notes string `json:"notes"`
StreetAddress string `json:"street_address" binding:"max=255"`
City string `json:"city" binding:"max=100"`
StateProvince string `json:"state_province" binding:"max=100"`
PostalCode string `json:"postal_code" binding:"max=20"`
SpecialtyIDs []uint `json:"specialty_ids"`
Rating *float64 `json:"rating"`
IsFavorite *bool `json:"is_favorite"`
}
// UpdateContractorRequest represents the request to update a contractor
type UpdateContractorRequest struct {
Name *string `json:"name" binding:"omitempty,min=1,max=200"`
Company *string `json:"company" binding:"omitempty,max=200"`
Phone *string `json:"phone" binding:"omitempty,max=20"`
Email *string `json:"email" binding:"omitempty,email,max=254"`
Website *string `json:"website" binding:"omitempty,max=200"`
Notes *string `json:"notes"`
StreetAddress *string `json:"street_address" binding:"omitempty,max=255"`
City *string `json:"city" binding:"omitempty,max=100"`
StateProvince *string `json:"state_province" binding:"omitempty,max=100"`
PostalCode *string `json:"postal_code" binding:"omitempty,max=20"`
SpecialtyIDs []uint `json:"specialty_ids"`
Rating *float64 `json:"rating"`
IsFavorite *bool `json:"is_favorite"`
}

View File

@@ -0,0 +1,46 @@
package requests
import (
"time"
"github.com/shopspring/decimal"
"github.com/treytartt/mycrib-api/internal/models"
)
// CreateDocumentRequest represents the request to create a document
type CreateDocumentRequest struct {
ResidenceID uint `json:"residence_id" binding:"required"`
Title string `json:"title" binding:"required,min=1,max=200"`
Description string `json:"description"`
DocumentType models.DocumentType `json:"document_type"`
FileURL string `json:"file_url" binding:"max=500"`
FileName string `json:"file_name" binding:"max=255"`
FileSize *int64 `json:"file_size"`
MimeType string `json:"mime_type" binding:"max=100"`
PurchaseDate *time.Time `json:"purchase_date"`
ExpiryDate *time.Time `json:"expiry_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
Vendor string `json:"vendor" binding:"max=200"`
SerialNumber string `json:"serial_number" binding:"max=100"`
ModelNumber string `json:"model_number" binding:"max=100"`
TaskID *uint `json:"task_id"`
}
// UpdateDocumentRequest represents the request to update a document
type UpdateDocumentRequest struct {
Title *string `json:"title" binding:"omitempty,min=1,max=200"`
Description *string `json:"description"`
DocumentType *models.DocumentType `json:"document_type"`
FileURL *string `json:"file_url" binding:"omitempty,max=500"`
FileName *string `json:"file_name" binding:"omitempty,max=255"`
FileSize *int64 `json:"file_size"`
MimeType *string `json:"mime_type" binding:"omitempty,max=100"`
PurchaseDate *time.Time `json:"purchase_date"`
ExpiryDate *time.Time `json:"expiry_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
Vendor *string `json:"vendor" binding:"omitempty,max=200"`
SerialNumber *string `json:"serial_number" binding:"omitempty,max=100"`
ModelNumber *string `json:"model_number" binding:"omitempty,max=100"`
TaskID *uint `json:"task_id"`
}

View File

@@ -0,0 +1,59 @@
package requests
import (
"time"
"github.com/shopspring/decimal"
)
// CreateResidenceRequest represents the request to create a residence
type CreateResidenceRequest struct {
Name string `json:"name" binding:"required,min=1,max=200"`
PropertyTypeID *uint `json:"property_type_id"`
StreetAddress string `json:"street_address" binding:"max=255"`
ApartmentUnit string `json:"apartment_unit" binding:"max=50"`
City string `json:"city" binding:"max=100"`
StateProvince string `json:"state_province" binding:"max=100"`
PostalCode string `json:"postal_code" binding:"max=20"`
Country string `json:"country" binding:"max=100"`
Bedrooms *int `json:"bedrooms"`
Bathrooms *decimal.Decimal `json:"bathrooms"`
SquareFootage *int `json:"square_footage"`
LotSize *decimal.Decimal `json:"lot_size"`
YearBuilt *int `json:"year_built"`
Description string `json:"description"`
PurchaseDate *time.Time `json:"purchase_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
IsPrimary *bool `json:"is_primary"`
}
// UpdateResidenceRequest represents the request to update a residence
type UpdateResidenceRequest struct {
Name *string `json:"name" binding:"omitempty,min=1,max=200"`
PropertyTypeID *uint `json:"property_type_id"`
StreetAddress *string `json:"street_address" binding:"omitempty,max=255"`
ApartmentUnit *string `json:"apartment_unit" binding:"omitempty,max=50"`
City *string `json:"city" binding:"omitempty,max=100"`
StateProvince *string `json:"state_province" binding:"omitempty,max=100"`
PostalCode *string `json:"postal_code" binding:"omitempty,max=20"`
Country *string `json:"country" binding:"omitempty,max=100"`
Bedrooms *int `json:"bedrooms"`
Bathrooms *decimal.Decimal `json:"bathrooms"`
SquareFootage *int `json:"square_footage"`
LotSize *decimal.Decimal `json:"lot_size"`
YearBuilt *int `json:"year_built"`
Description *string `json:"description"`
PurchaseDate *time.Time `json:"purchase_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
IsPrimary *bool `json:"is_primary"`
}
// JoinWithCodeRequest represents the request to join a residence via share code
type JoinWithCodeRequest struct {
Code string `json:"code" binding:"required,len=6"`
}
// GenerateShareCodeRequest represents the request to generate a share code
type GenerateShareCodeRequest struct {
ExpiresInHours int `json:"expires_in_hours"` // Default: 24 hours
}

View File

@@ -0,0 +1,46 @@
package requests
import (
"time"
"github.com/shopspring/decimal"
)
// CreateTaskRequest represents the request to create a task
type CreateTaskRequest struct {
ResidenceID uint `json:"residence_id" binding:"required"`
Title string `json:"title" binding:"required,min=1,max=200"`
Description string `json:"description"`
CategoryID *uint `json:"category_id"`
PriorityID *uint `json:"priority_id"`
StatusID *uint `json:"status_id"`
FrequencyID *uint `json:"frequency_id"`
AssignedToID *uint `json:"assigned_to_id"`
DueDate *time.Time `json:"due_date"`
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
ContractorID *uint `json:"contractor_id"`
}
// UpdateTaskRequest represents the request to update a task
type UpdateTaskRequest struct {
Title *string `json:"title" binding:"omitempty,min=1,max=200"`
Description *string `json:"description"`
CategoryID *uint `json:"category_id"`
PriorityID *uint `json:"priority_id"`
StatusID *uint `json:"status_id"`
FrequencyID *uint `json:"frequency_id"`
AssignedToID *uint `json:"assigned_to_id"`
DueDate *time.Time `json:"due_date"`
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
ActualCost *decimal.Decimal `json:"actual_cost"`
ContractorID *uint `json:"contractor_id"`
}
// CreateTaskCompletionRequest represents the request to create a task completion
type CreateTaskCompletionRequest struct {
TaskID uint `json:"task_id" binding:"required"`
CompletedAt *time.Time `json:"completed_at"` // Defaults to now
Notes string `json:"notes"`
ActualCost *decimal.Decimal `json:"actual_cost"`
PhotoURL string `json:"photo_url"`
}

View File

@@ -0,0 +1,151 @@
package responses
import (
"time"
"github.com/treytartt/mycrib-api/internal/models"
)
// UserResponse represents a user in API responses
type UserResponse struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
IsActive bool `json:"is_active"`
DateJoined time.Time `json:"date_joined"`
LastLogin *time.Time `json:"last_login,omitempty"`
}
// UserProfileResponse represents a user profile in API responses
type UserProfileResponse struct {
ID uint `json:"id"`
UserID uint `json:"user_id"`
Verified bool `json:"verified"`
Bio string `json:"bio"`
PhoneNumber string `json:"phone_number"`
DateOfBirth *time.Time `json:"date_of_birth,omitempty"`
ProfilePicture string `json:"profile_picture"`
}
// LoginResponse represents the login response
type LoginResponse struct {
Token string `json:"token"`
User UserResponse `json:"user"`
}
// RegisterResponse represents the registration response
type RegisterResponse struct {
Token string `json:"token"`
User UserResponse `json:"user"`
Message string `json:"message"`
}
// CurrentUserResponse represents the /auth/me/ response
type CurrentUserResponse struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
IsActive bool `json:"is_active"`
DateJoined time.Time `json:"date_joined"`
LastLogin *time.Time `json:"last_login,omitempty"`
Profile *UserProfileResponse `json:"profile,omitempty"`
}
// VerifyEmailResponse represents the email verification response
type VerifyEmailResponse struct {
Message string `json:"message"`
Verified bool `json:"verified"`
}
// ForgotPasswordResponse represents the forgot password response
type ForgotPasswordResponse struct {
Message string `json:"message"`
}
// VerifyResetCodeResponse represents the verify reset code response
type VerifyResetCodeResponse struct {
Message string `json:"message"`
ResetToken string `json:"reset_token"`
}
// ResetPasswordResponse represents the reset password response
type ResetPasswordResponse struct {
Message string `json:"message"`
}
// MessageResponse represents a simple message response
type MessageResponse struct {
Message string `json:"message"`
}
// ErrorResponse represents an error response
type ErrorResponse struct {
Error string `json:"error"`
Details map[string]string `json:"details,omitempty"`
}
// NewUserResponse creates a UserResponse from a User model
func NewUserResponse(user *models.User) UserResponse {
return UserResponse{
ID: user.ID,
Username: user.Username,
Email: user.Email,
FirstName: user.FirstName,
LastName: user.LastName,
IsActive: user.IsActive,
DateJoined: user.DateJoined,
LastLogin: user.LastLogin,
}
}
// NewUserProfileResponse creates a UserProfileResponse from a UserProfile model
func NewUserProfileResponse(profile *models.UserProfile) *UserProfileResponse {
if profile == nil {
return nil
}
return &UserProfileResponse{
ID: profile.ID,
UserID: profile.UserID,
Verified: profile.Verified,
Bio: profile.Bio,
PhoneNumber: profile.PhoneNumber,
DateOfBirth: profile.DateOfBirth,
ProfilePicture: profile.ProfilePicture,
}
}
// NewCurrentUserResponse creates a CurrentUserResponse from a User model
func NewCurrentUserResponse(user *models.User) CurrentUserResponse {
return CurrentUserResponse{
ID: user.ID,
Username: user.Username,
Email: user.Email,
FirstName: user.FirstName,
LastName: user.LastName,
IsActive: user.IsActive,
DateJoined: user.DateJoined,
LastLogin: user.LastLogin,
Profile: NewUserProfileResponse(user.Profile),
}
}
// NewLoginResponse creates a LoginResponse
func NewLoginResponse(token string, user *models.User) LoginResponse {
return LoginResponse{
Token: token,
User: NewUserResponse(user),
}
}
// NewRegisterResponse creates a RegisterResponse
func NewRegisterResponse(token string, user *models.User) RegisterResponse {
return RegisterResponse{
Token: token,
User: NewUserResponse(user),
Message: "Registration successful. Please check your email to verify your account.",
}
}

View File

@@ -0,0 +1,139 @@
package responses
import (
"time"
"github.com/treytartt/mycrib-api/internal/models"
)
// ContractorSpecialtyResponse represents a contractor specialty
type ContractorSpecialtyResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Icon string `json:"icon"`
DisplayOrder int `json:"display_order"`
}
// ContractorUserResponse represents a user in contractor context
type ContractorUserResponse struct {
ID uint `json:"id"`
Username string `json:"username"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
}
// ContractorResponse represents a contractor in the API response
type ContractorResponse struct {
ID uint `json:"id"`
ResidenceID uint `json:"residence_id"`
CreatedByID uint `json:"created_by_id"`
CreatedBy *ContractorUserResponse `json:"created_by,omitempty"`
Name string `json:"name"`
Company string `json:"company"`
Phone string `json:"phone"`
Email string `json:"email"`
Website string `json:"website"`
Notes string `json:"notes"`
StreetAddress string `json:"street_address"`
City string `json:"city"`
StateProvince string `json:"state_province"`
PostalCode string `json:"postal_code"`
Specialties []ContractorSpecialtyResponse `json:"specialties"`
Rating *float64 `json:"rating"`
IsFavorite bool `json:"is_favorite"`
IsActive bool `json:"is_active"`
TaskCount int `json:"task_count,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ContractorListResponse represents a paginated list of contractors
type ContractorListResponse struct {
Count int `json:"count"`
Next *string `json:"next"`
Previous *string `json:"previous"`
Results []ContractorResponse `json:"results"`
}
// ToggleFavoriteResponse represents the response after toggling favorite
type ToggleFavoriteResponse struct {
Message string `json:"message"`
IsFavorite bool `json:"is_favorite"`
}
// === Factory Functions ===
// NewContractorSpecialtyResponse creates a ContractorSpecialtyResponse from a model
func NewContractorSpecialtyResponse(s *models.ContractorSpecialty) ContractorSpecialtyResponse {
return ContractorSpecialtyResponse{
ID: s.ID,
Name: s.Name,
Description: s.Description,
Icon: s.Icon,
DisplayOrder: s.DisplayOrder,
}
}
// NewContractorUserResponse creates a ContractorUserResponse from a User model
func NewContractorUserResponse(u *models.User) *ContractorUserResponse {
if u == nil {
return nil
}
return &ContractorUserResponse{
ID: u.ID,
Username: u.Username,
FirstName: u.FirstName,
LastName: u.LastName,
}
}
// NewContractorResponse creates a ContractorResponse from a Contractor model
func NewContractorResponse(c *models.Contractor) ContractorResponse {
resp := ContractorResponse{
ID: c.ID,
ResidenceID: c.ResidenceID,
CreatedByID: c.CreatedByID,
Name: c.Name,
Company: c.Company,
Phone: c.Phone,
Email: c.Email,
Website: c.Website,
Notes: c.Notes,
StreetAddress: c.StreetAddress,
City: c.City,
StateProvince: c.StateProvince,
PostalCode: c.PostalCode,
Rating: c.Rating,
IsFavorite: c.IsFavorite,
IsActive: c.IsActive,
TaskCount: len(c.Tasks),
CreatedAt: c.CreatedAt,
UpdatedAt: c.UpdatedAt,
}
if c.CreatedBy.ID != 0 {
resp.CreatedBy = NewContractorUserResponse(&c.CreatedBy)
}
resp.Specialties = make([]ContractorSpecialtyResponse, len(c.Specialties))
for i, s := range c.Specialties {
resp.Specialties[i] = NewContractorSpecialtyResponse(&s)
}
return resp
}
// NewContractorListResponse creates a ContractorListResponse from a slice of contractors
func NewContractorListResponse(contractors []models.Contractor) ContractorListResponse {
results := make([]ContractorResponse, len(contractors))
for i, c := range contractors {
results[i] = NewContractorResponse(&c)
}
return ContractorListResponse{
Count: len(contractors),
Next: nil,
Previous: nil,
Results: results,
}
}

View File

@@ -0,0 +1,111 @@
package responses
import (
"time"
"github.com/shopspring/decimal"
"github.com/treytartt/mycrib-api/internal/models"
)
// DocumentUserResponse represents a user in document context
type DocumentUserResponse struct {
ID uint `json:"id"`
Username string `json:"username"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
}
// DocumentResponse represents a document in the API response
type DocumentResponse struct {
ID uint `json:"id"`
ResidenceID uint `json:"residence_id"`
CreatedByID uint `json:"created_by_id"`
CreatedBy *DocumentUserResponse `json:"created_by,omitempty"`
Title string `json:"title"`
Description string `json:"description"`
DocumentType models.DocumentType `json:"document_type"`
FileURL string `json:"file_url"`
FileName string `json:"file_name"`
FileSize *int64 `json:"file_size"`
MimeType string `json:"mime_type"`
PurchaseDate *time.Time `json:"purchase_date"`
ExpiryDate *time.Time `json:"expiry_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
Vendor string `json:"vendor"`
SerialNumber string `json:"serial_number"`
ModelNumber string `json:"model_number"`
TaskID *uint `json:"task_id"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// DocumentListResponse represents a paginated list of documents
type DocumentListResponse struct {
Count int `json:"count"`
Next *string `json:"next"`
Previous *string `json:"previous"`
Results []DocumentResponse `json:"results"`
}
// === Factory Functions ===
// NewDocumentUserResponse creates a DocumentUserResponse from a User model
func NewDocumentUserResponse(u *models.User) *DocumentUserResponse {
if u == nil {
return nil
}
return &DocumentUserResponse{
ID: u.ID,
Username: u.Username,
FirstName: u.FirstName,
LastName: u.LastName,
}
}
// NewDocumentResponse creates a DocumentResponse from a Document model
func NewDocumentResponse(d *models.Document) DocumentResponse {
resp := DocumentResponse{
ID: d.ID,
ResidenceID: d.ResidenceID,
CreatedByID: d.CreatedByID,
Title: d.Title,
Description: d.Description,
DocumentType: d.DocumentType,
FileURL: d.FileURL,
FileName: d.FileName,
FileSize: d.FileSize,
MimeType: d.MimeType,
PurchaseDate: d.PurchaseDate,
ExpiryDate: d.ExpiryDate,
PurchasePrice: d.PurchasePrice,
Vendor: d.Vendor,
SerialNumber: d.SerialNumber,
ModelNumber: d.ModelNumber,
TaskID: d.TaskID,
IsActive: d.IsActive,
CreatedAt: d.CreatedAt,
UpdatedAt: d.UpdatedAt,
}
if d.CreatedBy.ID != 0 {
resp.CreatedBy = NewDocumentUserResponse(&d.CreatedBy)
}
return resp
}
// NewDocumentListResponse creates a DocumentListResponse from a slice of documents
func NewDocumentListResponse(documents []models.Document) DocumentListResponse {
results := make([]DocumentResponse, len(documents))
for i, d := range documents {
results[i] = NewDocumentResponse(&d)
}
return DocumentListResponse{
Count: len(documents),
Next: nil,
Previous: nil,
Results: results,
}
}

View File

@@ -0,0 +1,189 @@
package responses
import (
"time"
"github.com/shopspring/decimal"
"github.com/treytartt/mycrib-api/internal/models"
)
// ResidenceTypeResponse represents a residence type in the API response
type ResidenceTypeResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
}
// ResidenceUserResponse represents a user with access to a residence
type ResidenceUserResponse struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
}
// ResidenceResponse represents a residence in the API response
type ResidenceResponse struct {
ID uint `json:"id"`
OwnerID uint `json:"owner_id"`
Owner *ResidenceUserResponse `json:"owner,omitempty"`
Users []ResidenceUserResponse `json:"users,omitempty"`
Name string `json:"name"`
PropertyTypeID *uint `json:"property_type_id"`
PropertyType *ResidenceTypeResponse `json:"property_type,omitempty"`
StreetAddress string `json:"street_address"`
ApartmentUnit string `json:"apartment_unit"`
City string `json:"city"`
StateProvince string `json:"state_province"`
PostalCode string `json:"postal_code"`
Country string `json:"country"`
Bedrooms *int `json:"bedrooms"`
Bathrooms *decimal.Decimal `json:"bathrooms"`
SquareFootage *int `json:"square_footage"`
LotSize *decimal.Decimal `json:"lot_size"`
YearBuilt *int `json:"year_built"`
Description string `json:"description"`
PurchaseDate *time.Time `json:"purchase_date"`
PurchasePrice *decimal.Decimal `json:"purchase_price"`
IsPrimary bool `json:"is_primary"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ResidenceListResponse represents the paginated list of residences
type ResidenceListResponse struct {
Count int `json:"count"`
Next *string `json:"next"`
Previous *string `json:"previous"`
Results []ResidenceResponse `json:"results"`
}
// ShareCodeResponse represents a share code in the API response
type ShareCodeResponse struct {
ID uint `json:"id"`
Code string `json:"code"`
ResidenceID uint `json:"residence_id"`
CreatedByID uint `json:"created_by_id"`
IsActive bool `json:"is_active"`
ExpiresAt *time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
}
// JoinResidenceResponse represents the response after joining a residence
type JoinResidenceResponse struct {
Message string `json:"message"`
Residence ResidenceResponse `json:"residence"`
}
// GenerateShareCodeResponse represents the response after generating a share code
type GenerateShareCodeResponse struct {
Message string `json:"message"`
ShareCode ShareCodeResponse `json:"share_code"`
}
// === Factory Functions ===
// NewResidenceUserResponse creates a ResidenceUserResponse from a User model
func NewResidenceUserResponse(user *models.User) *ResidenceUserResponse {
if user == nil {
return nil
}
return &ResidenceUserResponse{
ID: user.ID,
Username: user.Username,
Email: user.Email,
FirstName: user.FirstName,
LastName: user.LastName,
}
}
// NewResidenceTypeResponse creates a ResidenceTypeResponse from a ResidenceType model
func NewResidenceTypeResponse(rt *models.ResidenceType) *ResidenceTypeResponse {
if rt == nil {
return nil
}
return &ResidenceTypeResponse{
ID: rt.ID,
Name: rt.Name,
}
}
// NewResidenceResponse creates a ResidenceResponse from a Residence model
func NewResidenceResponse(residence *models.Residence) ResidenceResponse {
resp := ResidenceResponse{
ID: residence.ID,
OwnerID: residence.OwnerID,
Name: residence.Name,
PropertyTypeID: residence.PropertyTypeID,
StreetAddress: residence.StreetAddress,
ApartmentUnit: residence.ApartmentUnit,
City: residence.City,
StateProvince: residence.StateProvince,
PostalCode: residence.PostalCode,
Country: residence.Country,
Bedrooms: residence.Bedrooms,
Bathrooms: residence.Bathrooms,
SquareFootage: residence.SquareFootage,
LotSize: residence.LotSize,
YearBuilt: residence.YearBuilt,
Description: residence.Description,
PurchaseDate: residence.PurchaseDate,
PurchasePrice: residence.PurchasePrice,
IsPrimary: residence.IsPrimary,
IsActive: residence.IsActive,
CreatedAt: residence.CreatedAt,
UpdatedAt: residence.UpdatedAt,
}
// Include owner if loaded
if residence.Owner.ID != 0 {
resp.Owner = NewResidenceUserResponse(&residence.Owner)
}
// Include property type if loaded
if residence.PropertyType != nil {
resp.PropertyType = NewResidenceTypeResponse(residence.PropertyType)
}
// Include shared users if loaded
if len(residence.Users) > 0 {
resp.Users = make([]ResidenceUserResponse, len(residence.Users))
for i, user := range residence.Users {
resp.Users[i] = *NewResidenceUserResponse(&user)
}
} else {
resp.Users = []ResidenceUserResponse{}
}
return resp
}
// NewResidenceListResponse creates a paginated list response
func NewResidenceListResponse(residences []models.Residence) ResidenceListResponse {
results := make([]ResidenceResponse, len(residences))
for i, r := range residences {
results[i] = NewResidenceResponse(&r)
}
return ResidenceListResponse{
Count: len(residences),
Next: nil, // Pagination not implemented yet
Previous: nil,
Results: results,
}
}
// NewShareCodeResponse creates a ShareCodeResponse from a ResidenceShareCode model
func NewShareCodeResponse(sc *models.ResidenceShareCode) ShareCodeResponse {
return ShareCodeResponse{
ID: sc.ID,
Code: sc.Code,
ResidenceID: sc.ResidenceID,
CreatedByID: sc.CreatedByID,
IsActive: sc.IsActive,
ExpiresAt: sc.ExpiresAt,
CreatedAt: sc.CreatedAt,
}
}

View File

@@ -0,0 +1,324 @@
package responses
import (
"fmt"
"time"
"github.com/shopspring/decimal"
"github.com/treytartt/mycrib-api/internal/models"
)
// TaskCategoryResponse represents a task category
type TaskCategoryResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Icon string `json:"icon"`
Color string `json:"color"`
DisplayOrder int `json:"display_order"`
}
// TaskPriorityResponse represents a task priority
type TaskPriorityResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
Level int `json:"level"`
Color string `json:"color"`
DisplayOrder int `json:"display_order"`
}
// TaskStatusResponse represents a task status
type TaskStatusResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Color string `json:"color"`
DisplayOrder int `json:"display_order"`
}
// TaskFrequencyResponse represents a task frequency
type TaskFrequencyResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
Days *int `json:"days"`
DisplayOrder int `json:"display_order"`
}
// TaskUserResponse represents a user in task context
type TaskUserResponse struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
}
// TaskCompletionResponse represents a task completion
type TaskCompletionResponse struct {
ID uint `json:"id"`
TaskID uint `json:"task_id"`
CompletedBy *TaskUserResponse `json:"completed_by,omitempty"`
CompletedAt time.Time `json:"completed_at"`
Notes string `json:"notes"`
ActualCost *decimal.Decimal `json:"actual_cost"`
PhotoURL string `json:"photo_url"`
CreatedAt time.Time `json:"created_at"`
}
// TaskResponse represents a task in the API response
type TaskResponse struct {
ID uint `json:"id"`
ResidenceID uint `json:"residence_id"`
CreatedByID uint `json:"created_by_id"`
CreatedBy *TaskUserResponse `json:"created_by,omitempty"`
AssignedToID *uint `json:"assigned_to_id"`
AssignedTo *TaskUserResponse `json:"assigned_to,omitempty"`
Title string `json:"title"`
Description string `json:"description"`
CategoryID *uint `json:"category_id"`
Category *TaskCategoryResponse `json:"category,omitempty"`
PriorityID *uint `json:"priority_id"`
Priority *TaskPriorityResponse `json:"priority,omitempty"`
StatusID *uint `json:"status_id"`
Status *TaskStatusResponse `json:"status,omitempty"`
FrequencyID *uint `json:"frequency_id"`
Frequency *TaskFrequencyResponse `json:"frequency,omitempty"`
DueDate *time.Time `json:"due_date"`
EstimatedCost *decimal.Decimal `json:"estimated_cost"`
ActualCost *decimal.Decimal `json:"actual_cost"`
ContractorID *uint `json:"contractor_id"`
IsCancelled bool `json:"is_cancelled"`
IsArchived bool `json:"is_archived"`
ParentTaskID *uint `json:"parent_task_id"`
Completions []TaskCompletionResponse `json:"completions,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// TaskListResponse represents a paginated list of tasks
type TaskListResponse struct {
Count int `json:"count"`
Next *string `json:"next"`
Previous *string `json:"previous"`
Results []TaskResponse `json:"results"`
}
// KanbanColumnResponse represents a kanban column
type KanbanColumnResponse struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
ButtonTypes []string `json:"button_types"`
Icons map[string]string `json:"icons"`
Color string `json:"color"`
Tasks []TaskResponse `json:"tasks"`
Count int `json:"count"`
}
// KanbanBoardResponse represents the kanban board
type KanbanBoardResponse struct {
Columns []KanbanColumnResponse `json:"columns"`
DaysThreshold int `json:"days_threshold"`
ResidenceID string `json:"residence_id"`
}
// TaskCompletionListResponse represents a list of completions
type TaskCompletionListResponse struct {
Count int `json:"count"`
Next *string `json:"next"`
Previous *string `json:"previous"`
Results []TaskCompletionResponse `json:"results"`
}
// === Factory Functions ===
// NewTaskCategoryResponse creates a TaskCategoryResponse from a model
func NewTaskCategoryResponse(c *models.TaskCategory) *TaskCategoryResponse {
if c == nil {
return nil
}
return &TaskCategoryResponse{
ID: c.ID,
Name: c.Name,
Description: c.Description,
Icon: c.Icon,
Color: c.Color,
DisplayOrder: c.DisplayOrder,
}
}
// NewTaskPriorityResponse creates a TaskPriorityResponse from a model
func NewTaskPriorityResponse(p *models.TaskPriority) *TaskPriorityResponse {
if p == nil {
return nil
}
return &TaskPriorityResponse{
ID: p.ID,
Name: p.Name,
Level: p.Level,
Color: p.Color,
DisplayOrder: p.DisplayOrder,
}
}
// NewTaskStatusResponse creates a TaskStatusResponse from a model
func NewTaskStatusResponse(s *models.TaskStatus) *TaskStatusResponse {
if s == nil {
return nil
}
return &TaskStatusResponse{
ID: s.ID,
Name: s.Name,
Description: s.Description,
Color: s.Color,
DisplayOrder: s.DisplayOrder,
}
}
// NewTaskFrequencyResponse creates a TaskFrequencyResponse from a model
func NewTaskFrequencyResponse(f *models.TaskFrequency) *TaskFrequencyResponse {
if f == nil {
return nil
}
return &TaskFrequencyResponse{
ID: f.ID,
Name: f.Name,
Days: f.Days,
DisplayOrder: f.DisplayOrder,
}
}
// NewTaskUserResponse creates a TaskUserResponse from a User model
func NewTaskUserResponse(u *models.User) *TaskUserResponse {
if u == nil {
return nil
}
return &TaskUserResponse{
ID: u.ID,
Username: u.Username,
Email: u.Email,
FirstName: u.FirstName,
LastName: u.LastName,
}
}
// NewTaskCompletionResponse creates a TaskCompletionResponse from a model
func NewTaskCompletionResponse(c *models.TaskCompletion) TaskCompletionResponse {
resp := TaskCompletionResponse{
ID: c.ID,
TaskID: c.TaskID,
CompletedAt: c.CompletedAt,
Notes: c.Notes,
ActualCost: c.ActualCost,
PhotoURL: c.PhotoURL,
CreatedAt: c.CreatedAt,
}
if c.CompletedBy.ID != 0 {
resp.CompletedBy = NewTaskUserResponse(&c.CompletedBy)
}
return resp
}
// NewTaskResponse creates a TaskResponse from a Task model
func NewTaskResponse(t *models.Task) TaskResponse {
resp := TaskResponse{
ID: t.ID,
ResidenceID: t.ResidenceID,
CreatedByID: t.CreatedByID,
Title: t.Title,
Description: t.Description,
CategoryID: t.CategoryID,
PriorityID: t.PriorityID,
StatusID: t.StatusID,
FrequencyID: t.FrequencyID,
AssignedToID: t.AssignedToID,
DueDate: t.DueDate,
EstimatedCost: t.EstimatedCost,
ActualCost: t.ActualCost,
ContractorID: t.ContractorID,
IsCancelled: t.IsCancelled,
IsArchived: t.IsArchived,
ParentTaskID: t.ParentTaskID,
CreatedAt: t.CreatedAt,
UpdatedAt: t.UpdatedAt,
}
if t.CreatedBy.ID != 0 {
resp.CreatedBy = NewTaskUserResponse(&t.CreatedBy)
}
if t.AssignedTo != nil {
resp.AssignedTo = NewTaskUserResponse(t.AssignedTo)
}
if t.Category != nil {
resp.Category = NewTaskCategoryResponse(t.Category)
}
if t.Priority != nil {
resp.Priority = NewTaskPriorityResponse(t.Priority)
}
if t.Status != nil {
resp.Status = NewTaskStatusResponse(t.Status)
}
if t.Frequency != nil {
resp.Frequency = NewTaskFrequencyResponse(t.Frequency)
}
resp.Completions = make([]TaskCompletionResponse, len(t.Completions))
for i, c := range t.Completions {
resp.Completions[i] = NewTaskCompletionResponse(&c)
}
return resp
}
// NewTaskListResponse creates a TaskListResponse from a slice of tasks
func NewTaskListResponse(tasks []models.Task) TaskListResponse {
results := make([]TaskResponse, len(tasks))
for i, t := range tasks {
results[i] = NewTaskResponse(&t)
}
return TaskListResponse{
Count: len(tasks),
Next: nil,
Previous: nil,
Results: results,
}
}
// NewKanbanBoardResponse creates a KanbanBoardResponse from a KanbanBoard model
func NewKanbanBoardResponse(board *models.KanbanBoard, residenceID uint) KanbanBoardResponse {
columns := make([]KanbanColumnResponse, len(board.Columns))
for i, col := range board.Columns {
tasks := make([]TaskResponse, len(col.Tasks))
for j, t := range col.Tasks {
tasks[j] = NewTaskResponse(&t)
}
columns[i] = KanbanColumnResponse{
Name: col.Name,
DisplayName: col.DisplayName,
ButtonTypes: col.ButtonTypes,
Icons: col.Icons,
Color: col.Color,
Tasks: tasks,
Count: col.Count,
}
}
return KanbanBoardResponse{
Columns: columns,
DaysThreshold: board.DaysThreshold,
ResidenceID: fmt.Sprintf("%d", residenceID),
}
}
// NewTaskCompletionListResponse creates a TaskCompletionListResponse
func NewTaskCompletionListResponse(completions []models.TaskCompletion) TaskCompletionListResponse {
results := make([]TaskCompletionResponse, len(completions))
for i, c := range completions {
results[i] = NewTaskCompletionResponse(&c)
}
return TaskCompletionListResponse{
Count: len(completions),
Next: nil,
Previous: nil,
Results: results,
}
}

View File

@@ -0,0 +1,364 @@
package handlers
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/dto/responses"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/services"
)
// AuthHandler handles authentication endpoints
type AuthHandler struct {
authService *services.AuthService
emailService *services.EmailService
cache *services.CacheService
}
// NewAuthHandler creates a new auth handler
func NewAuthHandler(authService *services.AuthService, emailService *services.EmailService, cache *services.CacheService) *AuthHandler {
return &AuthHandler{
authService: authService,
emailService: emailService,
cache: cache,
}
}
// Login handles POST /api/auth/login/
func (h *AuthHandler) Login(c *gin.Context) {
var req requests.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
response, err := h.authService.Login(&req)
if err != nil {
status := http.StatusUnauthorized
message := "Invalid credentials"
if errors.Is(err, services.ErrUserInactive) {
message = "Account is inactive"
}
log.Debug().Err(err).Str("identifier", req.Username).Msg("Login failed")
c.JSON(status, responses.ErrorResponse{Error: message})
return
}
c.JSON(http.StatusOK, response)
}
// Register handles POST /api/auth/register/
func (h *AuthHandler) Register(c *gin.Context) {
var req requests.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
response, confirmationCode, err := h.authService.Register(&req)
if err != nil {
status := http.StatusBadRequest
message := err.Error()
if errors.Is(err, services.ErrUsernameTaken) {
message = "Username already taken"
} else if errors.Is(err, services.ErrEmailTaken) {
message = "Email already registered"
} else {
status = http.StatusInternalServerError
message = "Registration failed"
log.Error().Err(err).Msg("Registration failed")
}
c.JSON(status, responses.ErrorResponse{Error: message})
return
}
// Send welcome email with confirmation code (async)
if h.emailService != nil && confirmationCode != "" {
go func() {
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
}
}()
}
c.JSON(http.StatusCreated, response)
}
// Logout handles POST /api/auth/logout/
func (h *AuthHandler) Logout(c *gin.Context) {
token := middleware.GetAuthToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, responses.ErrorResponse{Error: "Not authenticated"})
return
}
// Invalidate token in database
if err := h.authService.Logout(token); err != nil {
log.Warn().Err(err).Msg("Failed to delete token from database")
}
// Invalidate token in cache
if h.cache != nil {
if err := h.cache.InvalidateAuthToken(c.Request.Context(), token); err != nil {
log.Warn().Err(err).Msg("Failed to invalidate token in cache")
}
}
c.JSON(http.StatusOK, responses.MessageResponse{Message: "Logged out successfully"})
}
// CurrentUser handles GET /api/auth/me/
func (h *AuthHandler) CurrentUser(c *gin.Context) {
user := middleware.MustGetAuthUser(c)
if user == nil {
return
}
response, err := h.authService.GetCurrentUser(user.ID)
if err != nil {
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to get current user")
c.JSON(http.StatusInternalServerError, responses.ErrorResponse{Error: "Failed to get user"})
return
}
c.JSON(http.StatusOK, response)
}
// UpdateProfile handles PUT/PATCH /api/auth/profile/
func (h *AuthHandler) UpdateProfile(c *gin.Context) {
user := middleware.MustGetAuthUser(c)
if user == nil {
return
}
var req requests.UpdateProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
response, err := h.authService.UpdateProfile(user.ID, &req)
if err != nil {
if errors.Is(err, services.ErrEmailTaken) {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{Error: "Email already taken"})
return
}
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to update profile")
c.JSON(http.StatusInternalServerError, responses.ErrorResponse{Error: "Failed to update profile"})
return
}
c.JSON(http.StatusOK, response)
}
// VerifyEmail handles POST /api/auth/verify-email/
func (h *AuthHandler) VerifyEmail(c *gin.Context) {
user := middleware.MustGetAuthUser(c)
if user == nil {
return
}
var req requests.VerifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
err := h.authService.VerifyEmail(user.ID, req.Code)
if err != nil {
status := http.StatusBadRequest
message := err.Error()
if errors.Is(err, services.ErrInvalidCode) {
message = "Invalid verification code"
} else if errors.Is(err, services.ErrCodeExpired) {
message = "Verification code has expired"
} else if errors.Is(err, services.ErrAlreadyVerified) {
message = "Email already verified"
} else {
status = http.StatusInternalServerError
message = "Verification failed"
log.Error().Err(err).Uint("user_id", user.ID).Msg("Email verification failed")
}
c.JSON(status, responses.ErrorResponse{Error: message})
return
}
c.JSON(http.StatusOK, responses.VerifyEmailResponse{
Message: "Email verified successfully",
Verified: true,
})
}
// ResendVerification handles POST /api/auth/resend-verification/
func (h *AuthHandler) ResendVerification(c *gin.Context) {
user := middleware.MustGetAuthUser(c)
if user == nil {
return
}
code, err := h.authService.ResendVerificationCode(user.ID)
if err != nil {
if errors.Is(err, services.ErrAlreadyVerified) {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{Error: "Email already verified"})
return
}
log.Error().Err(err).Uint("user_id", user.ID).Msg("Failed to resend verification")
c.JSON(http.StatusInternalServerError, responses.ErrorResponse{Error: "Failed to resend verification"})
return
}
// Send verification email (async)
if h.emailService != nil {
go func() {
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
}
}()
}
c.JSON(http.StatusOK, responses.MessageResponse{Message: "Verification email sent"})
}
// ForgotPassword handles POST /api/auth/forgot-password/
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
var req requests.ForgotPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
code, user, err := h.authService.ForgotPassword(req.Email)
if err != nil {
if errors.Is(err, services.ErrRateLimitExceeded) {
c.JSON(http.StatusTooManyRequests, responses.ErrorResponse{
Error: "Too many password reset requests. Please try again later.",
})
return
}
log.Error().Err(err).Str("email", req.Email).Msg("Forgot password failed")
// Don't reveal errors to prevent email enumeration
}
// Send password reset email (async) - only if user found
if h.emailService != nil && code != "" && user != nil {
go func() {
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
}
}()
}
// Always return success to prevent email enumeration
c.JSON(http.StatusOK, responses.ForgotPasswordResponse{
Message: "If an account with that email exists, a password reset code has been sent.",
})
}
// VerifyResetCode handles POST /api/auth/verify-reset-code/
func (h *AuthHandler) VerifyResetCode(c *gin.Context) {
var req requests.VerifyResetCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
resetToken, err := h.authService.VerifyResetCode(req.Email, req.Code)
if err != nil {
status := http.StatusBadRequest
message := "Invalid verification code"
if errors.Is(err, services.ErrCodeExpired) {
message = "Verification code has expired"
} else if errors.Is(err, services.ErrRateLimitExceeded) {
status = http.StatusTooManyRequests
message = "Too many attempts. Please request a new code."
}
c.JSON(status, responses.ErrorResponse{Error: message})
return
}
c.JSON(http.StatusOK, responses.VerifyResetCodeResponse{
Message: "Code verified successfully",
ResetToken: resetToken,
})
}
// ResetPassword handles POST /api/auth/reset-password/
func (h *AuthHandler) ResetPassword(c *gin.Context) {
var req requests.ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, responses.ErrorResponse{
Error: "Invalid request body",
Details: map[string]string{
"validation": err.Error(),
},
})
return
}
err := h.authService.ResetPassword(req.ResetToken, req.NewPassword)
if err != nil {
status := http.StatusBadRequest
message := "Invalid or expired reset token"
if errors.Is(err, services.ErrInvalidResetToken) {
message = "Invalid or expired reset token"
} else {
status = http.StatusInternalServerError
message = "Password reset failed"
log.Error().Err(err).Msg("Password reset failed")
}
c.JSON(status, responses.ErrorResponse{Error: message})
return
}
c.JSON(http.StatusOK, responses.ResetPasswordResponse{
Message: "Password reset successfully. Please log in with your new password.",
})
}

View File

@@ -0,0 +1,192 @@
package handlers
import (
"errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
// ContractorHandler handles contractor-related HTTP requests
type ContractorHandler struct {
contractorService *services.ContractorService
}
// NewContractorHandler creates a new contractor handler
func NewContractorHandler(contractorService *services.ContractorService) *ContractorHandler {
return &ContractorHandler{contractorService: contractorService}
}
// ListContractors handles GET /api/contractors/
func (h *ContractorHandler) ListContractors(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.contractorService.ListContractors(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetContractor handles GET /api/contractors/:id/
func (h *ContractorHandler) GetContractor(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contractor ID"})
return
}
response, err := h.contractorService.GetContractor(uint(contractorID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrContractorNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrContractorAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// CreateContractor handles POST /api/contractors/
func (h *ContractorHandler) CreateContractor(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.CreateContractorRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.contractorService.CreateContractor(&req, user.ID)
if err != nil {
if errors.Is(err, services.ErrResidenceAccessDenied) {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, response)
}
// UpdateContractor handles PUT/PATCH /api/contractors/:id/
func (h *ContractorHandler) UpdateContractor(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contractor ID"})
return
}
var req requests.UpdateContractorRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
if err != nil {
switch {
case errors.Is(err, services.ErrContractorNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrContractorAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// DeleteContractor handles DELETE /api/contractors/:id/
func (h *ContractorHandler) DeleteContractor(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contractor ID"})
return
}
err = h.contractorService.DeleteContractor(uint(contractorID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrContractorNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrContractorAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Contractor deleted successfully"})
}
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
func (h *ContractorHandler) ToggleFavorite(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contractor ID"})
return
}
response, err := h.contractorService.ToggleFavorite(uint(contractorID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrContractorNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrContractorAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// GetContractorTasks handles GET /api/contractors/:id/tasks/
func (h *ContractorHandler) GetContractorTasks(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid contractor ID"})
return
}
response, err := h.contractorService.GetContractorTasks(uint(contractorID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrContractorNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrContractorAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// GetSpecialties handles GET /api/contractors/specialties/
func (h *ContractorHandler) GetSpecialties(c *gin.Context) {
specialties, err := h.contractorService.GetSpecialties()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, specialties)
}

View File

@@ -0,0 +1,193 @@
package handlers
import (
"errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
// DocumentHandler handles document-related HTTP requests
type DocumentHandler struct {
documentService *services.DocumentService
}
// NewDocumentHandler creates a new document handler
func NewDocumentHandler(documentService *services.DocumentService) *DocumentHandler {
return &DocumentHandler{documentService: documentService}
}
// ListDocuments handles GET /api/documents/
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.documentService.ListDocuments(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetDocument handles GET /api/documents/:id/
func (h *DocumentHandler) GetDocument(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
response, err := h.documentService.GetDocument(uint(documentID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrDocumentNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrDocumentAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// ListWarranties handles GET /api/documents/warranties/
func (h *DocumentHandler) ListWarranties(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.documentService.ListWarranties(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// CreateDocument handles POST /api/documents/
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.CreateDocumentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.documentService.CreateDocument(&req, user.ID)
if err != nil {
if errors.Is(err, services.ErrResidenceAccessDenied) {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, response)
}
// UpdateDocument handles PUT/PATCH /api/documents/:id/
func (h *DocumentHandler) UpdateDocument(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
var req requests.UpdateDocumentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
if err != nil {
switch {
case errors.Is(err, services.ErrDocumentNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrDocumentAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// DeleteDocument handles DELETE /api/documents/:id/
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
err = h.documentService.DeleteDocument(uint(documentID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrDocumentNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrDocumentAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
}
// ActivateDocument handles POST /api/documents/:id/activate/
func (h *DocumentHandler) ActivateDocument(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
response, err := h.documentService.ActivateDocument(uint(documentID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrDocumentNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrDocumentAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Document activated", "document": response})
}
// DeactivateDocument handles POST /api/documents/:id/deactivate/
func (h *DocumentHandler) DeactivateDocument(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
response, err := h.documentService.DeactivateDocument(uint(documentID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrDocumentNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrDocumentAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Document deactivated", "document": response})
}

View File

@@ -0,0 +1,197 @@
package handlers
import (
"errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
// NotificationHandler handles notification-related HTTP requests
type NotificationHandler struct {
notificationService *services.NotificationService
}
// NewNotificationHandler creates a new notification handler
func NewNotificationHandler(notificationService *services.NotificationService) *NotificationHandler {
return &NotificationHandler{notificationService: notificationService}
}
// ListNotifications handles GET /api/notifications/
func (h *NotificationHandler) ListNotifications(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
limit := 50
offset := 0
if l := c.Query("limit"); l != "" {
if parsed, err := strconv.Atoi(l); err == nil && parsed > 0 {
limit = parsed
}
}
if o := c.Query("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
}
}
notifications, err := h.notificationService.GetNotifications(user.ID, limit, offset)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"count": len(notifications),
"results": notifications,
})
}
// GetUnreadCount handles GET /api/notifications/unread-count/
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
count, err := h.notificationService.GetUnreadCount(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"unread_count": count})
}
// MarkAsRead handles POST /api/notifications/:id/read/
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
notificationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid notification ID"})
return
}
err = h.notificationService.MarkAsRead(uint(notificationID), user.ID)
if err != nil {
if errors.Is(err, services.ErrNotificationNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Notification marked as read"})
}
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
err := h.notificationService.MarkAllAsRead(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "All notifications marked as read"})
}
// GetPreferences handles GET /api/notifications/preferences/
func (h *NotificationHandler) GetPreferences(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
prefs, err := h.notificationService.GetPreferences(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, prefs)
}
// UpdatePreferences handles PUT/PATCH /api/notifications/preferences/
func (h *NotificationHandler) UpdatePreferences(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req services.UpdatePreferencesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, prefs)
}
// RegisterDevice handles POST /api/notifications/devices/
func (h *NotificationHandler) RegisterDevice(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req services.RegisterDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
device, err := h.notificationService.RegisterDevice(user.ID, &req)
if err != nil {
if errors.Is(err, services.ErrInvalidPlatform) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, device)
}
// ListDevices handles GET /api/notifications/devices/
func (h *NotificationHandler) ListDevices(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
devices, err := h.notificationService.ListDevices(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, devices)
}
// DeleteDevice handles DELETE /api/notifications/devices/:id/
func (h *NotificationHandler) DeleteDevice(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
deviceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid device ID"})
return
}
platform := c.Query("platform")
if platform == "" {
platform = "ios" // Default to iOS
}
err = h.notificationService.DeleteDevice(uint(deviceID), platform, user.ID)
if err != nil {
if errors.Is(err, services.ErrInvalidPlatform) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Device removed"})
}

View File

@@ -0,0 +1,288 @@
package handlers
import (
"errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
// ResidenceHandler handles residence-related HTTP requests
type ResidenceHandler struct {
residenceService *services.ResidenceService
}
// NewResidenceHandler creates a new residence handler
func NewResidenceHandler(residenceService *services.ResidenceService) *ResidenceHandler {
return &ResidenceHandler{
residenceService: residenceService,
}
}
// ListResidences handles GET /api/residences/
func (h *ResidenceHandler) ListResidences(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.residenceService.ListResidences(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetMyResidences handles GET /api/residences/my-residences/
func (h *ResidenceHandler) GetMyResidences(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.residenceService.GetMyResidences(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetResidence handles GET /api/residences/:id/
func (h *ResidenceHandler) GetResidence(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
response, err := h.residenceService.GetResidence(uint(residenceID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrResidenceAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// CreateResidence handles POST /api/residences/
func (h *ResidenceHandler) CreateResidence(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.CreateResidenceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.residenceService.CreateResidence(&req, user.ID)
if err != nil {
if errors.Is(err, services.ErrPropertiesLimitReached) {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, response)
}
// UpdateResidence handles PUT/PATCH /api/residences/:id/
func (h *ResidenceHandler) UpdateResidence(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
var req requests.UpdateResidenceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.residenceService.UpdateResidence(uint(residenceID), user.ID, &req)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrNotResidenceOwner):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// DeleteResidence handles DELETE /api/residences/:id/
func (h *ResidenceHandler) DeleteResidence(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
err = h.residenceService.DeleteResidence(uint(residenceID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrNotResidenceOwner):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Residence deleted successfully"})
}
// GenerateShareCode handles POST /api/residences/:id/generate-share-code/
func (h *ResidenceHandler) GenerateShareCode(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
var req requests.GenerateShareCodeRequest
// Request body is optional
c.ShouldBindJSON(&req)
response, err := h.residenceService.GenerateShareCode(uint(residenceID), user.ID, req.ExpiresInHours)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrNotResidenceOwner):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// JoinWithCode handles POST /api/residences/join-with-code/
func (h *ResidenceHandler) JoinWithCode(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.JoinWithCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.residenceService.JoinWithCode(req.Code, user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrShareCodeInvalid):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrShareCodeExpired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrUserAlreadyMember):
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// GetResidenceUsers handles GET /api/residences/:id/users/
func (h *ResidenceHandler) GetResidenceUsers(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
users, err := h.residenceService.GetResidenceUsers(uint(residenceID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrResidenceAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, users)
}
// RemoveResidenceUser handles DELETE /api/residences/:id/users/:user_id/
func (h *ResidenceHandler) RemoveResidenceUser(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
userIDToRemove, err := strconv.ParseUint(c.Param("user_id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
err = h.residenceService.RemoveUser(uint(residenceID), uint(userIDToRemove), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrNotResidenceOwner):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrCannotRemoveOwner):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "User removed from residence"})
}
// GetResidenceTypes handles GET /api/residences/types/
func (h *ResidenceHandler) GetResidenceTypes(c *gin.Context) {
types, err := h.residenceService.GetResidenceTypes()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, types)
}

View File

@@ -0,0 +1,176 @@
package handlers
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
// SubscriptionHandler handles subscription-related HTTP requests
type SubscriptionHandler struct {
subscriptionService *services.SubscriptionService
}
// NewSubscriptionHandler creates a new subscription handler
func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{subscriptionService: subscriptionService}
}
// GetSubscription handles GET /api/subscription/
func (h *SubscriptionHandler) GetSubscription(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
subscription, err := h.subscriptionService.GetSubscription(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, subscription)
}
// GetSubscriptionStatus handles GET /api/subscription/status/
func (h *SubscriptionHandler) GetSubscriptionStatus(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, status)
}
// GetUpgradeTrigger handles GET /api/subscription/upgrade-trigger/:key/
func (h *SubscriptionHandler) GetUpgradeTrigger(c *gin.Context) {
key := c.Param("key")
trigger, err := h.subscriptionService.GetUpgradeTrigger(key)
if err != nil {
if errors.Is(err, services.ErrUpgradeTriggerNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, trigger)
}
// GetFeatureBenefits handles GET /api/subscription/features/
func (h *SubscriptionHandler) GetFeatureBenefits(c *gin.Context) {
benefits, err := h.subscriptionService.GetFeatureBenefits()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, benefits)
}
// GetPromotions handles GET /api/subscription/promotions/
func (h *SubscriptionHandler) GetPromotions(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, promotions)
}
// ProcessPurchase handles POST /api/subscription/purchase/
func (h *SubscriptionHandler) ProcessPurchase(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req services.ProcessPurchaseRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var subscription *services.SubscriptionResponse
var err error
switch req.Platform {
case "ios":
if req.ReceiptData == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "receipt_data is required for iOS"})
return
}
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData)
case "android":
if req.PurchaseToken == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "purchase_token is required for Android"})
return
}
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Subscription upgraded successfully",
"subscription": subscription,
})
}
// CancelSubscription handles POST /api/subscription/cancel/
func (h *SubscriptionHandler) CancelSubscription(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Subscription cancelled. You will retain Pro benefits until the end of your billing period.",
"subscription": subscription,
})
}
// RestoreSubscription handles POST /api/subscription/restore/
func (h *SubscriptionHandler) RestoreSubscription(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req services.ProcessPurchaseRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Same logic as ProcessPurchase - validates receipt/token and restores
var subscription *services.SubscriptionResponse
var err error
switch req.Platform {
case "ios":
subscription, err = h.subscriptionService.ProcessApplePurchase(user.ID, req.ReceiptData)
case "android":
subscription, err = h.subscriptionService.ProcessGooglePurchase(user.ID, req.PurchaseToken)
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Subscription restored successfully",
"subscription": subscription,
})
}

View File

@@ -0,0 +1,414 @@
package handlers
import (
"errors"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
// TaskHandler handles task-related HTTP requests
type TaskHandler struct {
taskService *services.TaskService
}
// NewTaskHandler creates a new task handler
func NewTaskHandler(taskService *services.TaskService) *TaskHandler {
return &TaskHandler{taskService: taskService}
}
// ListTasks handles GET /api/tasks/
func (h *TaskHandler) ListTasks(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.taskService.ListTasks(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetTask handles GET /api/tasks/:id/
func (h *TaskHandler) GetTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
response, err := h.taskService.GetTask(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// GetTasksByResidence handles GET /api/tasks/by-residence/:residence_id/
func (h *TaskHandler) GetTasksByResidence(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid residence ID"})
return
}
daysThreshold := 30
if d := c.Query("days_threshold"); d != "" {
if parsed, err := strconv.Atoi(d); err == nil {
daysThreshold = parsed
}
}
response, err := h.taskService.GetTasksByResidence(uint(residenceID), user.ID, daysThreshold)
if err != nil {
switch {
case errors.Is(err, services.ErrResidenceAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// CreateTask handles POST /api/tasks/
func (h *TaskHandler) CreateTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.CreateTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.taskService.CreateTask(&req, user.ID)
if err != nil {
if errors.Is(err, services.ErrResidenceAccessDenied) {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, response)
}
// UpdateTask handles PUT/PATCH /api/tasks/:id/
func (h *TaskHandler) UpdateTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
var req requests.UpdateTaskRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.taskService.UpdateTask(uint(taskID), user.ID, &req)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// DeleteTask handles DELETE /api/tasks/:id/
func (h *TaskHandler) DeleteTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
err = h.taskService.DeleteTask(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Task deleted successfully"})
}
// MarkInProgress handles POST /api/tasks/:id/mark-in-progress/
func (h *TaskHandler) MarkInProgress(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
response, err := h.taskService.MarkInProgress(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Task marked as in progress", "task": response})
}
// CancelTask handles POST /api/tasks/:id/cancel/
func (h *TaskHandler) CancelTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
response, err := h.taskService.CancelTask(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAlreadyCancelled):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Task cancelled", "task": response})
}
// UncancelTask handles POST /api/tasks/:id/uncancel/
func (h *TaskHandler) UncancelTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
response, err := h.taskService.UncancelTask(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Task uncancelled", "task": response})
}
// ArchiveTask handles POST /api/tasks/:id/archive/
func (h *TaskHandler) ArchiveTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
response, err := h.taskService.ArchiveTask(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAlreadyArchived):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Task archived", "task": response})
}
// UnarchiveTask handles POST /api/tasks/:id/unarchive/
func (h *TaskHandler) UnarchiveTask(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid task ID"})
return
}
response, err := h.taskService.UnarchiveTask(uint(taskID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Task unarchived", "task": response})
}
// === Task Completions ===
// ListCompletions handles GET /api/task-completions/
func (h *TaskHandler) ListCompletions(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
response, err := h.taskService.ListCompletions(user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, response)
}
// GetCompletion handles GET /api/task-completions/:id/
func (h *TaskHandler) GetCompletion(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid completion ID"})
return
}
response, err := h.taskService.GetCompletion(uint(completionID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrCompletionNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, response)
}
// CreateCompletion handles POST /api/task-completions/
func (h *TaskHandler) CreateCompletion(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
var req requests.CreateTaskCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
response, err := h.taskService.CreateCompletion(&req, user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrTaskNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusCreated, response)
}
// DeleteCompletion handles DELETE /api/task-completions/:id/
func (h *TaskHandler) DeleteCompletion(c *gin.Context) {
user := c.MustGet(middleware.AuthUserKey).(*models.User)
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid completion ID"})
return
}
err = h.taskService.DeleteCompletion(uint(completionID), user.ID)
if err != nil {
switch {
case errors.Is(err, services.ErrCompletionNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
case errors.Is(err, services.ErrTaskAccessDenied):
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, gin.H{"message": "Completion deleted successfully"})
}
// === Lookups ===
// GetCategories handles GET /api/tasks/categories/
func (h *TaskHandler) GetCategories(c *gin.Context) {
categories, err := h.taskService.GetCategories()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, categories)
}
// GetPriorities handles GET /api/tasks/priorities/
func (h *TaskHandler) GetPriorities(c *gin.Context) {
priorities, err := h.taskService.GetPriorities()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, priorities)
}
// GetStatuses handles GET /api/tasks/statuses/
func (h *TaskHandler) GetStatuses(c *gin.Context) {
statuses, err := h.taskService.GetStatuses()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, statuses)
}
// GetFrequencies handles GET /api/tasks/frequencies/
func (h *TaskHandler) GetFrequencies(c *gin.Context) {
frequencies, err := h.taskService.GetFrequencies()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, frequencies)
}

236
internal/middleware/auth.go Normal file
View File

@@ -0,0 +1,236 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/services"
)
const (
// AuthUserKey is the key used to store the authenticated user in the context
AuthUserKey = "auth_user"
// AuthTokenKey is the key used to store the token in the context
AuthTokenKey = "auth_token"
// TokenCacheTTL is the duration to cache tokens in Redis
TokenCacheTTL = 5 * time.Minute
// TokenCachePrefix is the prefix for token cache keys
TokenCachePrefix = "auth_token_"
)
// AuthMiddleware provides token authentication middleware
type AuthMiddleware struct {
db *gorm.DB
cache *services.CacheService
}
// NewAuthMiddleware creates a new auth middleware instance
func NewAuthMiddleware(db *gorm.DB, cache *services.CacheService) *AuthMiddleware {
return &AuthMiddleware{
db: db,
cache: cache,
}
}
// TokenAuth returns a Gin middleware that validates token authentication
func (m *AuthMiddleware) TokenAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract token from Authorization header
token, err := extractToken(c)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": err.Error(),
})
return
}
// Try to get user from cache first
user, err := m.getUserFromCache(c.Request.Context(), token)
if err == nil && user != nil {
// Cache hit - set user in context and continue
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
c.Next()
return
}
// Cache miss - look up token in database
user, err = m.getUserFromDatabase(token)
if err != nil {
log.Debug().Err(err).Str("token", token[:8]+"...").Msg("Token authentication failed")
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Invalid token",
})
return
}
// Cache the user ID for future requests
if cacheErr := m.cacheUserID(c.Request.Context(), token, user.ID); cacheErr != nil {
log.Warn().Err(cacheErr).Msg("Failed to cache user ID")
}
// Set user in context
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
c.Next()
}
}
// OptionalTokenAuth returns middleware that authenticates if token is present but doesn't require it
func (m *AuthMiddleware) OptionalTokenAuth() gin.HandlerFunc {
return func(c *gin.Context) {
token, err := extractToken(c)
if err != nil {
// No token or invalid format - continue without user
c.Next()
return
}
// Try cache first
user, err := m.getUserFromCache(c.Request.Context(), token)
if err == nil && user != nil {
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
c.Next()
return
}
// Try database
user, err = m.getUserFromDatabase(token)
if err == nil {
m.cacheUserID(c.Request.Context(), token, user.ID)
c.Set(AuthUserKey, user)
c.Set(AuthTokenKey, token)
}
c.Next()
}
}
// extractToken extracts the token from the Authorization header
func extractToken(c *gin.Context) (string, error) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return "", fmt.Errorf("authorization header required")
}
// Support both "Token xxx" (Django style) and "Bearer xxx" formats
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 {
return "", fmt.Errorf("invalid authorization header format")
}
scheme := parts[0]
token := parts[1]
if scheme != "Token" && scheme != "Bearer" {
return "", fmt.Errorf("invalid authorization scheme: %s", scheme)
}
if token == "" {
return "", fmt.Errorf("token is empty")
}
return token, nil
}
// getUserFromCache tries to get user from Redis cache
func (m *AuthMiddleware) getUserFromCache(ctx context.Context, token string) (*models.User, error) {
if m.cache == nil {
return nil, fmt.Errorf("cache not available")
}
userID, err := m.cache.GetCachedAuthToken(ctx, token)
if err != nil {
if err == redis.Nil {
return nil, fmt.Errorf("token not in cache")
}
return nil, err
}
// Get user from database by ID
var user models.User
if err := m.db.First(&user, userID).Error; err != nil {
// User was deleted - invalidate cache
m.cache.InvalidateAuthToken(ctx, token)
return nil, err
}
// Check if user is active
if !user.IsActive {
m.cache.InvalidateAuthToken(ctx, token)
return nil, fmt.Errorf("user is inactive")
}
return &user, nil
}
// getUserFromDatabase looks up the token in the database
func (m *AuthMiddleware) getUserFromDatabase(token string) (*models.User, error) {
var authToken models.AuthToken
if err := m.db.Preload("User").Where("key = ?", token).First(&authToken).Error; err != nil {
return nil, fmt.Errorf("token not found")
}
// Check if user is active
if !authToken.User.IsActive {
return nil, fmt.Errorf("user is inactive")
}
return &authToken.User, nil
}
// cacheUserID caches the user ID for a token
func (m *AuthMiddleware) cacheUserID(ctx context.Context, token string, userID uint) error {
if m.cache == nil {
return nil
}
return m.cache.CacheAuthToken(ctx, token, userID)
}
// InvalidateToken removes a token from the cache
func (m *AuthMiddleware) InvalidateToken(ctx context.Context, token string) error {
if m.cache == nil {
return nil
}
return m.cache.InvalidateAuthToken(ctx, token)
}
// GetAuthUser retrieves the authenticated user from the Gin context
func GetAuthUser(c *gin.Context) *models.User {
user, exists := c.Get(AuthUserKey)
if !exists {
return nil
}
return user.(*models.User)
}
// GetAuthToken retrieves the auth token from the Gin context
func GetAuthToken(c *gin.Context) string {
token, exists := c.Get(AuthTokenKey)
if !exists {
return ""
}
return token.(string)
}
// MustGetAuthUser retrieves the authenticated user or aborts with 401
func MustGetAuthUser(c *gin.Context) *models.User {
user := GetAuthUser(c)
if user == nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
})
return nil
}
return user
}

38
internal/models/base.go Normal file
View File

@@ -0,0 +1,38 @@
package models
import (
"time"
"gorm.io/gorm"
)
// BaseModel contains common columns for all tables with ID, CreatedAt, UpdatedAt
type BaseModel struct {
ID uint `gorm:"primaryKey" json:"id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// SoftDeleteModel extends BaseModel with soft delete support
type SoftDeleteModel struct {
BaseModel
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// BeforeCreate sets timestamps before creating a record
func (b *BaseModel) BeforeCreate(tx *gorm.DB) error {
now := time.Now().UTC()
if b.CreatedAt.IsZero() {
b.CreatedAt = now
}
if b.UpdatedAt.IsZero() {
b.UpdatedAt = now
}
return nil
}
// BeforeUpdate sets updated_at before updating a record
func (b *BaseModel) BeforeUpdate(tx *gorm.DB) error {
b.UpdatedAt = time.Now().UTC()
return nil
}

View File

@@ -0,0 +1,53 @@
package models
// ContractorSpecialty represents the task_contractorspecialty table
type ContractorSpecialty struct {
BaseModel
Name string `gorm:"column:name;size:50;not null" json:"name"`
Description string `gorm:"column:description;type:text" json:"description"`
Icon string `gorm:"column:icon;size:50" json:"icon"`
DisplayOrder int `gorm:"column:display_order;default:0" json:"display_order"`
}
// TableName returns the table name for GORM
func (ContractorSpecialty) TableName() string {
return "task_contractorspecialty"
}
// Contractor represents the task_contractor table
type Contractor struct {
BaseModel
ResidenceID uint `gorm:"column:residence_id;index;not null" json:"residence_id"`
Residence Residence `gorm:"foreignKey:ResidenceID" json:"-"`
CreatedByID uint `gorm:"column:created_by_id;index;not null" json:"created_by_id"`
CreatedBy User `gorm:"foreignKey:CreatedByID" json:"created_by,omitempty"`
Name string `gorm:"column:name;size:200;not null" json:"name"`
Company string `gorm:"column:company;size:200" json:"company"`
Phone string `gorm:"column:phone;size:20" json:"phone"`
Email string `gorm:"column:email;size:254" json:"email"`
Website string `gorm:"column:website;size:200" json:"website"`
Notes string `gorm:"column:notes;type:text" json:"notes"`
// Address
StreetAddress string `gorm:"column:street_address;size:255" json:"street_address"`
City string `gorm:"column:city;size:100" json:"city"`
StateProvince string `gorm:"column:state_province;size:100" json:"state_province"`
PostalCode string `gorm:"column:postal_code;size:20" json:"postal_code"`
// Specialties (many-to-many)
Specialties []ContractorSpecialty `gorm:"many2many:task_contractor_specialties;" json:"specialties,omitempty"`
// Rating and favorites
Rating *float64 `gorm:"column:rating;type:decimal(2,1)" json:"rating"`
IsFavorite bool `gorm:"column:is_favorite;default:false" json:"is_favorite"`
IsActive bool `gorm:"column:is_active;default:true;index" json:"is_active"`
// Tasks associated with this contractor
Tasks []Task `gorm:"foreignKey:ContractorID" json:"tasks,omitempty"`
}
// TableName returns the table name for GORM
func (Contractor) TableName() string {
return "task_contractor"
}

View File

@@ -0,0 +1,75 @@
package models
import (
"time"
"github.com/shopspring/decimal"
)
// DocumentType represents the type of document
type DocumentType string
const (
DocumentTypeGeneral DocumentType = "general"
DocumentTypeWarranty DocumentType = "warranty"
DocumentTypeReceipt DocumentType = "receipt"
DocumentTypeContract DocumentType = "contract"
DocumentTypeInsurance DocumentType = "insurance"
DocumentTypeManual DocumentType = "manual"
)
// Document represents the task_document table
type Document struct {
BaseModel
ResidenceID uint `gorm:"column:residence_id;index;not null" json:"residence_id"`
Residence Residence `gorm:"foreignKey:ResidenceID" json:"-"`
CreatedByID uint `gorm:"column:created_by_id;index;not null" json:"created_by_id"`
CreatedBy User `gorm:"foreignKey:CreatedByID" json:"created_by,omitempty"`
Title string `gorm:"column:title;size:200;not null" json:"title"`
Description string `gorm:"column:description;type:text" json:"description"`
DocumentType DocumentType `gorm:"column:document_type;size:20;default:'general'" json:"document_type"`
// File information
FileURL string `gorm:"column:file_url;size:500" json:"file_url"`
FileName string `gorm:"column:file_name;size:255" json:"file_name"`
FileSize *int64 `gorm:"column:file_size" json:"file_size"`
MimeType string `gorm:"column:mime_type;size:100" json:"mime_type"`
// Warranty-specific fields
PurchaseDate *time.Time `gorm:"column:purchase_date;type:date" json:"purchase_date"`
ExpiryDate *time.Time `gorm:"column:expiry_date;type:date;index" json:"expiry_date"`
PurchasePrice *decimal.Decimal `gorm:"column:purchase_price;type:decimal(10,2)" json:"purchase_price"`
Vendor string `gorm:"column:vendor;size:200" json:"vendor"`
SerialNumber string `gorm:"column:serial_number;size:100" json:"serial_number"`
ModelNumber string `gorm:"column:model_number;size:100" json:"model_number"`
// Associated task (optional)
TaskID *uint `gorm:"column:task_id;index" json:"task_id"`
Task *Task `gorm:"foreignKey:TaskID" json:"task,omitempty"`
// State
IsActive bool `gorm:"column:is_active;default:true;index" json:"is_active"`
}
// TableName returns the table name for GORM
func (Document) TableName() string {
return "task_document"
}
// IsWarrantyExpiringSoon returns true if the warranty expires within the specified days
func (d *Document) IsWarrantyExpiringSoon(days int) bool {
if d.DocumentType != DocumentTypeWarranty || d.ExpiryDate == nil {
return false
}
threshold := time.Now().UTC().AddDate(0, 0, days)
return d.ExpiryDate.Before(threshold) && d.ExpiryDate.After(time.Now().UTC())
}
// IsWarrantyExpired returns true if the warranty has expired
func (d *Document) IsWarrantyExpired() bool {
if d.DocumentType != DocumentTypeWarranty || d.ExpiryDate == nil {
return false
}
return time.Now().UTC().After(*d.ExpiryDate)
}

View File

@@ -0,0 +1,123 @@
package models
import (
"time"
)
// NotificationPreference represents the notifications_notificationpreference table
type NotificationPreference struct {
BaseModel
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
// Task notifications
TaskDueSoon bool `gorm:"column:task_due_soon;default:true" json:"task_due_soon"`
TaskOverdue bool `gorm:"column:task_overdue;default:true" json:"task_overdue"`
TaskCompleted bool `gorm:"column:task_completed;default:true" json:"task_completed"`
TaskAssigned bool `gorm:"column:task_assigned;default:true" json:"task_assigned"`
// Residence notifications
ResidenceShared bool `gorm:"column:residence_shared;default:true" json:"residence_shared"`
// Document notifications
WarrantyExpiring bool `gorm:"column:warranty_expiring;default:true" json:"warranty_expiring"`
}
// TableName returns the table name for GORM
func (NotificationPreference) TableName() string {
return "notifications_notificationpreference"
}
// NotificationType represents the type of notification
type NotificationType string
const (
NotificationTaskDueSoon NotificationType = "task_due_soon"
NotificationTaskOverdue NotificationType = "task_overdue"
NotificationTaskCompleted NotificationType = "task_completed"
NotificationTaskAssigned NotificationType = "task_assigned"
NotificationResidenceShared NotificationType = "residence_shared"
NotificationWarrantyExpiring NotificationType = "warranty_expiring"
)
// Notification represents the notifications_notification table
type Notification struct {
BaseModel
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
User User `gorm:"foreignKey:UserID" json:"-"`
NotificationType NotificationType `gorm:"column:notification_type;size:50;not null" json:"notification_type"`
Title string `gorm:"column:title;size:200;not null" json:"title"`
Body string `gorm:"column:body;type:text;not null" json:"body"`
// Related object (optional)
TaskID *uint `gorm:"column:task_id" json:"task_id"`
// Task *Task `gorm:"foreignKey:TaskID" json:"task,omitempty"` // Uncomment when Task model is implemented
// Additional data (JSON)
Data string `gorm:"column:data;type:jsonb;default:'{}'" json:"data"`
// Delivery tracking
Sent bool `gorm:"column:sent;default:false" json:"sent"`
SentAt *time.Time `gorm:"column:sent_at" json:"sent_at"`
// Read tracking
Read bool `gorm:"column:read;default:false" json:"read"`
ReadAt *time.Time `gorm:"column:read_at" json:"read_at"`
// Error handling
ErrorMessage string `gorm:"column:error_message;type:text" json:"error_message,omitempty"`
}
// TableName returns the table name for GORM
func (Notification) TableName() string {
return "notifications_notification"
}
// MarkAsRead marks the notification as read
func (n *Notification) MarkAsRead() {
n.Read = true
now := time.Now().UTC()
n.ReadAt = &now
}
// MarkAsSent marks the notification as sent
func (n *Notification) MarkAsSent() {
n.Sent = true
now := time.Now().UTC()
n.SentAt = &now
}
// APNSDevice represents iOS devices for push notifications
type APNSDevice struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"column:name;size:255" json:"name"`
Active bool `gorm:"column:active;default:true" json:"active"`
UserID *uint `gorm:"column:user_id;index" json:"user_id"`
User *User `gorm:"foreignKey:UserID" json:"-"`
DeviceID string `gorm:"column:device_id;size:255" json:"device_id"`
RegistrationID string `gorm:"column:registration_id;uniqueIndex;size:255" json:"registration_id"`
DateCreated time.Time `gorm:"column:date_created;autoCreateTime" json:"date_created"`
}
// TableName returns the table name for GORM
func (APNSDevice) TableName() string {
return "push_notifications_apnsdevice"
}
// GCMDevice represents Android devices for push notifications
type GCMDevice struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"column:name;size:255" json:"name"`
Active bool `gorm:"column:active;default:true" json:"active"`
UserID *uint `gorm:"column:user_id;index" json:"user_id"`
User *User `gorm:"foreignKey:UserID" json:"-"`
DeviceID string `gorm:"column:device_id;size:255" json:"device_id"`
RegistrationID string `gorm:"column:registration_id;uniqueIndex;size:255" json:"registration_id"`
CloudMessageType string `gorm:"column:cloud_message_type;size:3;default:'FCM'" json:"cloud_message_type"`
DateCreated time.Time `gorm:"column:date_created;autoCreateTime" json:"date_created"`
}
// TableName returns the table name for GORM
func (GCMDevice) TableName() string {
return "push_notifications_gcmdevice"
}

View File

@@ -0,0 +1,105 @@
package models
import (
"time"
"github.com/shopspring/decimal"
)
// ResidenceType represents the residence_residencetype table
type ResidenceType struct {
BaseModel
Name string `gorm:"column:name;size:20;not null" json:"name"`
}
// TableName returns the table name for GORM
func (ResidenceType) TableName() string {
return "residence_residencetype"
}
// Residence represents the residence_residence table
type Residence struct {
BaseModel
OwnerID uint `gorm:"column:owner_id;index;not null" json:"owner_id"`
Owner User `gorm:"foreignKey:OwnerID" json:"owner,omitempty"`
Users []User `gorm:"many2many:residence_residence_users;" json:"users,omitempty"`
Name string `gorm:"column:name;size:200;not null" json:"name"`
PropertyTypeID *uint `gorm:"column:property_type_id" json:"property_type_id"`
PropertyType *ResidenceType `gorm:"foreignKey:PropertyTypeID" json:"property_type,omitempty"`
// Address
StreetAddress string `gorm:"column:street_address;size:255" json:"street_address"`
ApartmentUnit string `gorm:"column:apartment_unit;size:50" json:"apartment_unit"`
City string `gorm:"column:city;size:100" json:"city"`
StateProvince string `gorm:"column:state_province;size:100" json:"state_province"`
PostalCode string `gorm:"column:postal_code;size:20" json:"postal_code"`
Country string `gorm:"column:country;size:100;default:'USA'" json:"country"`
// Property Details
Bedrooms *int `gorm:"column:bedrooms" json:"bedrooms"`
Bathrooms *decimal.Decimal `gorm:"column:bathrooms;type:decimal(3,1)" json:"bathrooms"`
SquareFootage *int `gorm:"column:square_footage" json:"square_footage"`
LotSize *decimal.Decimal `gorm:"column:lot_size;type:decimal(10,2)" json:"lot_size"`
YearBuilt *int `gorm:"column:year_built" json:"year_built"`
Description string `gorm:"column:description;type:text" json:"description"`
PurchaseDate *time.Time `gorm:"column:purchase_date;type:date" json:"purchase_date"`
PurchasePrice *decimal.Decimal `gorm:"column:purchase_price;type:decimal(12,2)" json:"purchase_price"`
IsPrimary bool `gorm:"column:is_primary;default:true" json:"is_primary"`
IsActive bool `gorm:"column:is_active;default:true;index" json:"is_active"` // Soft delete flag
// Relations (to be implemented in Phase 3)
// Tasks []Task `gorm:"foreignKey:ResidenceID" json:"tasks,omitempty"`
// Documents []Document `gorm:"foreignKey:ResidenceID" json:"documents,omitempty"`
// ShareCodes []ResidenceShareCode `gorm:"foreignKey:ResidenceID" json:"-"`
}
// TableName returns the table name for GORM
func (Residence) TableName() string {
return "residence_residence"
}
// GetAllUsers returns all users with access to this residence (owner + shared users)
func (r *Residence) GetAllUsers() []User {
users := make([]User, 0, len(r.Users)+1)
users = append(users, r.Owner)
users = append(users, r.Users...)
return users
}
// HasAccess checks if a user has access to this residence
func (r *Residence) HasAccess(userID uint) bool {
if r.OwnerID == userID {
return true
}
for _, u := range r.Users {
if u.ID == userID {
return true
}
}
return false
}
// IsPrimaryOwner checks if a user is the primary owner
func (r *Residence) IsPrimaryOwner(userID uint) bool {
return r.OwnerID == userID
}
// ResidenceShareCode represents the residence_residencesharecode table
type ResidenceShareCode struct {
BaseModel
ResidenceID uint `gorm:"column:residence_id;index;not null" json:"residence_id"`
Residence Residence `gorm:"foreignKey:ResidenceID" json:"-"`
Code string `gorm:"column:code;uniqueIndex;size:6;not null" json:"code"`
CreatedByID uint `gorm:"column:created_by_id;not null" json:"created_by_id"`
CreatedBy User `gorm:"foreignKey:CreatedByID" json:"-"`
IsActive bool `gorm:"column:is_active;default:true;index" json:"is_active"`
ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"`
}
// TableName returns the table name for GORM
func (ResidenceShareCode) TableName() string {
return "residence_residencesharecode"
}

View File

@@ -0,0 +1,163 @@
package models
import (
"time"
)
// SubscriptionTier represents the subscription tier
type SubscriptionTier string
const (
TierFree SubscriptionTier = "free"
TierPro SubscriptionTier = "pro"
)
// SubscriptionSettings represents the subscription_subscriptionsettings table (singleton)
type SubscriptionSettings struct {
ID uint `gorm:"primaryKey" json:"id"`
EnableLimitations bool `gorm:"column:enable_limitations;default:false" json:"enable_limitations"`
}
// TableName returns the table name for GORM
func (SubscriptionSettings) TableName() string {
return "subscription_subscriptionsettings"
}
// UserSubscription represents the subscription_usersubscription table
type UserSubscription struct {
BaseModel
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
Tier SubscriptionTier `gorm:"column:tier;size:10;default:'free'" json:"tier"`
// In-App Purchase data
AppleReceiptData *string `gorm:"column:apple_receipt_data;type:text" json:"-"`
GooglePurchaseToken *string `gorm:"column:google_purchase_token;type:text" json:"-"`
// Subscription dates
SubscribedAt *time.Time `gorm:"column:subscribed_at" json:"subscribed_at"`
ExpiresAt *time.Time `gorm:"column:expires_at" json:"expires_at"`
AutoRenew bool `gorm:"column:auto_renew;default:true" json:"auto_renew"`
// Tracking
CancelledAt *time.Time `gorm:"column:cancelled_at" json:"cancelled_at"`
Platform string `gorm:"column:platform;size:10" json:"platform"` // ios, android
}
// TableName returns the table name for GORM
func (UserSubscription) TableName() string {
return "subscription_usersubscription"
}
// IsActive returns true if the subscription is active (pro tier and not expired)
func (s *UserSubscription) IsActive() bool {
if s.Tier != TierPro {
return false
}
if s.ExpiresAt != nil && time.Now().UTC().After(*s.ExpiresAt) {
return false
}
return true
}
// IsPro returns true if the user has a pro subscription
func (s *UserSubscription) IsPro() bool {
return s.Tier == TierPro && s.IsActive()
}
// UpgradeTrigger represents the subscription_upgradetrigger table
type UpgradeTrigger struct {
BaseModel
TriggerKey string `gorm:"column:trigger_key;uniqueIndex;size:50;not null" json:"trigger_key"`
Title string `gorm:"column:title;size:200;not null" json:"title"`
Message string `gorm:"column:message;type:text;not null" json:"message"`
PromoHTML string `gorm:"column:promo_html;type:text" json:"promo_html"`
ButtonText string `gorm:"column:button_text;size:50;default:'Upgrade to Pro'" json:"button_text"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
}
// TableName returns the table name for GORM
func (UpgradeTrigger) TableName() string {
return "subscription_upgradetrigger"
}
// FeatureBenefit represents the subscription_featurebenefit table
type FeatureBenefit struct {
BaseModel
FeatureName string `gorm:"column:feature_name;size:200;not null" json:"feature_name"`
FreeTierText string `gorm:"column:free_tier_text;size:200;not null" json:"free_tier_text"`
ProTierText string `gorm:"column:pro_tier_text;size:200;not null" json:"pro_tier_text"`
DisplayOrder int `gorm:"column:display_order;default:0" json:"display_order"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
}
// TableName returns the table name for GORM
func (FeatureBenefit) TableName() string {
return "subscription_featurebenefit"
}
// Promotion represents the subscription_promotion table
type Promotion struct {
BaseModel
PromotionID string `gorm:"column:promotion_id;uniqueIndex;size:50;not null" json:"promotion_id"`
Title string `gorm:"column:title;size:200;not null" json:"title"`
Message string `gorm:"column:message;type:text;not null" json:"message"`
Link *string `gorm:"column:link;size:200" json:"link"`
StartDate time.Time `gorm:"column:start_date;not null" json:"start_date"`
EndDate time.Time `gorm:"column:end_date;not null" json:"end_date"`
TargetTier SubscriptionTier `gorm:"column:target_tier;size:10;default:'free'" json:"target_tier"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
}
// TableName returns the table name for GORM
func (Promotion) TableName() string {
return "subscription_promotion"
}
// IsCurrentlyActive returns true if the promotion is currently active
func (p *Promotion) IsCurrentlyActive() bool {
if !p.IsActive {
return false
}
now := time.Now().UTC()
return now.After(p.StartDate) && now.Before(p.EndDate)
}
// TierLimits represents the subscription_tierlimits table
type TierLimits struct {
BaseModel
Tier SubscriptionTier `gorm:"column:tier;uniqueIndex;size:10;not null" json:"tier"`
PropertiesLimit *int `gorm:"column:properties_limit" json:"properties_limit"`
TasksLimit *int `gorm:"column:tasks_limit" json:"tasks_limit"`
ContractorsLimit *int `gorm:"column:contractors_limit" json:"contractors_limit"`
DocumentsLimit *int `gorm:"column:documents_limit" json:"documents_limit"`
}
// TableName returns the table name for GORM
func (TierLimits) TableName() string {
return "subscription_tierlimits"
}
// GetDefaultFreeLimits returns the default limits for the free tier
func GetDefaultFreeLimits() TierLimits {
one := 1
ten := 10
zero := 0
return TierLimits{
Tier: TierFree,
PropertiesLimit: &one,
TasksLimit: &ten,
ContractorsLimit: &zero,
DocumentsLimit: &zero,
}
}
// GetDefaultProLimits returns the default limits for the pro tier (unlimited)
func GetDefaultProLimits() TierLimits {
return TierLimits{
Tier: TierPro,
PropertiesLimit: nil, // nil = unlimited
TasksLimit: nil,
ContractorsLimit: nil,
DocumentsLimit: nil,
}
}

170
internal/models/task.go Normal file
View File

@@ -0,0 +1,170 @@
package models
import (
"time"
"github.com/shopspring/decimal"
)
// TaskCategory represents the task_taskcategory table
type TaskCategory struct {
BaseModel
Name string `gorm:"column:name;size:50;not null" json:"name"`
Description string `gorm:"column:description;type:text" json:"description"`
Icon string `gorm:"column:icon;size:50" json:"icon"`
Color string `gorm:"column:color;size:7" json:"color"` // Hex color
DisplayOrder int `gorm:"column:display_order;default:0" json:"display_order"`
}
// TableName returns the table name for GORM
func (TaskCategory) TableName() string {
return "task_taskcategory"
}
// TaskPriority represents the task_taskpriority table
type TaskPriority struct {
BaseModel
Name string `gorm:"column:name;size:20;not null" json:"name"`
Level int `gorm:"column:level;not null" json:"level"` // 1=low, 2=medium, 3=high, 4=urgent
Color string `gorm:"column:color;size:7" json:"color"`
DisplayOrder int `gorm:"column:display_order;default:0" json:"display_order"`
}
// TableName returns the table name for GORM
func (TaskPriority) TableName() string {
return "task_taskpriority"
}
// TaskStatus represents the task_taskstatus table
type TaskStatus struct {
BaseModel
Name string `gorm:"column:name;size:20;not null" json:"name"`
Description string `gorm:"column:description;type:text" json:"description"`
Color string `gorm:"column:color;size:7" json:"color"`
DisplayOrder int `gorm:"column:display_order;default:0" json:"display_order"`
}
// TableName returns the table name for GORM
func (TaskStatus) TableName() string {
return "task_taskstatus"
}
// TaskFrequency represents the task_taskfrequency table
type TaskFrequency struct {
BaseModel
Name string `gorm:"column:name;size:20;not null" json:"name"`
Days *int `gorm:"column:days" json:"days"` // Number of days between occurrences (nil = one-time)
DisplayOrder int `gorm:"column:display_order;default:0" json:"display_order"`
}
// TableName returns the table name for GORM
func (TaskFrequency) TableName() string {
return "task_taskfrequency"
}
// Task represents the task_task table
type Task struct {
BaseModel
ResidenceID uint `gorm:"column:residence_id;index;not null" json:"residence_id"`
Residence Residence `gorm:"foreignKey:ResidenceID" json:"residence,omitempty"`
CreatedByID uint `gorm:"column:created_by_id;index;not null" json:"created_by_id"`
CreatedBy User `gorm:"foreignKey:CreatedByID" json:"created_by,omitempty"`
AssignedToID *uint `gorm:"column:assigned_to_id;index" json:"assigned_to_id"`
AssignedTo *User `gorm:"foreignKey:AssignedToID" json:"assigned_to,omitempty"`
Title string `gorm:"column:title;size:200;not null" json:"title"`
Description string `gorm:"column:description;type:text" json:"description"`
CategoryID *uint `gorm:"column:category_id;index" json:"category_id"`
Category *TaskCategory `gorm:"foreignKey:CategoryID" json:"category,omitempty"`
PriorityID *uint `gorm:"column:priority_id;index" json:"priority_id"`
Priority *TaskPriority `gorm:"foreignKey:PriorityID" json:"priority,omitempty"`
StatusID *uint `gorm:"column:status_id;index" json:"status_id"`
Status *TaskStatus `gorm:"foreignKey:StatusID" json:"status,omitempty"`
FrequencyID *uint `gorm:"column:frequency_id;index" json:"frequency_id"`
Frequency *TaskFrequency `gorm:"foreignKey:FrequencyID" json:"frequency,omitempty"`
DueDate *time.Time `gorm:"column:due_date;type:date;index" json:"due_date"`
EstimatedCost *decimal.Decimal `gorm:"column:estimated_cost;type:decimal(10,2)" json:"estimated_cost"`
ActualCost *decimal.Decimal `gorm:"column:actual_cost;type:decimal(10,2)" json:"actual_cost"`
// Contractor association
ContractorID *uint `gorm:"column:contractor_id;index" json:"contractor_id"`
// Contractor *Contractor `gorm:"foreignKey:ContractorID" json:"contractor,omitempty"`
// State flags
IsCancelled bool `gorm:"column:is_cancelled;default:false;index" json:"is_cancelled"`
IsArchived bool `gorm:"column:is_archived;default:false;index" json:"is_archived"`
// Parent task for recurring tasks
ParentTaskID *uint `gorm:"column:parent_task_id;index" json:"parent_task_id"`
ParentTask *Task `gorm:"foreignKey:ParentTaskID" json:"parent_task,omitempty"`
// Completions
Completions []TaskCompletion `gorm:"foreignKey:TaskID" json:"completions,omitempty"`
}
// TableName returns the table name for GORM
func (Task) TableName() string {
return "task_task"
}
// IsOverdue returns true if the task is past its due date and not completed
func (t *Task) IsOverdue() bool {
if t.DueDate == nil || t.IsCancelled || t.IsArchived {
return false
}
// Check if there's a completion
if len(t.Completions) > 0 {
return false
}
return time.Now().UTC().After(*t.DueDate)
}
// IsDueSoon returns true if the task is due within the specified days
func (t *Task) IsDueSoon(days int) bool {
if t.DueDate == nil || t.IsCancelled || t.IsArchived {
return false
}
if len(t.Completions) > 0 {
return false
}
threshold := time.Now().UTC().AddDate(0, 0, days)
return t.DueDate.Before(threshold) && !t.IsOverdue()
}
// TaskCompletion represents the task_taskcompletion table
type TaskCompletion struct {
BaseModel
TaskID uint `gorm:"column:task_id;index;not null" json:"task_id"`
Task Task `gorm:"foreignKey:TaskID" json:"-"`
CompletedByID uint `gorm:"column:completed_by_id;index;not null" json:"completed_by_id"`
CompletedBy User `gorm:"foreignKey:CompletedByID" json:"completed_by,omitempty"`
CompletedAt time.Time `gorm:"column:completed_at;not null" json:"completed_at"`
Notes string `gorm:"column:notes;type:text" json:"notes"`
ActualCost *decimal.Decimal `gorm:"column:actual_cost;type:decimal(10,2)" json:"actual_cost"`
PhotoURL string `gorm:"column:photo_url;size:500" json:"photo_url"`
}
// TableName returns the table name for GORM
func (TaskCompletion) TableName() string {
return "task_taskcompletion"
}
// KanbanColumn represents a column in the kanban board
type KanbanColumn struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
ButtonTypes []string `json:"button_types"`
Icons map[string]string `json:"icons"`
Color string `json:"color"`
Tasks []Task `json:"tasks"`
Count int `json:"count"`
}
// KanbanBoard represents the full kanban board response
type KanbanBoard struct {
Columns []KanbanColumn `json:"columns"`
DaysThreshold int `json:"days_threshold"`
ResidenceID string `json:"residence_id"`
}

232
internal/models/user.go Normal file
View File

@@ -0,0 +1,232 @@
package models
import (
"crypto/rand"
"encoding/hex"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// User represents the auth_user table (Django's default User model)
type User struct {
ID uint `gorm:"primaryKey" json:"id"`
Password string `gorm:"column:password;size:128;not null" json:"-"`
LastLogin *time.Time `gorm:"column:last_login" json:"last_login,omitempty"`
IsSuperuser bool `gorm:"column:is_superuser;default:false" json:"is_superuser"`
Username string `gorm:"column:username;uniqueIndex;size:150;not null" json:"username"`
FirstName string `gorm:"column:first_name;size:150" json:"first_name"`
LastName string `gorm:"column:last_name;size:150" json:"last_name"`
Email string `gorm:"column:email;size:254" json:"email"`
IsStaff bool `gorm:"column:is_staff;default:false" json:"is_staff"`
IsActive bool `gorm:"column:is_active;default:true" json:"is_active"`
DateJoined time.Time `gorm:"column:date_joined;autoCreateTime" json:"date_joined"`
// Relations (not stored in auth_user table)
Profile *UserProfile `gorm:"foreignKey:UserID" json:"profile,omitempty"`
AuthToken *AuthToken `gorm:"foreignKey:UserID" json:"-"`
OwnedResidences []Residence `gorm:"foreignKey:OwnerID" json:"-"`
SharedResidences []Residence `gorm:"many2many:residence_residence_users;" json:"-"`
NotificationPref *NotificationPreference `gorm:"foreignKey:UserID" json:"-"`
Subscription *UserSubscription `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (User) TableName() string {
return "auth_user"
}
// SetPassword hashes and sets the password
func (u *User) SetPassword(password string) error {
// Django uses PBKDF2_SHA256 by default, but we'll use bcrypt for Go
// Note: This means passwords set by Django won't work with Go's check
// For migration, you'd need to either:
// 1. Force password reset for all users
// 2. Implement Django's PBKDF2 hasher in Go
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.Password = string(hash)
return nil
}
// CheckPassword verifies a password against the stored hash
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
return err == nil
}
// GetFullName returns the user's full name
func (u *User) GetFullName() string {
if u.FirstName != "" && u.LastName != "" {
return u.FirstName + " " + u.LastName
}
if u.FirstName != "" {
return u.FirstName
}
return u.Username
}
// AuthToken represents the user_authtoken table
type AuthToken struct {
Key string `gorm:"column:key;primaryKey;size:40" json:"key"`
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
Created time.Time `gorm:"column:created;autoCreateTime" json:"created"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (AuthToken) TableName() string {
return "user_authtoken"
}
// BeforeCreate generates a token key if not provided
func (t *AuthToken) BeforeCreate(tx *gorm.DB) error {
if t.Key == "" {
t.Key = generateToken()
}
if t.Created.IsZero() {
t.Created = time.Now().UTC()
}
return nil
}
// generateToken creates a random 40-character hex token
func generateToken() string {
b := make([]byte, 20)
rand.Read(b)
return hex.EncodeToString(b)
}
// GetOrCreate gets an existing token or creates a new one for the user
func GetOrCreateToken(tx *gorm.DB, userID uint) (*AuthToken, error) {
var token AuthToken
result := tx.Where("user_id = ?", userID).First(&token)
if result.Error == gorm.ErrRecordNotFound {
token = AuthToken{UserID: userID}
if err := tx.Create(&token).Error; err != nil {
return nil, err
}
} else if result.Error != nil {
return nil, result.Error
}
return &token, nil
}
// UserProfile represents the user_userprofile table
type UserProfile struct {
BaseModel
UserID uint `gorm:"column:user_id;uniqueIndex;not null" json:"user_id"`
Verified bool `gorm:"column:verified;default:false" json:"verified"`
Bio string `gorm:"column:bio;type:text" json:"bio"`
PhoneNumber string `gorm:"column:phone_number;size:15" json:"phone_number"`
DateOfBirth *time.Time `gorm:"column:date_of_birth;type:date" json:"date_of_birth,omitempty"`
ProfilePicture string `gorm:"column:profile_picture;size:100" json:"profile_picture"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (UserProfile) TableName() string {
return "user_userprofile"
}
// ConfirmationCode represents the user_confirmationcode table
type ConfirmationCode struct {
BaseModel
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
Code string `gorm:"column:code;size:6;not null" json:"code"`
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
IsUsed bool `gorm:"column:is_used;default:false" json:"is_used"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (ConfirmationCode) TableName() string {
return "user_confirmationcode"
}
// IsValid checks if the confirmation code is still valid
func (c *ConfirmationCode) IsValid() bool {
return !c.IsUsed && time.Now().UTC().Before(c.ExpiresAt)
}
// GenerateCode creates a random 6-digit code
func GenerateConfirmationCode() string {
b := make([]byte, 3)
rand.Read(b)
// Convert to 6-digit number
num := int(b[0])<<16 | int(b[1])<<8 | int(b[2])
return string(rune('0'+num%10)) + string(rune('0'+(num/10)%10)) +
string(rune('0'+(num/100)%10)) + string(rune('0'+(num/1000)%10)) +
string(rune('0'+(num/10000)%10)) + string(rune('0'+(num/100000)%10))
}
// PasswordResetCode represents the user_passwordresetcode table
type PasswordResetCode struct {
BaseModel
UserID uint `gorm:"column:user_id;index;not null" json:"user_id"`
CodeHash string `gorm:"column:code_hash;size:128;not null" json:"-"`
ResetToken string `gorm:"column:reset_token;uniqueIndex;size:64;not null" json:"reset_token"`
ExpiresAt time.Time `gorm:"column:expires_at;not null" json:"expires_at"`
Used bool `gorm:"column:used;default:false" json:"used"`
Attempts int `gorm:"column:attempts;default:0" json:"attempts"`
MaxAttempts int `gorm:"column:max_attempts;default:5" json:"max_attempts"`
// Relations
User User `gorm:"foreignKey:UserID" json:"-"`
}
// TableName returns the table name for GORM
func (PasswordResetCode) TableName() string {
return "user_passwordresetcode"
}
// SetCode hashes and stores the reset code
func (p *PasswordResetCode) SetCode(code string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return err
}
p.CodeHash = string(hash)
return nil
}
// CheckCode verifies a code against the stored hash
func (p *PasswordResetCode) CheckCode(code string) bool {
err := bcrypt.CompareHashAndPassword([]byte(p.CodeHash), []byte(code))
return err == nil
}
// IsValid checks if the reset code is still valid
func (p *PasswordResetCode) IsValid() bool {
return !p.Used && time.Now().UTC().Before(p.ExpiresAt) && p.Attempts < p.MaxAttempts
}
// IncrementAttempts increments the attempt counter
func (p *PasswordResetCode) IncrementAttempts(tx *gorm.DB) error {
p.Attempts++
return tx.Model(p).Update("attempts", p.Attempts).Error
}
// MarkAsUsed marks the code as used
func (p *PasswordResetCode) MarkAsUsed(tx *gorm.DB) error {
p.Used = true
return tx.Model(p).Update("used", true).Error
}
// GenerateResetToken creates a URL-safe token
func GenerateResetToken() string {
b := make([]byte, 32)
rand.Read(b)
return hex.EncodeToString(b)
}

199
internal/push/gorush.go Normal file
View File

@@ -0,0 +1,199 @@
package push
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/rs/zerolog/log"
"github.com/treytartt/mycrib-api/internal/config"
)
// Platform constants
const (
PlatformIOS = "ios"
PlatformAndroid = "android"
)
// GorushClient handles communication with Gorush server
type GorushClient struct {
baseURL string
httpClient *http.Client
config *config.PushConfig
}
// NewGorushClient creates a new Gorush client
func NewGorushClient(cfg *config.PushConfig) *GorushClient {
return &GorushClient{
baseURL: cfg.GorushURL,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
config: cfg,
}
}
// PushNotification represents a push notification request
type PushNotification struct {
Tokens []string `json:"tokens"`
Platform int `json:"platform"` // 1 = iOS, 2 = Android
Message string `json:"message"`
Title string `json:"title,omitempty"`
Topic string `json:"topic,omitempty"` // iOS bundle ID
Badge *int `json:"badge,omitempty"` // iOS badge count
Sound string `json:"sound,omitempty"` // Notification sound
ContentAvailable bool `json:"content_available,omitempty"` // iOS background notification
MutableContent bool `json:"mutable_content,omitempty"` // iOS mutable content
Data map[string]string `json:"data,omitempty"` // Custom data payload
Priority string `json:"priority,omitempty"` // high or normal
ThreadID string `json:"thread_id,omitempty"` // iOS thread grouping
CollapseKey string `json:"collapse_key,omitempty"` // Android collapse key
}
// GorushRequest represents the full Gorush API request
type GorushRequest struct {
Notifications []PushNotification `json:"notifications"`
}
// GorushResponse represents the Gorush API response
type GorushResponse struct {
Counts int `json:"counts"`
Logs []GorushLog `json:"logs,omitempty"`
Success string `json:"success,omitempty"`
}
// GorushLog represents a log entry from Gorush
type GorushLog struct {
Type string `json:"type"`
Platform string `json:"platform"`
Token string `json:"token"`
Message string `json:"message"`
Error string `json:"error,omitempty"`
}
// SendToIOS sends a push notification to iOS devices
func (c *GorushClient) SendToIOS(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if len(tokens) == 0 {
return nil
}
notification := PushNotification{
Tokens: tokens,
Platform: 1, // iOS
Title: title,
Message: message,
Topic: c.config.APNSTopic,
Sound: "default",
MutableContent: true,
Data: data,
Priority: "high",
}
return c.send(ctx, notification)
}
// SendToAndroid sends a push notification to Android devices
func (c *GorushClient) SendToAndroid(ctx context.Context, tokens []string, title, message string, data map[string]string) error {
if len(tokens) == 0 {
return nil
}
notification := PushNotification{
Tokens: tokens,
Platform: 2, // Android
Title: title,
Message: message,
Data: data,
Priority: "high",
}
return c.send(ctx, notification)
}
// SendToAll sends a push notification to both iOS and Android devices
func (c *GorushClient) SendToAll(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error {
var errs []error
if len(iosTokens) > 0 {
if err := c.SendToIOS(ctx, iosTokens, title, message, data); err != nil {
errs = append(errs, fmt.Errorf("iOS: %w", err))
}
}
if len(androidTokens) > 0 {
if err := c.SendToAndroid(ctx, androidTokens, title, message, data); err != nil {
errs = append(errs, fmt.Errorf("Android: %w", err))
}
}
if len(errs) > 0 {
return fmt.Errorf("push notification errors: %v", errs)
}
return nil
}
// send sends the notification to Gorush
func (c *GorushClient) send(ctx context.Context, notification PushNotification) error {
req := GorushRequest{
Notifications: []PushNotification{notification},
}
body, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/push", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("gorush returned status %d", resp.StatusCode)
}
var gorushResp GorushResponse
if err := json.NewDecoder(resp.Body).Decode(&gorushResp); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
log.Debug().
Int("counts", gorushResp.Counts).
Int("tokens", len(notification.Tokens)).
Msg("Push notification sent")
return nil
}
// HealthCheck checks if Gorush is healthy
func (c *GorushClient) HealthCheck(ctx context.Context) error {
httpReq, err := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/stat/go", nil)
if err != nil {
return err
}
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("gorush health check failed: status %d", resp.StatusCode)
}
return nil
}

View File

@@ -0,0 +1,151 @@
package repositories
import (
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
// ContractorRepository handles database operations for contractors
type ContractorRepository struct {
db *gorm.DB
}
// NewContractorRepository creates a new contractor repository
func NewContractorRepository(db *gorm.DB) *ContractorRepository {
return &ContractorRepository{db: db}
}
// FindByID finds a contractor by ID with preloaded relations
func (r *ContractorRepository) FindByID(id uint) (*models.Contractor, error) {
var contractor models.Contractor
err := r.db.Preload("CreatedBy").
Preload("Specialties").
Preload("Tasks").
Where("id = ? AND is_active = ?", id, true).
First(&contractor).Error
if err != nil {
return nil, err
}
return &contractor, nil
}
// FindByResidence finds all contractors for a residence
func (r *ContractorRepository) FindByResidence(residenceID uint) ([]models.Contractor, error) {
var contractors []models.Contractor
err := r.db.Preload("CreatedBy").
Preload("Specialties").
Where("residence_id = ? AND is_active = ?", residenceID, true).
Order("is_favorite DESC, name ASC").
Find(&contractors).Error
return contractors, err
}
// FindByUser finds all contractors accessible to a user
func (r *ContractorRepository) FindByUser(residenceIDs []uint) ([]models.Contractor, error) {
var contractors []models.Contractor
err := r.db.Preload("CreatedBy").
Preload("Specialties").
Preload("Residence").
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Order("is_favorite DESC, name ASC").
Find(&contractors).Error
return contractors, err
}
// Create creates a new contractor
func (r *ContractorRepository) Create(contractor *models.Contractor) error {
return r.db.Create(contractor).Error
}
// Update updates a contractor
func (r *ContractorRepository) Update(contractor *models.Contractor) error {
return r.db.Save(contractor).Error
}
// Delete soft-deletes a contractor
func (r *ContractorRepository) Delete(id uint) error {
return r.db.Model(&models.Contractor{}).
Where("id = ?", id).
Update("is_active", false).Error
}
// ToggleFavorite toggles the favorite status of a contractor
func (r *ContractorRepository) ToggleFavorite(id uint) (bool, error) {
var contractor models.Contractor
if err := r.db.First(&contractor, id).Error; err != nil {
return false, err
}
newStatus := !contractor.IsFavorite
err := r.db.Model(&models.Contractor{}).
Where("id = ?", id).
Update("is_favorite", newStatus).Error
return newStatus, err
}
// GetTasksForContractor gets all tasks associated with a contractor
func (r *ContractorRepository) GetTasksForContractor(contractorID uint) ([]models.Task, error) {
var tasks []models.Task
err := r.db.Preload("Category").
Preload("Priority").
Preload("Status").
Where("contractor_id = ?", contractorID).
Order("due_date ASC NULLS LAST").
Find(&tasks).Error
return tasks, err
}
// SetSpecialties sets the specialties for a contractor
func (r *ContractorRepository) SetSpecialties(contractorID uint, specialtyIDs []uint) error {
var contractor models.Contractor
if err := r.db.First(&contractor, contractorID).Error; err != nil {
return err
}
// Clear existing specialties
if err := r.db.Model(&contractor).Association("Specialties").Clear(); err != nil {
return err
}
if len(specialtyIDs) == 0 {
return nil
}
// Add new specialties
var specialties []models.ContractorSpecialty
if err := r.db.Where("id IN ?", specialtyIDs).Find(&specialties).Error; err != nil {
return err
}
return r.db.Model(&contractor).Association("Specialties").Append(specialties)
}
// CountByResidence counts contractors in a residence
func (r *ContractorRepository) CountByResidence(residenceID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Contractor{}).
Where("residence_id = ? AND is_active = ?", residenceID, true).
Count(&count).Error
return count, err
}
// === Specialty Operations ===
// GetAllSpecialties returns all contractor specialties
func (r *ContractorRepository) GetAllSpecialties() ([]models.ContractorSpecialty, error) {
var specialties []models.ContractorSpecialty
err := r.db.Order("display_order, name").Find(&specialties).Error
return specialties, err
}
// FindSpecialtyByID finds a specialty by ID
func (r *ContractorRepository) FindSpecialtyByID(id uint) (*models.ContractorSpecialty, error) {
var specialty models.ContractorSpecialty
err := r.db.First(&specialty, id).Error
if err != nil {
return nil, err
}
return &specialty, nil
}

View File

@@ -0,0 +1,125 @@
package repositories
import (
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
// DocumentRepository handles database operations for documents
type DocumentRepository struct {
db *gorm.DB
}
// NewDocumentRepository creates a new document repository
func NewDocumentRepository(db *gorm.DB) *DocumentRepository {
return &DocumentRepository{db: db}
}
// FindByID finds a document by ID with preloaded relations
func (r *DocumentRepository) FindByID(id uint) (*models.Document, error) {
var document models.Document
err := r.db.Preload("CreatedBy").
Preload("Task").
Where("id = ? AND is_active = ?", id, true).
First(&document).Error
if err != nil {
return nil, err
}
return &document, nil
}
// FindByResidence finds all documents for a residence
func (r *DocumentRepository) FindByResidence(residenceID uint) ([]models.Document, error) {
var documents []models.Document
err := r.db.Preload("CreatedBy").
Where("residence_id = ? AND is_active = ?", residenceID, true).
Order("created_at DESC").
Find(&documents).Error
return documents, err
}
// FindByUser finds all documents accessible to a user
func (r *DocumentRepository) FindByUser(residenceIDs []uint) ([]models.Document, error) {
var documents []models.Document
err := r.db.Preload("CreatedBy").
Preload("Residence").
Where("residence_id IN ? AND is_active = ?", residenceIDs, true).
Order("created_at DESC").
Find(&documents).Error
return documents, err
}
// FindWarranties finds all warranty documents
func (r *DocumentRepository) FindWarranties(residenceIDs []uint) ([]models.Document, error) {
var documents []models.Document
err := r.db.Preload("CreatedBy").
Preload("Residence").
Where("residence_id IN ? AND is_active = ? AND document_type = ?",
residenceIDs, true, models.DocumentTypeWarranty).
Order("expiry_date ASC NULLS LAST").
Find(&documents).Error
return documents, err
}
// FindExpiringWarranties finds warranties expiring within the specified days
func (r *DocumentRepository) FindExpiringWarranties(residenceIDs []uint, days int) ([]models.Document, error) {
threshold := time.Now().UTC().AddDate(0, 0, days)
now := time.Now().UTC()
var documents []models.Document
err := r.db.Preload("CreatedBy").
Preload("Residence").
Where("residence_id IN ? AND is_active = ? AND document_type = ? AND expiry_date > ? AND expiry_date <= ?",
residenceIDs, true, models.DocumentTypeWarranty, now, threshold).
Order("expiry_date ASC").
Find(&documents).Error
return documents, err
}
// Create creates a new document
func (r *DocumentRepository) Create(document *models.Document) error {
return r.db.Create(document).Error
}
// Update updates a document
func (r *DocumentRepository) Update(document *models.Document) error {
return r.db.Save(document).Error
}
// Delete soft-deletes a document
func (r *DocumentRepository) Delete(id uint) error {
return r.db.Model(&models.Document{}).
Where("id = ?", id).
Update("is_active", false).Error
}
// Activate activates a document
func (r *DocumentRepository) Activate(id uint) error {
return r.db.Model(&models.Document{}).
Where("id = ?", id).
Update("is_active", true).Error
}
// Deactivate deactivates a document
func (r *DocumentRepository) Deactivate(id uint) error {
return r.db.Model(&models.Document{}).
Where("id = ?", id).
Update("is_active", false).Error
}
// CountByResidence counts documents in a residence
func (r *DocumentRepository) CountByResidence(residenceID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Document{}).
Where("residence_id = ? AND is_active = ?", residenceID, true).
Count(&count).Error
return count, err
}
// FindByIDIncludingInactive finds a document by ID including inactive ones
func (r *DocumentRepository) FindByIDIncludingInactive(id uint, document *models.Document) error {
return r.db.Preload("CreatedBy").First(document, id).Error
}

View File

@@ -0,0 +1,265 @@
package repositories
import (
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
// NotificationRepository handles database operations for notifications
type NotificationRepository struct {
db *gorm.DB
}
// NewNotificationRepository creates a new notification repository
func NewNotificationRepository(db *gorm.DB) *NotificationRepository {
return &NotificationRepository{db: db}
}
// === Notifications ===
// FindByID finds a notification by ID
func (r *NotificationRepository) FindByID(id uint) (*models.Notification, error) {
var notification models.Notification
err := r.db.First(&notification, id).Error
if err != nil {
return nil, err
}
return &notification, nil
}
// FindByUser finds all notifications for a user
func (r *NotificationRepository) FindByUser(userID uint, limit, offset int) ([]models.Notification, error) {
var notifications []models.Notification
query := r.db.Where("user_id = ?", userID).
Order("created_at DESC")
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
err := query.Find(&notifications).Error
return notifications, err
}
// Create creates a new notification
func (r *NotificationRepository) Create(notification *models.Notification) error {
return r.db.Create(notification).Error
}
// MarkAsRead marks a notification as read
func (r *NotificationRepository) MarkAsRead(id uint) error {
now := time.Now().UTC()
return r.db.Model(&models.Notification{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"read": true,
"read_at": now,
}).Error
}
// MarkAllAsRead marks all notifications for a user as read
func (r *NotificationRepository) MarkAllAsRead(userID uint) error {
now := time.Now().UTC()
return r.db.Model(&models.Notification{}).
Where("user_id = ? AND read = ?", userID, false).
Updates(map[string]interface{}{
"read": true,
"read_at": now,
}).Error
}
// MarkAsSent marks a notification as sent
func (r *NotificationRepository) MarkAsSent(id uint) error {
now := time.Now().UTC()
return r.db.Model(&models.Notification{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"sent": true,
"sent_at": now,
}).Error
}
// SetError sets an error message on a notification
func (r *NotificationRepository) SetError(id uint, errorMsg string) error {
return r.db.Model(&models.Notification{}).
Where("id = ?", id).
Update("error_message", errorMsg).Error
}
// CountUnread counts unread notifications for a user
func (r *NotificationRepository) CountUnread(userID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Notification{}).
Where("user_id = ? AND read = ?", userID, false).
Count(&count).Error
return count, err
}
// GetPendingNotifications gets notifications that need to be sent
func (r *NotificationRepository) GetPendingNotifications(limit int) ([]models.Notification, error) {
var notifications []models.Notification
err := r.db.Where("sent = ?", false).
Order("created_at ASC").
Limit(limit).
Find(&notifications).Error
return notifications, err
}
// === Notification Preferences ===
// FindPreferencesByUser finds notification preferences for a user
func (r *NotificationRepository) FindPreferencesByUser(userID uint) (*models.NotificationPreference, error) {
var prefs models.NotificationPreference
err := r.db.Where("user_id = ?", userID).First(&prefs).Error
if err != nil {
return nil, err
}
return &prefs, nil
}
// CreatePreferences creates notification preferences for a user
func (r *NotificationRepository) CreatePreferences(prefs *models.NotificationPreference) error {
return r.db.Create(prefs).Error
}
// UpdatePreferences updates notification preferences
func (r *NotificationRepository) UpdatePreferences(prefs *models.NotificationPreference) error {
return r.db.Save(prefs).Error
}
// GetOrCreatePreferences gets or creates notification preferences for a user
func (r *NotificationRepository) GetOrCreatePreferences(userID uint) (*models.NotificationPreference, error) {
prefs, err := r.FindPreferencesByUser(userID)
if err == nil {
return prefs, nil
}
if err == gorm.ErrRecordNotFound {
prefs = &models.NotificationPreference{
UserID: userID,
TaskDueSoon: true,
TaskOverdue: true,
TaskCompleted: true,
TaskAssigned: true,
ResidenceShared: true,
WarrantyExpiring: true,
}
if err := r.CreatePreferences(prefs); err != nil {
return nil, err
}
return prefs, nil
}
return nil, err
}
// === Device Registration ===
// FindAPNSDeviceByToken finds an APNS device by registration token
func (r *NotificationRepository) FindAPNSDeviceByToken(token string) (*models.APNSDevice, error) {
var device models.APNSDevice
err := r.db.Where("registration_id = ?", token).First(&device).Error
if err != nil {
return nil, err
}
return &device, nil
}
// FindAPNSDevicesByUser finds all APNS devices for a user
func (r *NotificationRepository) FindAPNSDevicesByUser(userID uint) ([]models.APNSDevice, error) {
var devices []models.APNSDevice
err := r.db.Where("user_id = ? AND active = ?", userID, true).Find(&devices).Error
return devices, err
}
// CreateAPNSDevice creates a new APNS device
func (r *NotificationRepository) CreateAPNSDevice(device *models.APNSDevice) error {
return r.db.Create(device).Error
}
// UpdateAPNSDevice updates an APNS device
func (r *NotificationRepository) UpdateAPNSDevice(device *models.APNSDevice) error {
return r.db.Save(device).Error
}
// DeleteAPNSDevice deletes an APNS device
func (r *NotificationRepository) DeleteAPNSDevice(id uint) error {
return r.db.Delete(&models.APNSDevice{}, id).Error
}
// DeactivateAPNSDevice deactivates an APNS device
func (r *NotificationRepository) DeactivateAPNSDevice(id uint) error {
return r.db.Model(&models.APNSDevice{}).
Where("id = ?", id).
Update("active", false).Error
}
// FindGCMDeviceByToken finds a GCM device by registration token
func (r *NotificationRepository) FindGCMDeviceByToken(token string) (*models.GCMDevice, error) {
var device models.GCMDevice
err := r.db.Where("registration_id = ?", token).First(&device).Error
if err != nil {
return nil, err
}
return &device, nil
}
// FindGCMDevicesByUser finds all GCM devices for a user
func (r *NotificationRepository) FindGCMDevicesByUser(userID uint) ([]models.GCMDevice, error) {
var devices []models.GCMDevice
err := r.db.Where("user_id = ? AND active = ?", userID, true).Find(&devices).Error
return devices, err
}
// CreateGCMDevice creates a new GCM device
func (r *NotificationRepository) CreateGCMDevice(device *models.GCMDevice) error {
return r.db.Create(device).Error
}
// UpdateGCMDevice updates a GCM device
func (r *NotificationRepository) UpdateGCMDevice(device *models.GCMDevice) error {
return r.db.Save(device).Error
}
// DeleteGCMDevice deletes a GCM device
func (r *NotificationRepository) DeleteGCMDevice(id uint) error {
return r.db.Delete(&models.GCMDevice{}, id).Error
}
// DeactivateGCMDevice deactivates a GCM device
func (r *NotificationRepository) DeactivateGCMDevice(id uint) error {
return r.db.Model(&models.GCMDevice{}).
Where("id = ?", id).
Update("active", false).Error
}
// GetActiveTokensForUser gets all active push tokens for a user
func (r *NotificationRepository) GetActiveTokensForUser(userID uint) (iosTokens []string, androidTokens []string, err error) {
apnsDevices, err := r.FindAPNSDevicesByUser(userID)
if err != nil && err != gorm.ErrRecordNotFound {
return nil, nil, err
}
gcmDevices, err := r.FindGCMDevicesByUser(userID)
if err != nil && err != gorm.ErrRecordNotFound {
return nil, nil, err
}
iosTokens = make([]string, 0, len(apnsDevices))
for _, d := range apnsDevices {
iosTokens = append(iosTokens, d.RegistrationID)
}
androidTokens = make([]string, 0, len(gcmDevices))
for _, d := range gcmDevices {
androidTokens = append(androidTokens, d.RegistrationID)
}
return iosTokens, androidTokens, nil
}

View File

@@ -0,0 +1,310 @@
package repositories
import (
"crypto/rand"
"errors"
"math/big"
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
// ResidenceRepository handles database operations for residences
type ResidenceRepository struct {
db *gorm.DB
}
// NewResidenceRepository creates a new residence repository
func NewResidenceRepository(db *gorm.DB) *ResidenceRepository {
return &ResidenceRepository{db: db}
}
// FindByID finds a residence by ID with preloaded relations
func (r *ResidenceRepository) FindByID(id uint) (*models.Residence, error) {
var residence models.Residence
err := r.db.Preload("Owner").
Preload("Users").
Preload("PropertyType").
Where("id = ? AND is_active = ?", id, true).
First(&residence).Error
if err != nil {
return nil, err
}
return &residence, nil
}
// FindByIDSimple finds a residence by ID without preloading (for quick checks)
func (r *ResidenceRepository) FindByIDSimple(id uint) (*models.Residence, error) {
var residence models.Residence
err := r.db.Where("id = ? AND is_active = ?", id, true).First(&residence).Error
if err != nil {
return nil, err
}
return &residence, nil
}
// FindByUser finds all residences accessible to a user (owned or shared)
func (r *ResidenceRepository) FindByUser(userID uint) ([]models.Residence, error) {
var residences []models.Residence
// Find residences where user is owner OR user is in the shared users list
err := r.db.Preload("Owner").
Preload("Users").
Preload("PropertyType").
Where("is_active = ?", true).
Where("owner_id = ? OR id IN (?)",
userID,
r.db.Table("residence_residence_users").Select("residence_id").Where("user_id = ?", userID),
).
Order("is_primary DESC, created_at DESC").
Find(&residences).Error
if err != nil {
return nil, err
}
return residences, nil
}
// FindOwnedByUser finds all residences owned by a user
func (r *ResidenceRepository) FindOwnedByUser(userID uint) ([]models.Residence, error) {
var residences []models.Residence
err := r.db.Preload("Owner").
Preload("Users").
Preload("PropertyType").
Where("owner_id = ? AND is_active = ?", userID, true).
Order("is_primary DESC, created_at DESC").
Find(&residences).Error
if err != nil {
return nil, err
}
return residences, nil
}
// Create creates a new residence
func (r *ResidenceRepository) Create(residence *models.Residence) error {
return r.db.Create(residence).Error
}
// Update updates a residence
func (r *ResidenceRepository) Update(residence *models.Residence) error {
return r.db.Save(residence).Error
}
// Delete soft-deletes a residence by setting is_active to false
func (r *ResidenceRepository) Delete(id uint) error {
return r.db.Model(&models.Residence{}).
Where("id = ?", id).
Update("is_active", false).Error
}
// AddUser adds a user to a residence's shared users
func (r *ResidenceRepository) AddUser(residenceID, userID uint) error {
// Using raw SQL for the many-to-many join table
return r.db.Exec(
"INSERT INTO residence_residence_users (residence_id, user_id) VALUES (?, ?) ON CONFLICT DO NOTHING",
residenceID, userID,
).Error
}
// RemoveUser removes a user from a residence's shared users
func (r *ResidenceRepository) RemoveUser(residenceID, userID uint) error {
return r.db.Exec(
"DELETE FROM residence_residence_users WHERE residence_id = ? AND user_id = ?",
residenceID, userID,
).Error
}
// GetResidenceUsers returns all users with access to a residence
func (r *ResidenceRepository) GetResidenceUsers(residenceID uint) ([]models.User, error) {
residence, err := r.FindByID(residenceID)
if err != nil {
return nil, err
}
users := make([]models.User, 0, len(residence.Users)+1)
users = append(users, residence.Owner)
users = append(users, residence.Users...)
return users, nil
}
// HasAccess checks if a user has access to a residence
func (r *ResidenceRepository) HasAccess(residenceID, userID uint) (bool, error) {
var count int64
// Check if user is owner
err := r.db.Model(&models.Residence{}).
Where("id = ? AND owner_id = ? AND is_active = ?", residenceID, userID, true).
Count(&count).Error
if err != nil {
return false, err
}
if count > 0 {
return true, nil
}
// Check if user is in shared users
err = r.db.Table("residence_residence_users").
Where("residence_id = ? AND user_id = ?", residenceID, userID).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IsOwner checks if a user is the owner of a residence
func (r *ResidenceRepository) IsOwner(residenceID, userID uint) (bool, error) {
var count int64
err := r.db.Model(&models.Residence{}).
Where("id = ? AND owner_id = ? AND is_active = ?", residenceID, userID, true).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// CountByOwner counts residences owned by a user
func (r *ResidenceRepository) CountByOwner(userID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Residence{}).
Where("owner_id = ? AND is_active = ?", userID, true).
Count(&count).Error
return count, err
}
// === Share Code Operations ===
// CreateShareCode creates a new share code for a residence
func (r *ResidenceRepository) CreateShareCode(residenceID, createdByID uint, expiresIn time.Duration) (*models.ResidenceShareCode, error) {
// Deactivate existing codes for this residence
err := r.db.Model(&models.ResidenceShareCode{}).
Where("residence_id = ? AND is_active = ?", residenceID, true).
Update("is_active", false).Error
if err != nil {
return nil, err
}
// Generate unique 6-character code
code, err := r.generateUniqueCode()
if err != nil {
return nil, err
}
expiresAt := time.Now().UTC().Add(expiresIn)
shareCode := &models.ResidenceShareCode{
ResidenceID: residenceID,
Code: code,
CreatedByID: createdByID,
IsActive: true,
ExpiresAt: &expiresAt,
}
if err := r.db.Create(shareCode).Error; err != nil {
return nil, err
}
return shareCode, nil
}
// FindShareCodeByCode finds an active share code by its code string
func (r *ResidenceRepository) FindShareCodeByCode(code string) (*models.ResidenceShareCode, error) {
var shareCode models.ResidenceShareCode
err := r.db.Preload("Residence").
Where("code = ? AND is_active = ?", code, true).
First(&shareCode).Error
if err != nil {
return nil, err
}
// Check if expired
if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) {
return nil, errors.New("share code has expired")
}
return &shareCode, nil
}
// DeactivateShareCode deactivates a share code
func (r *ResidenceRepository) DeactivateShareCode(codeID uint) error {
return r.db.Model(&models.ResidenceShareCode{}).
Where("id = ?", codeID).
Update("is_active", false).Error
}
// GetActiveShareCode gets the active share code for a residence (if any)
func (r *ResidenceRepository) GetActiveShareCode(residenceID uint) (*models.ResidenceShareCode, error) {
var shareCode models.ResidenceShareCode
err := r.db.Where("residence_id = ? AND is_active = ?", residenceID, true).
First(&shareCode).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
// Check if expired
if shareCode.ExpiresAt != nil && time.Now().UTC().After(*shareCode.ExpiresAt) {
// Auto-deactivate expired code
r.DeactivateShareCode(shareCode.ID)
return nil, nil
}
return &shareCode, nil
}
// generateUniqueCode generates a unique 6-character alphanumeric code
func (r *ResidenceRepository) generateUniqueCode() (string, error) {
const charset = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" // Removed ambiguous chars: 0, O, I, 1
const codeLength = 6
maxAttempts := 10
for attempt := 0; attempt < maxAttempts; attempt++ {
code := make([]byte, codeLength)
for i := range code {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
code[i] = charset[num.Int64()]
}
codeStr := string(code)
// Check if code already exists
var count int64
r.db.Model(&models.ResidenceShareCode{}).
Where("code = ? AND is_active = ?", codeStr, true).
Count(&count)
if count == 0 {
return codeStr, nil
}
}
return "", errors.New("failed to generate unique share code")
}
// === Residence Type Operations ===
// GetAllResidenceTypes returns all residence types
func (r *ResidenceRepository) GetAllResidenceTypes() ([]models.ResidenceType, error) {
var types []models.ResidenceType
err := r.db.Order("id").Find(&types).Error
return types, err
}
// FindResidenceTypeByID finds a residence type by ID
func (r *ResidenceRepository) FindResidenceTypeByID(id uint) (*models.ResidenceType, error) {
var residenceType models.ResidenceType
err := r.db.First(&residenceType, id).Error
if err != nil {
return nil, err
}
return &residenceType, nil
}

View File

@@ -0,0 +1,203 @@
package repositories
import (
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
// SubscriptionRepository handles database operations for subscriptions
type SubscriptionRepository struct {
db *gorm.DB
}
// NewSubscriptionRepository creates a new subscription repository
func NewSubscriptionRepository(db *gorm.DB) *SubscriptionRepository {
return &SubscriptionRepository{db: db}
}
// === User Subscription ===
// FindByUserID finds a subscription by user ID
func (r *SubscriptionRepository) FindByUserID(userID uint) (*models.UserSubscription, error) {
var sub models.UserSubscription
err := r.db.Where("user_id = ?", userID).First(&sub).Error
if err != nil {
return nil, err
}
return &sub, nil
}
// GetOrCreate gets or creates a subscription for a user (defaults to free tier)
func (r *SubscriptionRepository) GetOrCreate(userID uint) (*models.UserSubscription, error) {
sub, err := r.FindByUserID(userID)
if err == nil {
return sub, nil
}
if err == gorm.ErrRecordNotFound {
sub = &models.UserSubscription{
UserID: userID,
Tier: models.TierFree,
AutoRenew: true,
}
if err := r.db.Create(sub).Error; err != nil {
return nil, err
}
return sub, nil
}
return nil, err
}
// Update updates a subscription
func (r *SubscriptionRepository) Update(sub *models.UserSubscription) error {
return r.db.Save(sub).Error
}
// UpgradeToPro upgrades a user to Pro tier
func (r *SubscriptionRepository) UpgradeToPro(userID uint, expiresAt time.Time, platform string) error {
now := time.Now().UTC()
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Updates(map[string]interface{}{
"tier": models.TierPro,
"subscribed_at": now,
"expires_at": expiresAt,
"cancelled_at": nil,
"platform": platform,
"auto_renew": true,
}).Error
}
// DowngradeToFree downgrades a user to Free tier
func (r *SubscriptionRepository) DowngradeToFree(userID uint) error {
now := time.Now().UTC()
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Updates(map[string]interface{}{
"tier": models.TierFree,
"cancelled_at": now,
"auto_renew": false,
}).Error
}
// SetAutoRenew sets the auto-renew flag
func (r *SubscriptionRepository) SetAutoRenew(userID uint, autoRenew bool) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("auto_renew", autoRenew).Error
}
// UpdateReceiptData updates the Apple receipt data
func (r *SubscriptionRepository) UpdateReceiptData(userID uint, receiptData string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("apple_receipt_data", receiptData).Error
}
// UpdatePurchaseToken updates the Google purchase token
func (r *SubscriptionRepository) UpdatePurchaseToken(userID uint, token string) error {
return r.db.Model(&models.UserSubscription{}).
Where("user_id = ?", userID).
Update("google_purchase_token", token).Error
}
// === Tier Limits ===
// GetTierLimits gets the limits for a subscription tier
func (r *SubscriptionRepository) GetTierLimits(tier models.SubscriptionTier) (*models.TierLimits, error) {
var limits models.TierLimits
err := r.db.Where("tier = ?", tier).First(&limits).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
// Return defaults
if tier == models.TierFree {
defaults := models.GetDefaultFreeLimits()
return &defaults, nil
}
defaults := models.GetDefaultProLimits()
return &defaults, nil
}
return nil, err
}
return &limits, nil
}
// GetAllTierLimits gets all tier limits
func (r *SubscriptionRepository) GetAllTierLimits() ([]models.TierLimits, error) {
var limits []models.TierLimits
err := r.db.Find(&limits).Error
return limits, err
}
// === Subscription Settings (Singleton) ===
// GetSettings gets the subscription settings
func (r *SubscriptionRepository) GetSettings() (*models.SubscriptionSettings, error) {
var settings models.SubscriptionSettings
err := r.db.First(&settings).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
// Return default settings (limitations disabled)
return &models.SubscriptionSettings{
EnableLimitations: false,
}, nil
}
return nil, err
}
return &settings, nil
}
// === Upgrade Triggers ===
// GetUpgradeTrigger gets an upgrade trigger by key
func (r *SubscriptionRepository) GetUpgradeTrigger(key string) (*models.UpgradeTrigger, error) {
var trigger models.UpgradeTrigger
err := r.db.Where("trigger_key = ? AND is_active = ?", key, true).First(&trigger).Error
if err != nil {
return nil, err
}
return &trigger, nil
}
// GetAllUpgradeTriggers gets all active upgrade triggers
func (r *SubscriptionRepository) GetAllUpgradeTriggers() ([]models.UpgradeTrigger, error) {
var triggers []models.UpgradeTrigger
err := r.db.Where("is_active = ?", true).Find(&triggers).Error
return triggers, err
}
// === Feature Benefits ===
// GetFeatureBenefits gets all active feature benefits
func (r *SubscriptionRepository) GetFeatureBenefits() ([]models.FeatureBenefit, error) {
var benefits []models.FeatureBenefit
err := r.db.Where("is_active = ?", true).Order("display_order").Find(&benefits).Error
return benefits, err
}
// === Promotions ===
// GetActivePromotions gets all currently active promotions for a tier
func (r *SubscriptionRepository) GetActivePromotions(tier models.SubscriptionTier) ([]models.Promotion, error) {
now := time.Now().UTC()
var promotions []models.Promotion
err := r.db.Where("is_active = ? AND target_tier = ? AND start_date <= ? AND end_date >= ?",
true, tier, now, now).
Order("start_date DESC").
Find(&promotions).Error
return promotions, err
}
// GetPromotionByID gets a promotion by ID
func (r *SubscriptionRepository) GetPromotionByID(promotionID string) (*models.Promotion, error) {
var promotion models.Promotion
err := r.db.Where("promotion_id = ? AND is_active = ?", promotionID, true).First(&promotion).Error
if err != nil {
return nil, err
}
return &promotion, nil
}

View File

@@ -0,0 +1,347 @@
package repositories
import (
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
// TaskRepository handles database operations for tasks
type TaskRepository struct {
db *gorm.DB
}
// NewTaskRepository creates a new task repository
func NewTaskRepository(db *gorm.DB) *TaskRepository {
return &TaskRepository{db: db}
}
// === Task CRUD ===
// FindByID finds a task by ID with preloaded relations
func (r *TaskRepository) FindByID(id uint) (*models.Task, error) {
var task models.Task
err := r.db.Preload("Residence").
Preload("CreatedBy").
Preload("AssignedTo").
Preload("Category").
Preload("Priority").
Preload("Status").
Preload("Frequency").
Preload("Completions").
Preload("Completions.CompletedBy").
First(&task, id).Error
if err != nil {
return nil, err
}
return &task, nil
}
// FindByResidence finds all tasks for a residence
func (r *TaskRepository) FindByResidence(residenceID uint) ([]models.Task, error) {
var tasks []models.Task
err := r.db.Preload("CreatedBy").
Preload("AssignedTo").
Preload("Category").
Preload("Priority").
Preload("Status").
Preload("Frequency").
Preload("Completions").
Where("residence_id = ?", residenceID).
Order("due_date ASC NULLS LAST, created_at DESC").
Find(&tasks).Error
return tasks, err
}
// FindByUser finds all tasks accessible to a user (across all their residences)
func (r *TaskRepository) FindByUser(userID uint, residenceIDs []uint) ([]models.Task, error) {
var tasks []models.Task
err := r.db.Preload("Residence").
Preload("CreatedBy").
Preload("AssignedTo").
Preload("Category").
Preload("Priority").
Preload("Status").
Preload("Frequency").
Preload("Completions").
Where("residence_id IN ?", residenceIDs).
Order("due_date ASC NULLS LAST, created_at DESC").
Find(&tasks).Error
return tasks, err
}
// Create creates a new task
func (r *TaskRepository) Create(task *models.Task) error {
return r.db.Create(task).Error
}
// Update updates a task
func (r *TaskRepository) Update(task *models.Task) error {
return r.db.Save(task).Error
}
// Delete hard-deletes a task
func (r *TaskRepository) Delete(id uint) error {
return r.db.Delete(&models.Task{}, id).Error
}
// === Task State Operations ===
// MarkInProgress marks a task as in progress
func (r *TaskRepository) MarkInProgress(id uint, statusID uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("status_id", statusID).Error
}
// Cancel cancels a task
func (r *TaskRepository) Cancel(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_cancelled", true).Error
}
// Uncancel uncancels a task
func (r *TaskRepository) Uncancel(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_cancelled", false).Error
}
// Archive archives a task
func (r *TaskRepository) Archive(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_archived", true).Error
}
// Unarchive unarchives a task
func (r *TaskRepository) Unarchive(id uint) error {
return r.db.Model(&models.Task{}).
Where("id = ?", id).
Update("is_archived", false).Error
}
// === Kanban Board ===
// GetKanbanData retrieves tasks organized for kanban display
func (r *TaskRepository) GetKanbanData(residenceID uint, daysThreshold int) (*models.KanbanBoard, error) {
var tasks []models.Task
err := r.db.Preload("CreatedBy").
Preload("AssignedTo").
Preload("Category").
Preload("Priority").
Preload("Status").
Preload("Frequency").
Preload("Completions").
Preload("Completions.CompletedBy").
Where("residence_id = ? AND is_archived = ?", residenceID, false).
Order("due_date ASC NULLS LAST, priority_id DESC, created_at DESC").
Find(&tasks).Error
if err != nil {
return nil, err
}
// Organize into columns
now := time.Now().UTC()
threshold := now.AddDate(0, 0, daysThreshold)
overdue := make([]models.Task, 0)
dueSoon := make([]models.Task, 0)
upcoming := make([]models.Task, 0)
inProgress := make([]models.Task, 0)
completed := make([]models.Task, 0)
cancelled := make([]models.Task, 0)
for _, task := range tasks {
if task.IsCancelled {
cancelled = append(cancelled, task)
continue
}
// Check if completed (has completions)
if len(task.Completions) > 0 {
completed = append(completed, task)
continue
}
// Check status for in-progress (status_id = 2 typically)
if task.Status != nil && task.Status.Name == "In Progress" {
inProgress = append(inProgress, task)
continue
}
// Check due date
if task.DueDate != nil {
if task.DueDate.Before(now) {
overdue = append(overdue, task)
} else if task.DueDate.Before(threshold) {
dueSoon = append(dueSoon, task)
} else {
upcoming = append(upcoming, task)
}
} else {
upcoming = append(upcoming, task)
}
}
columns := []models.KanbanColumn{
{
Name: "overdue_tasks",
DisplayName: "Overdue",
ButtonTypes: []string{"edit", "cancel", "mark_in_progress"},
Icons: map[string]string{"ios": "exclamationmark.triangle", "android": "Warning"},
Color: "#FF3B30",
Tasks: overdue,
Count: len(overdue),
},
{
Name: "due_soon_tasks",
DisplayName: "Due Soon",
ButtonTypes: []string{"edit", "complete", "mark_in_progress"},
Icons: map[string]string{"ios": "clock", "android": "Schedule"},
Color: "#FF9500",
Tasks: dueSoon,
Count: len(dueSoon),
},
{
Name: "upcoming_tasks",
DisplayName: "Upcoming",
ButtonTypes: []string{"edit", "cancel"},
Icons: map[string]string{"ios": "calendar", "android": "Event"},
Color: "#007AFF",
Tasks: upcoming,
Count: len(upcoming),
},
{
Name: "in_progress_tasks",
DisplayName: "In Progress",
ButtonTypes: []string{"edit", "complete"},
Icons: map[string]string{"ios": "hammer", "android": "Build"},
Color: "#5856D6",
Tasks: inProgress,
Count: len(inProgress),
},
{
Name: "completed_tasks",
DisplayName: "Completed",
ButtonTypes: []string{"view"},
Icons: map[string]string{"ios": "checkmark.circle", "android": "CheckCircle"},
Color: "#34C759",
Tasks: completed,
Count: len(completed),
},
{
Name: "cancelled_tasks",
DisplayName: "Cancelled",
ButtonTypes: []string{"uncancel", "delete"},
Icons: map[string]string{"ios": "xmark.circle", "android": "Cancel"},
Color: "#8E8E93",
Tasks: cancelled,
Count: len(cancelled),
},
}
return &models.KanbanBoard{
Columns: columns,
DaysThreshold: daysThreshold,
ResidenceID: string(rune(residenceID)),
}, nil
}
// === Lookup Operations ===
// GetAllCategories returns all task categories
func (r *TaskRepository) GetAllCategories() ([]models.TaskCategory, error) {
var categories []models.TaskCategory
err := r.db.Order("display_order, name").Find(&categories).Error
return categories, err
}
// GetAllPriorities returns all task priorities
func (r *TaskRepository) GetAllPriorities() ([]models.TaskPriority, error) {
var priorities []models.TaskPriority
err := r.db.Order("level").Find(&priorities).Error
return priorities, err
}
// GetAllStatuses returns all task statuses
func (r *TaskRepository) GetAllStatuses() ([]models.TaskStatus, error) {
var statuses []models.TaskStatus
err := r.db.Order("display_order").Find(&statuses).Error
return statuses, err
}
// GetAllFrequencies returns all task frequencies
func (r *TaskRepository) GetAllFrequencies() ([]models.TaskFrequency, error) {
var frequencies []models.TaskFrequency
err := r.db.Order("display_order").Find(&frequencies).Error
return frequencies, err
}
// FindStatusByName finds a status by name
func (r *TaskRepository) FindStatusByName(name string) (*models.TaskStatus, error) {
var status models.TaskStatus
err := r.db.Where("name = ?", name).First(&status).Error
if err != nil {
return nil, err
}
return &status, nil
}
// CountByResidence counts tasks in a residence
func (r *TaskRepository) CountByResidence(residenceID uint) (int64, error) {
var count int64
err := r.db.Model(&models.Task{}).
Where("residence_id = ? AND is_cancelled = ? AND is_archived = ?", residenceID, false, false).
Count(&count).Error
return count, err
}
// === Task Completion Operations ===
// CreateCompletion creates a new task completion
func (r *TaskRepository) CreateCompletion(completion *models.TaskCompletion) error {
return r.db.Create(completion).Error
}
// FindCompletionByID finds a completion by ID
func (r *TaskRepository) FindCompletionByID(id uint) (*models.TaskCompletion, error) {
var completion models.TaskCompletion
err := r.db.Preload("Task").
Preload("CompletedBy").
First(&completion, id).Error
if err != nil {
return nil, err
}
return &completion, nil
}
// FindCompletionsByTask finds all completions for a task
func (r *TaskRepository) FindCompletionsByTask(taskID uint) ([]models.TaskCompletion, error) {
var completions []models.TaskCompletion
err := r.db.Preload("CompletedBy").
Where("task_id = ?", taskID).
Order("completed_at DESC").
Find(&completions).Error
return completions, err
}
// FindCompletionsByUser finds all completions by a user
func (r *TaskRepository) FindCompletionsByUser(userID uint, residenceIDs []uint) ([]models.TaskCompletion, error) {
var completions []models.TaskCompletion
err := r.db.Preload("Task").
Preload("CompletedBy").
Joins("JOIN task_task ON task_task.id = task_taskcompletion.task_id").
Where("task_task.residence_id IN ?", residenceIDs).
Order("completed_at DESC").
Find(&completions).Error
return completions, err
}
// DeleteCompletion deletes a task completion
func (r *TaskRepository) DeleteCompletion(id uint) error {
return r.db.Delete(&models.TaskCompletion{}, id).Error
}

View File

@@ -0,0 +1,373 @@
package repositories
import (
"errors"
"strings"
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrUserExists = errors.New("user already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenNotFound = errors.New("token not found")
ErrCodeNotFound = errors.New("code not found")
ErrCodeExpired = errors.New("code expired")
ErrCodeUsed = errors.New("code already used")
ErrTooManyAttempts = errors.New("too many attempts")
ErrRateLimitExceeded = errors.New("rate limit exceeded")
)
// UserRepository handles user-related database operations
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository creates a new user repository
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// FindByID finds a user by ID
func (r *UserRepository) FindByID(id uint) (*models.User, error) {
var user models.User
if err := r.db.First(&user, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// FindByIDWithProfile finds a user by ID with profile preloaded
func (r *UserRepository) FindByIDWithProfile(id uint) (*models.User, error) {
var user models.User
if err := r.db.Preload("Profile").First(&user, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// FindByUsername finds a user by username (case-insensitive)
func (r *UserRepository) FindByUsername(username string) (*models.User, error) {
var user models.User
if err := r.db.Where("LOWER(username) = LOWER(?)", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// FindByEmail finds a user by email (case-insensitive)
func (r *UserRepository) FindByEmail(email string) (*models.User, error) {
var user models.User
if err := r.db.Where("LOWER(email) = LOWER(?)", email).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// FindByUsernameOrEmail finds a user by username or email
func (r *UserRepository) FindByUsernameOrEmail(identifier string) (*models.User, error) {
var user models.User
if err := r.db.Where("LOWER(username) = LOWER(?) OR LOWER(email) = LOWER(?)", identifier, identifier).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err
}
return &user, nil
}
// Create creates a new user
func (r *UserRepository) Create(user *models.User) error {
return r.db.Create(user).Error
}
// Update updates a user
func (r *UserRepository) Update(user *models.User) error {
return r.db.Save(user).Error
}
// UpdateLastLogin updates the user's last login timestamp
func (r *UserRepository) UpdateLastLogin(userID uint) error {
now := time.Now().UTC()
return r.db.Model(&models.User{}).Where("id = ?", userID).Update("last_login", now).Error
}
// ExistsByUsername checks if a username exists
func (r *UserRepository) ExistsByUsername(username string) (bool, error) {
var count int64
if err := r.db.Model(&models.User{}).Where("LOWER(username) = LOWER(?)", username).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
// ExistsByEmail checks if an email exists
func (r *UserRepository) ExistsByEmail(email string) (bool, error) {
var count int64
if err := r.db.Model(&models.User{}).Where("LOWER(email) = LOWER(?)", email).Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
// --- Auth Token Methods ---
// GetOrCreateToken gets or creates an auth token for a user
func (r *UserRepository) GetOrCreateToken(userID uint) (*models.AuthToken, error) {
var token models.AuthToken
result := r.db.Where("user_id = ?", userID).First(&token)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
token = models.AuthToken{UserID: userID}
if err := r.db.Create(&token).Error; err != nil {
return nil, err
}
} else if result.Error != nil {
return nil, result.Error
}
return &token, nil
}
// DeleteToken deletes an auth token
func (r *UserRepository) DeleteToken(token string) error {
result := r.db.Where("key = ?", token).Delete(&models.AuthToken{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return ErrTokenNotFound
}
return nil
}
// DeleteTokenByUserID deletes an auth token by user ID
func (r *UserRepository) DeleteTokenByUserID(userID uint) error {
return r.db.Where("user_id = ?", userID).Delete(&models.AuthToken{}).Error
}
// --- User Profile Methods ---
// GetOrCreateProfile gets or creates a user profile
func (r *UserRepository) GetOrCreateProfile(userID uint) (*models.UserProfile, error) {
var profile models.UserProfile
result := r.db.Where("user_id = ?", userID).First(&profile)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
profile = models.UserProfile{UserID: userID}
if err := r.db.Create(&profile).Error; err != nil {
return nil, err
}
} else if result.Error != nil {
return nil, result.Error
}
return &profile, nil
}
// UpdateProfile updates a user profile
func (r *UserRepository) UpdateProfile(profile *models.UserProfile) error {
return r.db.Save(profile).Error
}
// SetProfileVerified sets the profile verified status
func (r *UserRepository) SetProfileVerified(userID uint, verified bool) error {
return r.db.Model(&models.UserProfile{}).Where("user_id = ?", userID).Update("verified", verified).Error
}
// --- Confirmation Code Methods ---
// CreateConfirmationCode creates a new confirmation code
func (r *UserRepository) CreateConfirmationCode(userID uint, code string, expiresAt time.Time) (*models.ConfirmationCode, error) {
// Invalidate any existing unused codes for this user
r.db.Model(&models.ConfirmationCode{}).
Where("user_id = ? AND is_used = ?", userID, false).
Update("is_used", true)
confirmCode := &models.ConfirmationCode{
UserID: userID,
Code: code,
ExpiresAt: expiresAt,
IsUsed: false,
}
if err := r.db.Create(confirmCode).Error; err != nil {
return nil, err
}
return confirmCode, nil
}
// FindConfirmationCode finds a valid confirmation code for a user
func (r *UserRepository) FindConfirmationCode(userID uint, code string) (*models.ConfirmationCode, error) {
var confirmCode models.ConfirmationCode
if err := r.db.Where("user_id = ? AND code = ? AND is_used = ?", userID, code, false).
First(&confirmCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCodeNotFound
}
return nil, err
}
if !confirmCode.IsValid() {
if confirmCode.IsUsed {
return nil, ErrCodeUsed
}
return nil, ErrCodeExpired
}
return &confirmCode, nil
}
// MarkConfirmationCodeUsed marks a confirmation code as used
func (r *UserRepository) MarkConfirmationCodeUsed(codeID uint) error {
return r.db.Model(&models.ConfirmationCode{}).Where("id = ?", codeID).Update("is_used", true).Error
}
// --- Password Reset Code Methods ---
// CreatePasswordResetCode creates a new password reset code
func (r *UserRepository) CreatePasswordResetCode(userID uint, codeHash string, resetToken string, expiresAt time.Time) (*models.PasswordResetCode, error) {
// Invalidate any existing unused codes for this user
r.db.Model(&models.PasswordResetCode{}).
Where("user_id = ? AND used = ?", userID, false).
Update("used", true)
resetCode := &models.PasswordResetCode{
UserID: userID,
CodeHash: codeHash,
ResetToken: resetToken,
ExpiresAt: expiresAt,
Used: false,
Attempts: 0,
MaxAttempts: 5,
}
if err := r.db.Create(resetCode).Error; err != nil {
return nil, err
}
return resetCode, nil
}
// FindPasswordResetCode finds a password reset code by email and checks validity
func (r *UserRepository) FindPasswordResetCodeByEmail(email string) (*models.PasswordResetCode, *models.User, error) {
user, err := r.FindByEmail(email)
if err != nil {
return nil, nil, err
}
var resetCode models.PasswordResetCode
if err := r.db.Where("user_id = ? AND used = ?", user.ID, false).
Order("created_at DESC").
First(&resetCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrCodeNotFound
}
return nil, nil, err
}
return &resetCode, user, nil
}
// FindPasswordResetCodeByToken finds a password reset code by reset token
func (r *UserRepository) FindPasswordResetCodeByToken(resetToken string) (*models.PasswordResetCode, error) {
var resetCode models.PasswordResetCode
if err := r.db.Where("reset_token = ?", resetToken).First(&resetCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCodeNotFound
}
return nil, err
}
if !resetCode.IsValid() {
if resetCode.Used {
return nil, ErrCodeUsed
}
if resetCode.Attempts >= resetCode.MaxAttempts {
return nil, ErrTooManyAttempts
}
return nil, ErrCodeExpired
}
return &resetCode, nil
}
// IncrementResetCodeAttempts increments the attempt counter
func (r *UserRepository) IncrementResetCodeAttempts(codeID uint) error {
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).
Update("attempts", gorm.Expr("attempts + 1")).Error
}
// MarkPasswordResetCodeUsed marks a password reset code as used
func (r *UserRepository) MarkPasswordResetCodeUsed(codeID uint) error {
return r.db.Model(&models.PasswordResetCode{}).Where("id = ?", codeID).Update("used", true).Error
}
// CountRecentPasswordResetRequests counts reset requests in the last hour
func (r *UserRepository) CountRecentPasswordResetRequests(userID uint) (int64, error) {
var count int64
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
if err := r.db.Model(&models.PasswordResetCode{}).
Where("user_id = ? AND created_at > ?", userID, oneHourAgo).
Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
// --- Search Methods ---
// SearchUsers searches users by username, email, first name, or last name
func (r *UserRepository) SearchUsers(query string, limit, offset int) ([]models.User, int64, error) {
var users []models.User
var total int64
searchQuery := "%" + strings.ToLower(query) + "%"
baseQuery := r.db.Model(&models.User{}).
Where("LOWER(username) LIKE ? OR LOWER(email) LIKE ? OR LOWER(first_name) LIKE ? OR LOWER(last_name) LIKE ?",
searchQuery, searchQuery, searchQuery, searchQuery)
if err := baseQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := baseQuery.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// ListUsers lists all users with pagination
func (r *UserRepository) ListUsers(limit, offset int) ([]models.User, int64, error) {
var users []models.User
var total int64
if err := r.db.Model(&models.User{}).Count(&total).Error; err != nil {
return nil, 0, err
}
if err := r.db.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}

325
internal/router/router.go Normal file
View File

@@ -0,0 +1,325 @@
package router
import (
"net/http"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/config"
"github.com/treytartt/mycrib-api/internal/handlers"
"github.com/treytartt/mycrib-api/internal/middleware"
"github.com/treytartt/mycrib-api/internal/push"
"github.com/treytartt/mycrib-api/internal/repositories"
"github.com/treytartt/mycrib-api/internal/services"
"github.com/treytartt/mycrib-api/pkg/utils"
)
const Version = "2.0.0"
// Dependencies holds all dependencies needed by the router
type Dependencies struct {
DB *gorm.DB
Cache *services.CacheService
Config *config.Config
EmailService *services.EmailService
PushClient interface{} // *push.GorushClient - optional
}
// SetupRouter creates and configures the Gin router
func SetupRouter(deps *Dependencies) *gin.Engine {
cfg := deps.Config
// Set Gin mode based on debug setting
if cfg.Server.Debug {
gin.SetMode(gin.DebugMode)
} else {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
// Global middleware
r.Use(utils.GinRecovery())
r.Use(utils.GinLogger())
r.Use(corsMiddleware(cfg))
// Health check endpoint (no auth required)
r.GET("/api/health/", healthCheck)
// Initialize repositories
userRepo := repositories.NewUserRepository(deps.DB)
residenceRepo := repositories.NewResidenceRepository(deps.DB)
taskRepo := repositories.NewTaskRepository(deps.DB)
contractorRepo := repositories.NewContractorRepository(deps.DB)
documentRepo := repositories.NewDocumentRepository(deps.DB)
notificationRepo := repositories.NewNotificationRepository(deps.DB)
subscriptionRepo := repositories.NewSubscriptionRepository(deps.DB)
// Initialize push client (optional)
var gorushClient *push.GorushClient
if deps.PushClient != nil {
if gc, ok := deps.PushClient.(*push.GorushClient); ok {
gorushClient = gc
}
}
// Initialize services
authService := services.NewAuthService(userRepo, cfg)
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
taskService := services.NewTaskService(taskRepo, residenceRepo)
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
documentService := services.NewDocumentService(documentRepo, residenceRepo)
notificationService := services.NewNotificationService(notificationRepo, gorushClient)
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
// Initialize middleware
authMiddleware := middleware.NewAuthMiddleware(deps.DB, deps.Cache)
// Initialize handlers
authHandler := handlers.NewAuthHandler(authService, deps.EmailService, deps.Cache)
residenceHandler := handlers.NewResidenceHandler(residenceService)
taskHandler := handlers.NewTaskHandler(taskService)
contractorHandler := handlers.NewContractorHandler(contractorService)
documentHandler := handlers.NewDocumentHandler(documentService)
notificationHandler := handlers.NewNotificationHandler(notificationService)
subscriptionHandler := handlers.NewSubscriptionHandler(subscriptionService)
// API group
api := r.Group("/api")
{
// Public auth routes (no auth required)
setupPublicAuthRoutes(api, authHandler)
// Public data routes (no auth required)
setupPublicDataRoutes(api, residenceHandler, taskHandler, contractorHandler)
// Protected routes (auth required)
protected := api.Group("")
protected.Use(authMiddleware.TokenAuth())
{
setupProtectedAuthRoutes(protected, authHandler)
setupResidenceRoutes(protected, residenceHandler)
setupTaskRoutes(protected, taskHandler)
setupContractorRoutes(protected, contractorHandler)
setupDocumentRoutes(protected, documentHandler)
setupNotificationRoutes(protected, notificationHandler)
setupSubscriptionRoutes(protected, subscriptionHandler)
setupUserRoutes(protected)
}
}
return r
}
// corsMiddleware configures CORS
func corsMiddleware(cfg *config.Config) gin.HandlerFunc {
corsConfig := cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Type", "Accept", "Authorization"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
MaxAge: 12 * time.Hour,
}
// In debug mode, allow all origins; otherwise use configured hosts
if cfg.Server.Debug {
corsConfig.AllowAllOrigins = true
} else {
corsConfig.AllowOrigins = cfg.Server.AllowedHosts
}
return cors.New(corsConfig)
}
// healthCheck returns API health status
func healthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"version": Version,
"framework": "Gin",
"timestamp": time.Now().UTC().Format(time.RFC3339),
})
}
// setupPublicAuthRoutes configures public authentication routes
func setupPublicAuthRoutes(api *gin.RouterGroup, authHandler *handlers.AuthHandler) {
auth := api.Group("/auth")
{
auth.POST("/login/", authHandler.Login)
auth.POST("/register/", authHandler.Register)
auth.POST("/forgot-password/", authHandler.ForgotPassword)
auth.POST("/verify-reset-code/", authHandler.VerifyResetCode)
auth.POST("/reset-password/", authHandler.ResetPassword)
}
}
// setupProtectedAuthRoutes configures protected authentication routes
func setupProtectedAuthRoutes(api *gin.RouterGroup, authHandler *handlers.AuthHandler) {
auth := api.Group("/auth")
{
auth.POST("/logout/", authHandler.Logout)
auth.GET("/me/", authHandler.CurrentUser)
auth.PUT("/profile/", authHandler.UpdateProfile)
auth.PATCH("/profile/", authHandler.UpdateProfile)
auth.POST("/verify-email/", authHandler.VerifyEmail)
auth.POST("/resend-verification/", authHandler.ResendVerification)
}
}
// setupPublicDataRoutes configures public data routes (lookups, static data)
func setupPublicDataRoutes(api *gin.RouterGroup, residenceHandler *handlers.ResidenceHandler, taskHandler *handlers.TaskHandler, contractorHandler *handlers.ContractorHandler) {
// Static data routes (public, cached)
staticData := api.Group("/static_data")
{
staticData.GET("/", placeholderHandler("get-static-data"))
staticData.POST("/refresh/", placeholderHandler("refresh-static-data"))
}
// Lookup routes (public)
api.GET("/residences/types/", residenceHandler.GetResidenceTypes)
api.GET("/tasks/categories/", taskHandler.GetCategories)
api.GET("/tasks/priorities/", taskHandler.GetPriorities)
api.GET("/tasks/frequencies/", taskHandler.GetFrequencies)
api.GET("/tasks/statuses/", taskHandler.GetStatuses)
api.GET("/contractors/specialties/", contractorHandler.GetSpecialties)
}
// setupResidenceRoutes configures residence routes
func setupResidenceRoutes(api *gin.RouterGroup, residenceHandler *handlers.ResidenceHandler) {
residences := api.Group("/residences")
{
residences.GET("/", residenceHandler.ListResidences)
residences.POST("/", residenceHandler.CreateResidence)
residences.GET("/my-residences/", residenceHandler.GetMyResidences)
residences.POST("/join-with-code/", residenceHandler.JoinWithCode)
residences.GET("/:id/", residenceHandler.GetResidence)
residences.PUT("/:id/", residenceHandler.UpdateResidence)
residences.PATCH("/:id/", residenceHandler.UpdateResidence)
residences.DELETE("/:id/", residenceHandler.DeleteResidence)
residences.POST("/:id/generate-share-code/", residenceHandler.GenerateShareCode)
residences.POST("/:id/generate-tasks-report/", placeholderHandler("generate-tasks-report"))
residences.GET("/:id/users/", residenceHandler.GetResidenceUsers)
residences.DELETE("/:id/users/:user_id/", residenceHandler.RemoveResidenceUser)
}
}
// setupTaskRoutes configures task routes
func setupTaskRoutes(api *gin.RouterGroup, taskHandler *handlers.TaskHandler) {
tasks := api.Group("/tasks")
{
tasks.GET("/", taskHandler.ListTasks)
tasks.POST("/", taskHandler.CreateTask)
tasks.GET("/by-residence/:residence_id/", taskHandler.GetTasksByResidence)
tasks.GET("/:id/", taskHandler.GetTask)
tasks.PUT("/:id/", taskHandler.UpdateTask)
tasks.PATCH("/:id/", taskHandler.UpdateTask)
tasks.DELETE("/:id/", taskHandler.DeleteTask)
tasks.POST("/:id/mark-in-progress/", taskHandler.MarkInProgress)
tasks.POST("/:id/cancel/", taskHandler.CancelTask)
tasks.POST("/:id/uncancel/", taskHandler.UncancelTask)
tasks.POST("/:id/archive/", taskHandler.ArchiveTask)
tasks.POST("/:id/unarchive/", taskHandler.UnarchiveTask)
}
// Task Completions
completions := api.Group("/task-completions")
{
completions.GET("/", taskHandler.ListCompletions)
completions.POST("/", taskHandler.CreateCompletion)
completions.GET("/:id/", taskHandler.GetCompletion)
completions.DELETE("/:id/", taskHandler.DeleteCompletion)
}
}
// setupContractorRoutes configures contractor routes
func setupContractorRoutes(api *gin.RouterGroup, contractorHandler *handlers.ContractorHandler) {
contractors := api.Group("/contractors")
{
contractors.GET("/", contractorHandler.ListContractors)
contractors.POST("/", contractorHandler.CreateContractor)
contractors.GET("/:id/", contractorHandler.GetContractor)
contractors.PUT("/:id/", contractorHandler.UpdateContractor)
contractors.PATCH("/:id/", contractorHandler.UpdateContractor)
contractors.DELETE("/:id/", contractorHandler.DeleteContractor)
contractors.POST("/:id/toggle-favorite/", contractorHandler.ToggleFavorite)
contractors.GET("/:id/tasks/", contractorHandler.GetContractorTasks)
}
}
// setupDocumentRoutes configures document routes
func setupDocumentRoutes(api *gin.RouterGroup, documentHandler *handlers.DocumentHandler) {
documents := api.Group("/documents")
{
documents.GET("/", documentHandler.ListDocuments)
documents.POST("/", documentHandler.CreateDocument)
documents.GET("/warranties/", documentHandler.ListWarranties)
documents.GET("/:id/", documentHandler.GetDocument)
documents.PUT("/:id/", documentHandler.UpdateDocument)
documents.PATCH("/:id/", documentHandler.UpdateDocument)
documents.DELETE("/:id/", documentHandler.DeleteDocument)
documents.POST("/:id/activate/", documentHandler.ActivateDocument)
documents.POST("/:id/deactivate/", documentHandler.DeactivateDocument)
}
}
// setupNotificationRoutes configures notification routes
func setupNotificationRoutes(api *gin.RouterGroup, notificationHandler *handlers.NotificationHandler) {
notifications := api.Group("/notifications")
{
notifications.GET("/", notificationHandler.ListNotifications)
notifications.GET("/unread-count/", notificationHandler.GetUnreadCount)
notifications.POST("/mark-all-read/", notificationHandler.MarkAllAsRead)
notifications.POST("/:id/read/", notificationHandler.MarkAsRead)
notifications.POST("/devices/", notificationHandler.RegisterDevice)
notifications.GET("/devices/", notificationHandler.ListDevices)
notifications.DELETE("/devices/:id/", notificationHandler.DeleteDevice)
notifications.GET("/preferences/", notificationHandler.GetPreferences)
notifications.PUT("/preferences/", notificationHandler.UpdatePreferences)
notifications.PATCH("/preferences/", notificationHandler.UpdatePreferences)
}
}
// setupSubscriptionRoutes configures subscription routes
func setupSubscriptionRoutes(api *gin.RouterGroup, subscriptionHandler *handlers.SubscriptionHandler) {
subscription := api.Group("/subscription")
{
subscription.GET("/", subscriptionHandler.GetSubscription)
subscription.GET("/status/", subscriptionHandler.GetSubscriptionStatus)
subscription.GET("/upgrade-trigger/:key/", subscriptionHandler.GetUpgradeTrigger)
subscription.GET("/features/", subscriptionHandler.GetFeatureBenefits)
subscription.GET("/promotions/", subscriptionHandler.GetPromotions)
subscription.POST("/purchase/", subscriptionHandler.ProcessPurchase)
subscription.POST("/cancel/", subscriptionHandler.CancelSubscription)
subscription.POST("/restore/", subscriptionHandler.RestoreSubscription)
}
}
// setupUserRoutes configures user routes
func setupUserRoutes(api *gin.RouterGroup) {
users := api.Group("/users")
{
users.GET("/", placeholderHandler("list-users"))
users.GET("/:id/", placeholderHandler("get-user"))
users.GET("/profiles/", placeholderHandler("list-profiles"))
}
}
// placeholderHandler returns a handler that indicates an endpoint is not yet implemented
func placeholderHandler(name string) gin.HandlerFunc {
return func(c *gin.Context) {
c.JSON(http.StatusNotImplemented, gin.H{
"error": "Endpoint not yet implemented",
"endpoint": name,
"message": "This endpoint is planned for future phases",
})
}
}

View File

@@ -0,0 +1,418 @@
package services
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"time"
"golang.org/x/crypto/bcrypt"
"github.com/treytartt/mycrib-api/internal/config"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/dto/responses"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/repositories"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrUsernameTaken = errors.New("username already taken")
ErrEmailTaken = errors.New("email already taken")
ErrUserInactive = errors.New("user account is inactive")
ErrInvalidCode = errors.New("invalid verification code")
ErrCodeExpired = errors.New("verification code expired")
ErrAlreadyVerified = errors.New("email already verified")
ErrRateLimitExceeded = errors.New("too many requests, please try again later")
ErrInvalidResetToken = errors.New("invalid or expired reset token")
)
// AuthService handles authentication business logic
type AuthService struct {
userRepo *repositories.UserRepository
cfg *config.Config
}
// NewAuthService creates a new auth service
func NewAuthService(userRepo *repositories.UserRepository, cfg *config.Config) *AuthService {
return &AuthService{
userRepo: userRepo,
cfg: cfg,
}
}
// Login authenticates a user and returns a token
func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginResponse, error) {
// Find user by username or email
identifier := req.Username
if identifier == "" {
identifier = req.Email
}
user, err := s.userRepo.FindByUsernameOrEmail(identifier)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
return nil, ErrInvalidCredentials
}
return nil, fmt.Errorf("failed to find user: %w", err)
}
// Check if user is active
if !user.IsActive {
return nil, ErrUserInactive
}
// Verify password
if !user.CheckPassword(req.Password) {
return nil, ErrInvalidCredentials
}
// Get or create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
}
// Update last login
if err := s.userRepo.UpdateLastLogin(user.ID); err != nil {
// Log error but don't fail the login
fmt.Printf("Failed to update last login: %v\n", err)
}
return &responses.LoginResponse{
Token: token.Key,
User: responses.NewUserResponse(user),
}, nil
}
// Register creates a new user account
func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.RegisterResponse, string, error) {
// Check if username exists
exists, err := s.userRepo.ExistsByUsername(req.Username)
if err != nil {
return nil, "", fmt.Errorf("failed to check username: %w", err)
}
if exists {
return nil, "", ErrUsernameTaken
}
// Check if email exists
exists, err = s.userRepo.ExistsByEmail(req.Email)
if err != nil {
return nil, "", fmt.Errorf("failed to check email: %w", err)
}
if exists {
return nil, "", ErrEmailTaken
}
// Create user
user := &models.User{
Username: req.Username,
Email: req.Email,
FirstName: req.FirstName,
LastName: req.LastName,
IsActive: true,
}
// Hash password
if err := user.SetPassword(req.Password); err != nil {
return nil, "", fmt.Errorf("failed to hash password: %w", err)
}
// Save user
if err := s.userRepo.Create(user); err != nil {
return nil, "", fmt.Errorf("failed to create user: %w", err)
}
// Create user profile
if _, err := s.userRepo.GetOrCreateProfile(user.ID); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create user profile: %v\n", err)
}
// Create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, "", fmt.Errorf("failed to create token: %w", err)
}
// Generate confirmation code
code := generateSixDigitCode()
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(user.ID, code, expiresAt); err != nil {
// Log error but don't fail registration
fmt.Printf("Failed to create confirmation code: %v\n", err)
}
return &responses.RegisterResponse{
Token: token.Key,
User: responses.NewUserResponse(user),
Message: "Registration successful. Please check your email to verify your account.",
}, code, nil
}
// Logout invalidates a user's token
func (s *AuthService) Logout(token string) error {
return s.userRepo.DeleteToken(token)
}
// GetCurrentUser returns the current authenticated user with profile
func (s *AuthService) GetCurrentUser(userID uint) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.FindByIDWithProfile(userID)
if err != nil {
return nil, err
}
response := responses.NewCurrentUserResponse(user)
return &response, nil
}
// UpdateProfile updates a user's profile
func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequest) (*responses.CurrentUserResponse, error) {
user, err := s.userRepo.FindByID(userID)
if err != nil {
return nil, err
}
// Check if new email is taken (if email is being changed)
if req.Email != nil && *req.Email != user.Email {
exists, err := s.userRepo.ExistsByEmail(*req.Email)
if err != nil {
return nil, fmt.Errorf("failed to check email: %w", err)
}
if exists {
return nil, ErrEmailTaken
}
user.Email = *req.Email
}
if req.FirstName != nil {
user.FirstName = *req.FirstName
}
if req.LastName != nil {
user.LastName = *req.LastName
}
if err := s.userRepo.Update(user); err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
}
// Reload with profile
user, err = s.userRepo.FindByIDWithProfile(userID)
if err != nil {
return nil, err
}
response := responses.NewCurrentUserResponse(user)
return &response, nil
}
// VerifyEmail verifies a user's email with a confirmation code
func (s *AuthService) VerifyEmail(userID uint, code string) error {
// Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID)
if err != nil {
return fmt.Errorf("failed to get profile: %w", err)
}
// Check if already verified
if profile.Verified {
return ErrAlreadyVerified
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
return fmt.Errorf("failed to verify profile: %w", err)
}
return nil
}
// Find and validate confirmation code
confirmCode, err := s.userRepo.FindConfirmationCode(userID, code)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) {
return ErrInvalidCode
}
if errors.Is(err, repositories.ErrCodeExpired) {
return ErrCodeExpired
}
return err
}
// Mark code as used
if err := s.userRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
return fmt.Errorf("failed to mark code as used: %w", err)
}
// Set profile as verified
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
return fmt.Errorf("failed to verify profile: %w", err)
}
return nil
}
// ResendVerificationCode creates and returns a new verification code
func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
// Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID)
if err != nil {
return "", fmt.Errorf("failed to get profile: %w", err)
}
// Check if already verified
if profile.Verified {
return "", ErrAlreadyVerified
}
// Generate new code
code := generateSixDigitCode()
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil {
return "", fmt.Errorf("failed to create confirmation code: %w", err)
}
return code, nil
}
// ForgotPassword initiates the password reset process
func (s *AuthService) ForgotPassword(email string) (string, *models.User, error) {
// Find user by email
user, err := s.userRepo.FindByEmail(email)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
// Don't reveal that the email doesn't exist
return "", nil, nil
}
return "", nil, err
}
// Check rate limit
count, err := s.userRepo.CountRecentPasswordResetRequests(user.ID)
if err != nil {
return "", nil, fmt.Errorf("failed to check rate limit: %w", err)
}
if count >= int64(s.cfg.Security.MaxPasswordResetRate) {
return "", nil, ErrRateLimitExceeded
}
// Generate code and reset token
code := generateSixDigitCode()
resetToken := generateResetToken()
expiresAt := time.Now().UTC().Add(s.cfg.Security.PasswordResetExpiry)
// Hash the code before storing
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return "", nil, fmt.Errorf("failed to hash code: %w", err)
}
if _, err := s.userRepo.CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
return "", nil, fmt.Errorf("failed to create reset code: %w", err)
}
return code, user, nil
}
// VerifyResetCode verifies a password reset code and returns a reset token
func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
// Find the reset code
resetCode, user, err := s.userRepo.FindPasswordResetCodeByEmail(email)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) {
return "", ErrInvalidCode
}
return "", err
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
return resetCode.ResetToken, nil
}
// Verify the code
if !resetCode.CheckCode(code) {
// Increment attempts
s.userRepo.IncrementResetCodeAttempts(resetCode.ID)
return "", ErrInvalidCode
}
// Check if code is still valid
if !resetCode.IsValid() {
if resetCode.Used {
return "", ErrInvalidCode
}
if resetCode.Attempts >= resetCode.MaxAttempts {
return "", ErrRateLimitExceeded
}
return "", ErrCodeExpired
}
_ = user // user available if needed
return resetCode.ResetToken, nil
}
// ResetPassword resets the user's password using a reset token
func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
// Find the reset code by token
resetCode, err := s.userRepo.FindPasswordResetCodeByToken(resetToken)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) {
return ErrInvalidResetToken
}
return err
}
// Get the user
user, err := s.userRepo.FindByID(resetCode.UserID)
if err != nil {
return fmt.Errorf("failed to find user: %w", err)
}
// Update password
if err := user.SetPassword(newPassword); err != nil {
return fmt.Errorf("failed to hash password: %w", err)
}
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
// Mark reset code as used
if err := s.userRepo.MarkPasswordResetCodeUsed(resetCode.ID); err != nil {
// Log error but don't fail
fmt.Printf("Failed to mark reset code as used: %v\n", err)
}
// Invalidate all existing tokens for this user (security measure)
if err := s.userRepo.DeleteTokenByUserID(user.ID); err != nil {
// Log error but don't fail
fmt.Printf("Failed to delete user tokens: %v\n", err)
}
return nil
}
// Helper functions
func generateSixDigitCode() string {
b := make([]byte, 4)
rand.Read(b)
num := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if num < 0 {
num = -num
}
code := num % 1000000
return fmt.Sprintf("%06d", code)
}
func generateResetToken() string {
b := make([]byte, 32)
rand.Read(b)
return hex.EncodeToString(b)
}

View File

@@ -0,0 +1,163 @@
package services
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog/log"
"github.com/treytartt/mycrib-api/internal/config"
)
// CacheService provides Redis caching functionality
type CacheService struct {
client *redis.Client
}
var cacheInstance *CacheService
// NewCacheService creates a new cache service
func NewCacheService(cfg *config.RedisConfig) (*CacheService, error) {
opt, err := redis.ParseURL(cfg.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
}
if cfg.Password != "" {
opt.Password = cfg.Password
}
if cfg.DB != 0 {
opt.DB = cfg.DB
}
client := redis.NewClient(opt)
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
log.Info().
Str("url", cfg.URL).
Int("db", opt.DB).
Msg("Connected to Redis")
cacheInstance = &CacheService{client: client}
return cacheInstance, nil
}
// GetCache returns the cache service instance
func GetCache() *CacheService {
return cacheInstance
}
// Client returns the underlying Redis client
func (c *CacheService) Client() *redis.Client {
return c.client
}
// Set stores a value with expiration
func (c *CacheService) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
return c.client.Set(ctx, key, data, expiration).Err()
}
// Get retrieves a value by key
func (c *CacheService) Get(ctx context.Context, key string, dest interface{}) error {
data, err := c.client.Get(ctx, key).Bytes()
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// GetString retrieves a string value by key
func (c *CacheService) GetString(ctx context.Context, key string) (string, error) {
return c.client.Get(ctx, key).Result()
}
// SetString stores a string value with expiration
func (c *CacheService) SetString(ctx context.Context, key string, value string, expiration time.Duration) error {
return c.client.Set(ctx, key, value, expiration).Err()
}
// Delete removes a key
func (c *CacheService) Delete(ctx context.Context, keys ...string) error {
return c.client.Del(ctx, keys...).Err()
}
// Exists checks if a key exists
func (c *CacheService) Exists(ctx context.Context, keys ...string) (int64, error) {
return c.client.Exists(ctx, keys...).Result()
}
// Close closes the Redis connection
func (c *CacheService) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// Auth token cache helpers
const (
AuthTokenPrefix = "auth_token_"
TokenCacheTTL = 5 * time.Minute
)
// CacheAuthToken caches a user ID for a token
func (c *CacheService) CacheAuthToken(ctx context.Context, token string, userID uint) error {
key := AuthTokenPrefix + token
return c.SetString(ctx, key, fmt.Sprintf("%d", userID), TokenCacheTTL)
}
// GetCachedAuthToken gets a cached user ID for a token
func (c *CacheService) GetCachedAuthToken(ctx context.Context, token string) (uint, error) {
key := AuthTokenPrefix + token
val, err := c.GetString(ctx, key)
if err != nil {
return 0, err
}
var userID uint
_, err = fmt.Sscanf(val, "%d", &userID)
return userID, err
}
// InvalidateAuthToken removes a cached token
func (c *CacheService) InvalidateAuthToken(ctx context.Context, token string) error {
key := AuthTokenPrefix + token
return c.Delete(ctx, key)
}
// Static data cache helpers
const (
StaticDataKey = "static_data"
StaticDataTTL = 1 * time.Hour
)
// CacheStaticData caches static lookup data
func (c *CacheService) CacheStaticData(ctx context.Context, data interface{}) error {
return c.Set(ctx, StaticDataKey, data, StaticDataTTL)
}
// GetCachedStaticData retrieves cached static data
func (c *CacheService) GetCachedStaticData(ctx context.Context, dest interface{}) error {
return c.Get(ctx, StaticDataKey, dest)
}
// InvalidateStaticData removes cached static data
func (c *CacheService) InvalidateStaticData(ctx context.Context) error {
return c.Delete(ctx, StaticDataKey)
}

View File

@@ -0,0 +1,312 @@
package services
import (
"errors"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/dto/responses"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/repositories"
)
// Contractor-related errors
var (
ErrContractorNotFound = errors.New("contractor not found")
ErrContractorAccessDenied = errors.New("you do not have access to this contractor")
)
// ContractorService handles contractor business logic
type ContractorService struct {
contractorRepo *repositories.ContractorRepository
residenceRepo *repositories.ResidenceRepository
}
// NewContractorService creates a new contractor service
func NewContractorService(contractorRepo *repositories.ContractorRepository, residenceRepo *repositories.ResidenceRepository) *ContractorService {
return &ContractorService{
contractorRepo: contractorRepo,
residenceRepo: residenceRepo,
}
}
// GetContractor gets a contractor by ID with access check
func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
}
return nil, err
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(contractor.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrContractorAccessDenied
}
resp := responses.NewContractorResponse(contractor)
return &resp, nil
}
// ListContractors lists all contractors accessible to a user
func (s *ContractorService) ListContractors(userID uint) (*responses.ContractorListResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
if len(residenceIDs) == 0 {
return &responses.ContractorListResponse{Count: 0, Results: []responses.ContractorResponse{}}, nil
}
contractors, err := s.contractorRepo.FindByUser(residenceIDs)
if err != nil {
return nil, err
}
resp := responses.NewContractorListResponse(contractors)
return &resp, nil
}
// CreateContractor creates a new contractor
func (s *ContractorService) CreateContractor(req *requests.CreateContractorRequest, userID uint) (*responses.ContractorResponse, error) {
// Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
}
isFavorite := false
if req.IsFavorite != nil {
isFavorite = *req.IsFavorite
}
contractor := &models.Contractor{
ResidenceID: req.ResidenceID,
CreatedByID: userID,
Name: req.Name,
Company: req.Company,
Phone: req.Phone,
Email: req.Email,
Website: req.Website,
Notes: req.Notes,
StreetAddress: req.StreetAddress,
City: req.City,
StateProvince: req.StateProvince,
PostalCode: req.PostalCode,
Rating: req.Rating,
IsFavorite: isFavorite,
IsActive: true,
}
if err := s.contractorRepo.Create(contractor); err != nil {
return nil, err
}
// Set specialties if provided
if len(req.SpecialtyIDs) > 0 {
if err := s.contractorRepo.SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
return nil, err
}
}
// Reload with relations
contractor, err = s.contractorRepo.FindByID(contractor.ID)
if err != nil {
return nil, err
}
resp := responses.NewContractorResponse(contractor)
return &resp, nil
}
// UpdateContractor updates a contractor
func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *requests.UpdateContractorRequest) (*responses.ContractorResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(contractor.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrContractorAccessDenied
}
// Apply updates
if req.Name != nil {
contractor.Name = *req.Name
}
if req.Company != nil {
contractor.Company = *req.Company
}
if req.Phone != nil {
contractor.Phone = *req.Phone
}
if req.Email != nil {
contractor.Email = *req.Email
}
if req.Website != nil {
contractor.Website = *req.Website
}
if req.Notes != nil {
contractor.Notes = *req.Notes
}
if req.StreetAddress != nil {
contractor.StreetAddress = *req.StreetAddress
}
if req.City != nil {
contractor.City = *req.City
}
if req.StateProvince != nil {
contractor.StateProvince = *req.StateProvince
}
if req.PostalCode != nil {
contractor.PostalCode = *req.PostalCode
}
if req.Rating != nil {
contractor.Rating = req.Rating
}
if req.IsFavorite != nil {
contractor.IsFavorite = *req.IsFavorite
}
if err := s.contractorRepo.Update(contractor); err != nil {
return nil, err
}
// Update specialties if provided
if req.SpecialtyIDs != nil {
if err := s.contractorRepo.SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
return nil, err
}
}
// Reload
contractor, err = s.contractorRepo.FindByID(contractorID)
if err != nil {
return nil, err
}
resp := responses.NewContractorResponse(contractor)
return &resp, nil
}
// DeleteContractor soft-deletes a contractor
func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrContractorNotFound
}
return err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(contractor.ResidenceID, userID)
if err != nil {
return err
}
if !hasAccess {
return ErrContractorAccessDenied
}
return s.contractorRepo.Delete(contractorID)
}
// ToggleFavorite toggles the favorite status of a contractor
func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*responses.ToggleFavoriteResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(contractor.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrContractorAccessDenied
}
newStatus, err := s.contractorRepo.ToggleFavorite(contractorID)
if err != nil {
return nil, err
}
message := "Contractor removed from favorites"
if newStatus {
message = "Contractor added to favorites"
}
return &responses.ToggleFavoriteResponse{
Message: message,
IsFavorite: newStatus,
}, nil
}
// GetContractorTasks gets all tasks for a contractor
func (s *ContractorService) GetContractorTasks(contractorID, userID uint) (*responses.TaskListResponse, error) {
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(contractor.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrContractorAccessDenied
}
tasks, err := s.contractorRepo.GetTasksForContractor(contractorID)
if err != nil {
return nil, err
}
resp := responses.NewTaskListResponse(tasks)
return &resp, nil
}
// GetSpecialties returns all contractor specialties
func (s *ContractorService) GetSpecialties() ([]responses.ContractorSpecialtyResponse, error) {
specialties, err := s.contractorRepo.GetAllSpecialties()
if err != nil {
return nil, err
}
result := make([]responses.ContractorSpecialtyResponse, len(specialties))
for i, sp := range specialties {
result[i] = responses.NewContractorSpecialtyResponse(&sp)
}
return result, nil
}

View File

@@ -0,0 +1,313 @@
package services
import (
"errors"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/dto/responses"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/repositories"
)
// Document-related errors
var (
ErrDocumentNotFound = errors.New("document not found")
ErrDocumentAccessDenied = errors.New("you do not have access to this document")
)
// DocumentService handles document business logic
type DocumentService struct {
documentRepo *repositories.DocumentRepository
residenceRepo *repositories.ResidenceRepository
}
// NewDocumentService creates a new document service
func NewDocumentService(documentRepo *repositories.DocumentRepository, residenceRepo *repositories.ResidenceRepository) *DocumentService {
return &DocumentService{
documentRepo: documentRepo,
residenceRepo: residenceRepo,
}
}
// GetDocument gets a document by ID with access check
func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDocumentNotFound
}
return nil, err
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
}
resp := responses.NewDocumentResponse(document)
return &resp, nil
}
// ListDocuments lists all documents accessible to a user
func (s *DocumentService) ListDocuments(userID uint) (*responses.DocumentListResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
if len(residenceIDs) == 0 {
return &responses.DocumentListResponse{Count: 0, Results: []responses.DocumentResponse{}}, nil
}
documents, err := s.documentRepo.FindByUser(residenceIDs)
if err != nil {
return nil, err
}
resp := responses.NewDocumentListResponse(documents)
return &resp, nil
}
// ListWarranties lists all warranty documents
func (s *DocumentService) ListWarranties(userID uint) (*responses.DocumentListResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
if len(residenceIDs) == 0 {
return &responses.DocumentListResponse{Count: 0, Results: []responses.DocumentResponse{}}, nil
}
documents, err := s.documentRepo.FindWarranties(residenceIDs)
if err != nil {
return nil, err
}
resp := responses.NewDocumentListResponse(documents)
return &resp, nil
}
// CreateDocument creates a new document
func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, userID uint) (*responses.DocumentResponse, error) {
// Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
}
documentType := req.DocumentType
if documentType == "" {
documentType = models.DocumentTypeGeneral
}
document := &models.Document{
ResidenceID: req.ResidenceID,
CreatedByID: userID,
Title: req.Title,
Description: req.Description,
DocumentType: documentType,
FileURL: req.FileURL,
FileName: req.FileName,
FileSize: req.FileSize,
MimeType: req.MimeType,
PurchaseDate: req.PurchaseDate,
ExpiryDate: req.ExpiryDate,
PurchasePrice: req.PurchasePrice,
Vendor: req.Vendor,
SerialNumber: req.SerialNumber,
ModelNumber: req.ModelNumber,
TaskID: req.TaskID,
IsActive: true,
}
if err := s.documentRepo.Create(document); err != nil {
return nil, err
}
// Reload with relations
document, err = s.documentRepo.FindByID(document.ID)
if err != nil {
return nil, err
}
resp := responses.NewDocumentResponse(document)
return &resp, nil
}
// UpdateDocument updates a document
func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.UpdateDocumentRequest) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDocumentNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
}
// Apply updates
if req.Title != nil {
document.Title = *req.Title
}
if req.Description != nil {
document.Description = *req.Description
}
if req.DocumentType != nil {
document.DocumentType = *req.DocumentType
}
if req.FileURL != nil {
document.FileURL = *req.FileURL
}
if req.FileName != nil {
document.FileName = *req.FileName
}
if req.FileSize != nil {
document.FileSize = req.FileSize
}
if req.MimeType != nil {
document.MimeType = *req.MimeType
}
if req.PurchaseDate != nil {
document.PurchaseDate = req.PurchaseDate
}
if req.ExpiryDate != nil {
document.ExpiryDate = req.ExpiryDate
}
if req.PurchasePrice != nil {
document.PurchasePrice = req.PurchasePrice
}
if req.Vendor != nil {
document.Vendor = *req.Vendor
}
if req.SerialNumber != nil {
document.SerialNumber = *req.SerialNumber
}
if req.ModelNumber != nil {
document.ModelNumber = *req.ModelNumber
}
if req.TaskID != nil {
document.TaskID = req.TaskID
}
if err := s.documentRepo.Update(document); err != nil {
return nil, err
}
// Reload
document, err = s.documentRepo.FindByID(documentID)
if err != nil {
return nil, err
}
resp := responses.NewDocumentResponse(document)
return &resp, nil
}
// DeleteDocument soft-deletes a document
func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrDocumentNotFound
}
return err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return err
}
if !hasAccess {
return ErrDocumentAccessDenied
}
return s.documentRepo.Delete(documentID)
}
// ActivateDocument activates a document
func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
// First check if document exists (even if inactive)
var document models.Document
if err := s.documentRepo.FindByIDIncludingInactive(documentID, &document); err != nil {
return nil, ErrDocumentNotFound
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
}
if err := s.documentRepo.Activate(documentID); err != nil {
return nil, err
}
// Reload
doc, err := s.documentRepo.FindByID(documentID)
if err != nil {
return nil, err
}
resp := responses.NewDocumentResponse(doc)
return &resp, nil
}
// DeactivateDocument deactivates a document
func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*responses.DocumentResponse, error) {
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDocumentNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
}
if err := s.documentRepo.Deactivate(documentID); err != nil {
return nil, err
}
document.IsActive = false
resp := responses.NewDocumentResponse(document)
return &resp, nil
}

View File

@@ -0,0 +1,305 @@
package services
import (
"bytes"
"fmt"
"html/template"
"time"
"github.com/rs/zerolog/log"
"gopkg.in/gomail.v2"
"github.com/treytartt/mycrib-api/internal/config"
)
// EmailService handles sending emails
type EmailService struct {
cfg *config.EmailConfig
dialer *gomail.Dialer
}
// NewEmailService creates a new email service
func NewEmailService(cfg *config.EmailConfig) *EmailService {
dialer := gomail.NewDialer(cfg.Host, cfg.Port, cfg.User, cfg.Password)
return &EmailService{
cfg: cfg,
dialer: dialer,
}
}
// SendEmail sends an email
func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
m := gomail.NewMessage()
m.SetHeader("From", s.cfg.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", textBody)
m.AddAlternative("text/html", htmlBody)
if err := s.dialer.DialAndSend(m); err != nil {
log.Error().Err(err).Str("to", to).Str("subject", subject).Msg("Failed to send email")
return fmt.Errorf("failed to send email: %w", err)
}
log.Info().Str("to", to).Str("subject", subject).Msg("Email sent successfully")
return nil
}
// SendWelcomeEmail sends a welcome email with verification code
func (s *EmailService) SendWelcomeEmail(to, firstName, code string) error {
subject := "Welcome to MyCrib - Verify Your Email"
name := firstName
if name == "" {
name = "there"
}
htmlBody := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.header { text-align: center; padding: 20px 0; }
.code { background: #f4f4f4; padding: 20px; text-align: center; font-size: 32px; font-weight: bold; letter-spacing: 8px; border-radius: 8px; margin: 20px 0; }
.footer { text-align: center; color: #666; font-size: 12px; margin-top: 40px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>Welcome to MyCrib!</h1>
</div>
<p>Hi %s,</p>
<p>Thank you for creating a MyCrib account. To complete your registration, please verify your email address by entering the following code:</p>
<div class="code">%s</div>
<p>This code will expire in 24 hours.</p>
<p>If you didn't create a MyCrib account, you can safely ignore this email.</p>
<p>Best regards,<br>The MyCrib Team</p>
<div class="footer">
<p>&copy; %d MyCrib. All rights reserved.</p>
</div>
</div>
</body>
</html>
`, name, code, time.Now().Year())
textBody := fmt.Sprintf(`
Welcome to MyCrib!
Hi %s,
Thank you for creating a MyCrib account. To complete your registration, please verify your email address by entering the following code:
%s
This code will expire in 24 hours.
If you didn't create a MyCrib account, you can safely ignore this email.
Best regards,
The MyCrib Team
`, name, code)
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendVerificationEmail sends an email verification code
func (s *EmailService) SendVerificationEmail(to, firstName, code string) error {
subject := "MyCrib - Verify Your Email"
name := firstName
if name == "" {
name = "there"
}
htmlBody := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.code { background: #f4f4f4; padding: 20px; text-align: center; font-size: 32px; font-weight: bold; letter-spacing: 8px; border-radius: 8px; margin: 20px 0; }
.footer { text-align: center; color: #666; font-size: 12px; margin-top: 40px; }
</style>
</head>
<body>
<div class="container">
<h1>Verify Your Email</h1>
<p>Hi %s,</p>
<p>Please use the following code to verify your email address:</p>
<div class="code">%s</div>
<p>This code will expire in 24 hours.</p>
<p>If you didn't request this, you can safely ignore this email.</p>
<p>Best regards,<br>The MyCrib Team</p>
<div class="footer">
<p>&copy; %d MyCrib. All rights reserved.</p>
</div>
</div>
</body>
</html>
`, name, code, time.Now().Year())
textBody := fmt.Sprintf(`
Verify Your Email
Hi %s,
Please use the following code to verify your email address:
%s
This code will expire in 24 hours.
If you didn't request this, you can safely ignore this email.
Best regards,
The MyCrib Team
`, name, code)
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendPasswordResetEmail sends a password reset email
func (s *EmailService) SendPasswordResetEmail(to, firstName, code string) error {
subject := "MyCrib - Password Reset Request"
name := firstName
if name == "" {
name = "there"
}
htmlBody := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.code { background: #f4f4f4; padding: 20px; text-align: center; font-size: 32px; font-weight: bold; letter-spacing: 8px; border-radius: 8px; margin: 20px 0; }
.warning { background: #fff3cd; border: 1px solid #ffc107; padding: 15px; border-radius: 8px; margin: 20px 0; }
.footer { text-align: center; color: #666; font-size: 12px; margin-top: 40px; }
</style>
</head>
<body>
<div class="container">
<h1>Password Reset Request</h1>
<p>Hi %s,</p>
<p>We received a request to reset your password. Use the following code to complete the reset:</p>
<div class="code">%s</div>
<p>This code will expire in 15 minutes.</p>
<div class="warning">
<strong>Security Notice:</strong> If you didn't request a password reset, please ignore this email. Your password will remain unchanged.
</div>
<p>Best regards,<br>The MyCrib Team</p>
<div class="footer">
<p>&copy; %d MyCrib. All rights reserved.</p>
</div>
</div>
</body>
</html>
`, name, code, time.Now().Year())
textBody := fmt.Sprintf(`
Password Reset Request
Hi %s,
We received a request to reset your password. Use the following code to complete the reset:
%s
This code will expire in 15 minutes.
SECURITY NOTICE: If you didn't request a password reset, please ignore this email. Your password will remain unchanged.
Best regards,
The MyCrib Team
`, name, code)
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendPasswordChangedEmail sends a password changed confirmation email
func (s *EmailService) SendPasswordChangedEmail(to, firstName string) error {
subject := "MyCrib - Your Password Has Been Changed"
name := firstName
if name == "" {
name = "there"
}
htmlBody := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.warning { background: #fff3cd; border: 1px solid #ffc107; padding: 15px; border-radius: 8px; margin: 20px 0; }
.footer { text-align: center; color: #666; font-size: 12px; margin-top: 40px; }
</style>
</head>
<body>
<div class="container">
<h1>Password Changed</h1>
<p>Hi %s,</p>
<p>Your MyCrib password was successfully changed on %s.</p>
<div class="warning">
<strong>Didn't make this change?</strong> If you didn't change your password, please contact us immediately at support@mycrib.com or reset your password.
</div>
<p>Best regards,<br>The MyCrib Team</p>
<div class="footer">
<p>&copy; %d MyCrib. All rights reserved.</p>
</div>
</div>
</body>
</html>
`, name, time.Now().UTC().Format("January 2, 2006 at 3:04 PM UTC"), time.Now().Year())
textBody := fmt.Sprintf(`
Password Changed
Hi %s,
Your MyCrib password was successfully changed on %s.
DIDN'T MAKE THIS CHANGE? If you didn't change your password, please contact us immediately at support@mycrib.com or reset your password.
Best regards,
The MyCrib Team
`, name, time.Now().UTC().Format("January 2, 2006 at 3:04 PM UTC"))
return s.SendEmail(to, subject, htmlBody, textBody)
}
// EmailTemplate represents an email template
type EmailTemplate struct {
name string
template *template.Template
}
// ParseTemplate parses an email template from a string
func ParseTemplate(name, tmpl string) (*EmailTemplate, error) {
t, err := template.New(name).Parse(tmpl)
if err != nil {
return nil, err
}
return &EmailTemplate{name: name, template: t}, nil
}
// Execute executes the template with the given data
func (t *EmailTemplate) Execute(data interface{}) (string, error) {
var buf bytes.Buffer
if err := t.template.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,428 @@
package services
import (
"context"
"encoding/json"
"errors"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/push"
"github.com/treytartt/mycrib-api/internal/repositories"
)
// Notification-related errors
var (
ErrNotificationNotFound = errors.New("notification not found")
ErrDeviceNotFound = errors.New("device not found")
ErrInvalidPlatform = errors.New("invalid platform, must be 'ios' or 'android'")
)
// NotificationService handles notification business logic
type NotificationService struct {
notificationRepo *repositories.NotificationRepository
gorushClient *push.GorushClient
}
// NewNotificationService creates a new notification service
func NewNotificationService(notificationRepo *repositories.NotificationRepository, gorushClient *push.GorushClient) *NotificationService {
return &NotificationService{
notificationRepo: notificationRepo,
gorushClient: gorushClient,
}
}
// === Notifications ===
// GetNotifications gets notifications for a user
func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ([]NotificationResponse, error) {
notifications, err := s.notificationRepo.FindByUser(userID, limit, offset)
if err != nil {
return nil, err
}
result := make([]NotificationResponse, len(notifications))
for i, n := range notifications {
result[i] = NewNotificationResponse(&n)
}
return result, nil
}
// GetUnreadCount gets the count of unread notifications
func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
return s.notificationRepo.CountUnread(userID)
}
// MarkAsRead marks a notification as read
func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
notification, err := s.notificationRepo.FindByID(notificationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotificationNotFound
}
return err
}
if notification.UserID != userID {
return ErrNotificationNotFound
}
return s.notificationRepo.MarkAsRead(notificationID)
}
// MarkAllAsRead marks all notifications as read
func (s *NotificationService) MarkAllAsRead(userID uint) error {
return s.notificationRepo.MarkAllAsRead(userID)
}
// CreateAndSendNotification creates a notification and sends it via push
func (s *NotificationService) CreateAndSendNotification(ctx context.Context, userID uint, notificationType models.NotificationType, title, body string, data map[string]interface{}) error {
// Check user preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return err
}
// Check if notification type is enabled
if !s.isNotificationEnabled(prefs, notificationType) {
return nil // Skip silently
}
// Create notification record
dataJSON, _ := json.Marshal(data)
notification := &models.Notification{
UserID: userID,
NotificationType: notificationType,
Title: title,
Body: body,
Data: string(dataJSON),
}
if err := s.notificationRepo.Create(notification); err != nil {
return err
}
// Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
if err != nil {
return err
}
// Convert data for push
pushData := make(map[string]string)
for k, v := range data {
switch val := v.(type) {
case string:
pushData[k] = val
default:
jsonVal, _ := json.Marshal(val)
pushData[k] = string(jsonVal)
}
}
pushData["notification_id"] = string(rune(notification.ID))
// Send push notification
if s.gorushClient != nil {
err = s.gorushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData)
if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error())
return err
}
}
return s.notificationRepo.MarkAsSent(notification.ID)
}
// isNotificationEnabled checks if a notification type is enabled for user
func (s *NotificationService) isNotificationEnabled(prefs *models.NotificationPreference, notificationType models.NotificationType) bool {
switch notificationType {
case models.NotificationTaskDueSoon:
return prefs.TaskDueSoon
case models.NotificationTaskOverdue:
return prefs.TaskOverdue
case models.NotificationTaskCompleted:
return prefs.TaskCompleted
case models.NotificationTaskAssigned:
return prefs.TaskAssigned
case models.NotificationResidenceShared:
return prefs.ResidenceShared
case models.NotificationWarrantyExpiring:
return prefs.WarrantyExpiring
default:
return true
}
}
// === Notification Preferences ===
// GetPreferences gets notification preferences for a user
func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return nil, err
}
return NewNotificationPreferencesResponse(prefs), nil
}
// UpdatePreferences updates notification preferences
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return nil, err
}
if req.TaskDueSoon != nil {
prefs.TaskDueSoon = *req.TaskDueSoon
}
if req.TaskOverdue != nil {
prefs.TaskOverdue = *req.TaskOverdue
}
if req.TaskCompleted != nil {
prefs.TaskCompleted = *req.TaskCompleted
}
if req.TaskAssigned != nil {
prefs.TaskAssigned = *req.TaskAssigned
}
if req.ResidenceShared != nil {
prefs.ResidenceShared = *req.ResidenceShared
}
if req.WarrantyExpiring != nil {
prefs.WarrantyExpiring = *req.WarrantyExpiring
}
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
return nil, err
}
return NewNotificationPreferencesResponse(prefs), nil
}
// === Device Registration ===
// RegisterDevice registers a device for push notifications
func (s *NotificationService) RegisterDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
switch req.Platform {
case push.PlatformIOS:
return s.registerAPNSDevice(userID, req)
case push.PlatformAndroid:
return s.registerGCMDevice(userID, req)
default:
return nil, ErrInvalidPlatform
}
}
func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
// Check if device exists
existing, err := s.notificationRepo.FindAPNSDeviceByToken(req.RegistrationID)
if err == nil {
// Update existing device
existing.UserID = &userID
existing.Active = true
existing.Name = req.Name
existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateAPNSDevice(existing); err != nil {
return nil, err
}
return NewAPNSDeviceResponse(existing), nil
}
// Create new device
device := &models.APNSDevice{
UserID: &userID,
Name: req.Name,
DeviceID: req.DeviceID,
RegistrationID: req.RegistrationID,
Active: true,
}
if err := s.notificationRepo.CreateAPNSDevice(device); err != nil {
return nil, err
}
return NewAPNSDeviceResponse(device), nil
}
func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDeviceRequest) (*DeviceResponse, error) {
// Check if device exists
existing, err := s.notificationRepo.FindGCMDeviceByToken(req.RegistrationID)
if err == nil {
// Update existing device
existing.UserID = &userID
existing.Active = true
existing.Name = req.Name
existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateGCMDevice(existing); err != nil {
return nil, err
}
return NewGCMDeviceResponse(existing), nil
}
// Create new device
device := &models.GCMDevice{
UserID: &userID,
Name: req.Name,
DeviceID: req.DeviceID,
RegistrationID: req.RegistrationID,
CloudMessageType: "FCM",
Active: true,
}
if err := s.notificationRepo.CreateGCMDevice(device); err != nil {
return nil, err
}
return NewGCMDeviceResponse(device), nil
}
// ListDevices lists all devices for a user
func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) {
iosDevices, err := s.notificationRepo.FindAPNSDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
androidDevices, err := s.notificationRepo.FindGCMDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
result := make([]DeviceResponse, 0, len(iosDevices)+len(androidDevices))
for _, d := range iosDevices {
result = append(result, *NewAPNSDeviceResponse(&d))
}
for _, d := range androidDevices {
result = append(result, *NewGCMDeviceResponse(&d))
}
return result, nil
}
// DeleteDevice deletes a device
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error {
switch platform {
case push.PlatformIOS:
return s.notificationRepo.DeactivateAPNSDevice(deviceID)
case push.PlatformAndroid:
return s.notificationRepo.DeactivateGCMDevice(deviceID)
default:
return ErrInvalidPlatform
}
}
// === Response/Request Types ===
// NotificationResponse represents a notification in API response
type NotificationResponse struct {
ID uint `json:"id"`
UserID uint `json:"user_id"`
NotificationType models.NotificationType `json:"notification_type"`
Title string `json:"title"`
Body string `json:"body"`
Data map[string]interface{} `json:"data"`
Read bool `json:"read"`
ReadAt *string `json:"read_at"`
Sent bool `json:"sent"`
SentAt *string `json:"sent_at"`
CreatedAt string `json:"created_at"`
}
// NewNotificationResponse creates a NotificationResponse
func NewNotificationResponse(n *models.Notification) NotificationResponse {
resp := NotificationResponse{
ID: n.ID,
UserID: n.UserID,
NotificationType: n.NotificationType,
Title: n.Title,
Body: n.Body,
Read: n.Read,
Sent: n.Sent,
CreatedAt: n.CreatedAt.Format("2006-01-02T15:04:05Z"),
}
if n.Data != "" {
json.Unmarshal([]byte(n.Data), &resp.Data)
}
if n.ReadAt != nil {
t := n.ReadAt.Format("2006-01-02T15:04:05Z")
resp.ReadAt = &t
}
if n.SentAt != nil {
t := n.SentAt.Format("2006-01-02T15:04:05Z")
resp.SentAt = &t
}
return resp
}
// NotificationPreferencesResponse represents notification preferences
type NotificationPreferencesResponse struct {
TaskDueSoon bool `json:"task_due_soon"`
TaskOverdue bool `json:"task_overdue"`
TaskCompleted bool `json:"task_completed"`
TaskAssigned bool `json:"task_assigned"`
ResidenceShared bool `json:"residence_shared"`
WarrantyExpiring bool `json:"warranty_expiring"`
}
// NewNotificationPreferencesResponse creates a NotificationPreferencesResponse
func NewNotificationPreferencesResponse(p *models.NotificationPreference) *NotificationPreferencesResponse {
return &NotificationPreferencesResponse{
TaskDueSoon: p.TaskDueSoon,
TaskOverdue: p.TaskOverdue,
TaskCompleted: p.TaskCompleted,
TaskAssigned: p.TaskAssigned,
ResidenceShared: p.ResidenceShared,
WarrantyExpiring: p.WarrantyExpiring,
}
}
// UpdatePreferencesRequest represents preferences update request
type UpdatePreferencesRequest struct {
TaskDueSoon *bool `json:"task_due_soon"`
TaskOverdue *bool `json:"task_overdue"`
TaskCompleted *bool `json:"task_completed"`
TaskAssigned *bool `json:"task_assigned"`
ResidenceShared *bool `json:"residence_shared"`
WarrantyExpiring *bool `json:"warranty_expiring"`
}
// DeviceResponse represents a device in API response
type DeviceResponse struct {
ID uint `json:"id"`
Name string `json:"name"`
DeviceID string `json:"device_id"`
RegistrationID string `json:"registration_id"`
Platform string `json:"platform"`
Active bool `json:"active"`
DateCreated string `json:"date_created"`
}
// NewAPNSDeviceResponse creates a DeviceResponse from APNS device
func NewAPNSDeviceResponse(d *models.APNSDevice) *DeviceResponse {
return &DeviceResponse{
ID: d.ID,
Name: d.Name,
DeviceID: d.DeviceID,
RegistrationID: d.RegistrationID,
Platform: push.PlatformIOS,
Active: d.Active,
DateCreated: d.DateCreated.Format("2006-01-02T15:04:05Z"),
}
}
// NewGCMDeviceResponse creates a DeviceResponse from GCM device
func NewGCMDeviceResponse(d *models.GCMDevice) *DeviceResponse {
return &DeviceResponse{
ID: d.ID,
Name: d.Name,
DeviceID: d.DeviceID,
RegistrationID: d.RegistrationID,
Platform: push.PlatformAndroid,
Active: d.Active,
DateCreated: d.DateCreated.Format("2006-01-02T15:04:05Z"),
}
}
// RegisterDeviceRequest represents device registration request
type RegisterDeviceRequest struct {
Name string `json:"name"`
DeviceID string `json:"device_id" binding:"required"`
RegistrationID string `json:"registration_id" binding:"required"`
Platform string `json:"platform" binding:"required,oneof=ios android"`
}

View File

@@ -0,0 +1,381 @@
package services
import (
"errors"
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/config"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/dto/responses"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/repositories"
)
// Common errors
var (
ErrResidenceNotFound = errors.New("residence not found")
ErrResidenceAccessDenied = errors.New("you do not have access to this residence")
ErrNotResidenceOwner = errors.New("only the residence owner can perform this action")
ErrCannotRemoveOwner = errors.New("cannot remove the owner from the residence")
ErrUserAlreadyMember = errors.New("user is already a member of this residence")
ErrShareCodeInvalid = errors.New("invalid or expired share code")
ErrShareCodeExpired = errors.New("share code has expired")
ErrPropertiesLimitReached = errors.New("you have reached the maximum number of properties for your subscription tier")
)
// ResidenceService handles residence business logic
type ResidenceService struct {
residenceRepo *repositories.ResidenceRepository
userRepo *repositories.UserRepository
config *config.Config
}
// NewResidenceService creates a new residence service
func NewResidenceService(residenceRepo *repositories.ResidenceRepository, userRepo *repositories.UserRepository, cfg *config.Config) *ResidenceService {
return &ResidenceService{
residenceRepo: residenceRepo,
userRepo: userRepo,
config: cfg,
}
}
// GetResidence gets a residence by ID with access check
func (s *ResidenceService) GetResidence(residenceID, userID uint) (*responses.ResidenceResponse, error) {
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
}
residence, err := s.residenceRepo.FindByID(residenceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrResidenceNotFound
}
return nil, err
}
resp := responses.NewResidenceResponse(residence)
return &resp, nil
}
// ListResidences lists all residences accessible to a user
func (s *ResidenceService) ListResidences(userID uint) (*responses.ResidenceListResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
resp := responses.NewResidenceListResponse(residences)
return &resp, nil
}
// GetMyResidences returns residences with additional details (tasks, completions, etc.)
// This is the "my-residences" endpoint that returns richer data
func (s *ResidenceService) GetMyResidences(userID uint) (*responses.ResidenceListResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
// TODO: In Phase 4, this will include tasks and completions
resp := responses.NewResidenceListResponse(residences)
return &resp, nil
}
// CreateResidence creates a new residence
func (s *ResidenceService) CreateResidence(req *requests.CreateResidenceRequest, ownerID uint) (*responses.ResidenceResponse, error) {
// TODO: Check subscription tier limits
// count, err := s.residenceRepo.CountByOwner(ownerID)
// if err != nil {
// return nil, err
// }
// Check against tier limits...
isPrimary := true
if req.IsPrimary != nil {
isPrimary = *req.IsPrimary
}
// Set default country if not provided
country := req.Country
if country == "" {
country = "USA"
}
residence := &models.Residence{
OwnerID: ownerID,
Name: req.Name,
PropertyTypeID: req.PropertyTypeID,
StreetAddress: req.StreetAddress,
ApartmentUnit: req.ApartmentUnit,
City: req.City,
StateProvince: req.StateProvince,
PostalCode: req.PostalCode,
Country: country,
Bedrooms: req.Bedrooms,
Bathrooms: req.Bathrooms,
SquareFootage: req.SquareFootage,
LotSize: req.LotSize,
YearBuilt: req.YearBuilt,
Description: req.Description,
PurchaseDate: req.PurchaseDate,
PurchasePrice: req.PurchasePrice,
IsPrimary: isPrimary,
IsActive: true,
}
if err := s.residenceRepo.Create(residence); err != nil {
return nil, err
}
// Reload with relations
residence, err := s.residenceRepo.FindByID(residence.ID)
if err != nil {
return nil, err
}
resp := responses.NewResidenceResponse(residence)
return &resp, nil
}
// UpdateResidence updates a residence
func (s *ResidenceService) UpdateResidence(residenceID, userID uint, req *requests.UpdateResidenceRequest) (*responses.ResidenceResponse, error) {
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return nil, err
}
if !isOwner {
return nil, ErrNotResidenceOwner
}
residence, err := s.residenceRepo.FindByID(residenceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrResidenceNotFound
}
return nil, err
}
// Apply updates (only non-nil fields)
if req.Name != nil {
residence.Name = *req.Name
}
if req.PropertyTypeID != nil {
residence.PropertyTypeID = req.PropertyTypeID
}
if req.StreetAddress != nil {
residence.StreetAddress = *req.StreetAddress
}
if req.ApartmentUnit != nil {
residence.ApartmentUnit = *req.ApartmentUnit
}
if req.City != nil {
residence.City = *req.City
}
if req.StateProvince != nil {
residence.StateProvince = *req.StateProvince
}
if req.PostalCode != nil {
residence.PostalCode = *req.PostalCode
}
if req.Country != nil {
residence.Country = *req.Country
}
if req.Bedrooms != nil {
residence.Bedrooms = req.Bedrooms
}
if req.Bathrooms != nil {
residence.Bathrooms = req.Bathrooms
}
if req.SquareFootage != nil {
residence.SquareFootage = req.SquareFootage
}
if req.LotSize != nil {
residence.LotSize = req.LotSize
}
if req.YearBuilt != nil {
residence.YearBuilt = req.YearBuilt
}
if req.Description != nil {
residence.Description = *req.Description
}
if req.PurchaseDate != nil {
residence.PurchaseDate = req.PurchaseDate
}
if req.PurchasePrice != nil {
residence.PurchasePrice = req.PurchasePrice
}
if req.IsPrimary != nil {
residence.IsPrimary = *req.IsPrimary
}
if err := s.residenceRepo.Update(residence); err != nil {
return nil, err
}
// Reload with relations
residence, err = s.residenceRepo.FindByID(residence.ID)
if err != nil {
return nil, err
}
resp := responses.NewResidenceResponse(residence)
return &resp, nil
}
// DeleteResidence soft-deletes a residence
func (s *ResidenceService) DeleteResidence(residenceID, userID uint) error {
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return err
}
if !isOwner {
return ErrNotResidenceOwner
}
return s.residenceRepo.Delete(residenceID)
}
// GenerateShareCode generates a new share code for a residence
func (s *ResidenceService) GenerateShareCode(residenceID, userID uint, expiresInHours int) (*responses.GenerateShareCodeResponse, error) {
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return nil, err
}
if !isOwner {
return nil, ErrNotResidenceOwner
}
// Default to 24 hours if not specified
if expiresInHours <= 0 {
expiresInHours = 24
}
shareCode, err := s.residenceRepo.CreateShareCode(residenceID, userID, time.Duration(expiresInHours)*time.Hour)
if err != nil {
return nil, err
}
return &responses.GenerateShareCodeResponse{
Message: "Share code generated successfully",
ShareCode: responses.NewShareCodeResponse(shareCode),
}, nil
}
// JoinWithCode allows a user to join a residence using a share code
func (s *ResidenceService) JoinWithCode(code string, userID uint) (*responses.JoinResidenceResponse, error) {
// Find the share code
shareCode, err := s.residenceRepo.FindShareCodeByCode(code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrShareCodeInvalid
}
return nil, err
}
// Check if already a member
hasAccess, err := s.residenceRepo.HasAccess(shareCode.ResidenceID, userID)
if err != nil {
return nil, err
}
if hasAccess {
return nil, ErrUserAlreadyMember
}
// Add user to residence
if err := s.residenceRepo.AddUser(shareCode.ResidenceID, userID); err != nil {
return nil, err
}
// Get the residence with full details
residence, err := s.residenceRepo.FindByID(shareCode.ResidenceID)
if err != nil {
return nil, err
}
return &responses.JoinResidenceResponse{
Message: "Successfully joined residence",
Residence: responses.NewResidenceResponse(residence),
}, nil
}
// GetResidenceUsers returns all users with access to a residence
func (s *ResidenceService) GetResidenceUsers(residenceID, userID uint) ([]responses.ResidenceUserResponse, error) {
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
}
users, err := s.residenceRepo.GetResidenceUsers(residenceID)
if err != nil {
return nil, err
}
result := make([]responses.ResidenceUserResponse, len(users))
for i, user := range users {
result[i] = *responses.NewResidenceUserResponse(&user)
}
return result, nil
}
// RemoveUser removes a user from a residence (owner only)
func (s *ResidenceService) RemoveUser(residenceID, userIDToRemove, requestingUserID uint) error {
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, requestingUserID)
if err != nil {
return err
}
if !isOwner {
return ErrNotResidenceOwner
}
// Cannot remove the owner
if userIDToRemove == requestingUserID {
return ErrCannotRemoveOwner
}
// Check if the residence exists
residence, err := s.residenceRepo.FindByIDSimple(residenceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrResidenceNotFound
}
return err
}
// Cannot remove the owner
if userIDToRemove == residence.OwnerID {
return ErrCannotRemoveOwner
}
return s.residenceRepo.RemoveUser(residenceID, userIDToRemove)
}
// GetResidenceTypes returns all residence types
func (s *ResidenceService) GetResidenceTypes() ([]responses.ResidenceTypeResponse, error) {
types, err := s.residenceRepo.GetAllResidenceTypes()
if err != nil {
return nil, err
}
result := make([]responses.ResidenceTypeResponse, len(types))
for i, t := range types {
result[i] = *responses.NewResidenceTypeResponse(&t)
}
return result, nil
}

View File

@@ -0,0 +1,417 @@
package services
import (
"errors"
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/repositories"
)
// Subscription-related errors
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrPropertiesLimitExceeded = errors.New("properties limit exceeded for your subscription tier")
ErrTasksLimitExceeded = errors.New("tasks limit exceeded for your subscription tier")
ErrContractorsLimitExceeded = errors.New("contractors limit exceeded for your subscription tier")
ErrDocumentsLimitExceeded = errors.New("documents limit exceeded for your subscription tier")
ErrUpgradeTriggerNotFound = errors.New("upgrade trigger not found")
ErrPromotionNotFound = errors.New("promotion not found")
)
// SubscriptionService handles subscription business logic
type SubscriptionService struct {
subscriptionRepo *repositories.SubscriptionRepository
residenceRepo *repositories.ResidenceRepository
taskRepo *repositories.TaskRepository
contractorRepo *repositories.ContractorRepository
documentRepo *repositories.DocumentRepository
}
// NewSubscriptionService creates a new subscription service
func NewSubscriptionService(
subscriptionRepo *repositories.SubscriptionRepository,
residenceRepo *repositories.ResidenceRepository,
taskRepo *repositories.TaskRepository,
contractorRepo *repositories.ContractorRepository,
documentRepo *repositories.DocumentRepository,
) *SubscriptionService {
return &SubscriptionService{
subscriptionRepo: subscriptionRepo,
residenceRepo: residenceRepo,
taskRepo: taskRepo,
contractorRepo: contractorRepo,
documentRepo: documentRepo,
}
}
// GetSubscription gets the subscription for a user
func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return nil, err
}
return NewSubscriptionResponse(sub), nil
}
// GetSubscriptionStatus gets detailed subscription status including limits
func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionStatusResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return nil, err
}
settings, err := s.subscriptionRepo.GetSettings()
if err != nil {
return nil, err
}
limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier)
if err != nil {
return nil, err
}
// Get current usage if limitations are enabled
var usage *UsageResponse
if settings.EnableLimitations {
usage, err = s.getUserUsage(userID)
if err != nil {
return nil, err
}
}
return &SubscriptionStatusResponse{
Subscription: NewSubscriptionResponse(sub),
Limits: NewTierLimitsResponse(limits),
Usage: usage,
LimitationsEnabled: settings.EnableLimitations,
}, nil
}
// getUserUsage calculates current usage for a user
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
residences, err := s.residenceRepo.FindOwnedByUser(userID)
if err != nil {
return nil, err
}
propertiesCount := int64(len(residences))
// Count tasks, contractors, and documents across all user's residences
var tasksCount, contractorsCount, documentsCount int64
for _, r := range residences {
tc, err := s.taskRepo.CountByResidence(r.ID)
if err != nil {
return nil, err
}
tasksCount += tc
cc, err := s.contractorRepo.CountByResidence(r.ID)
if err != nil {
return nil, err
}
contractorsCount += cc
dc, err := s.documentRepo.CountByResidence(r.ID)
if err != nil {
return nil, err
}
documentsCount += dc
}
return &UsageResponse{
Properties: propertiesCount,
Tasks: tasksCount,
Contractors: contractorsCount,
Documents: documentsCount,
}, nil
}
// CheckLimit checks if a user has exceeded a specific limit
func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
settings, err := s.subscriptionRepo.GetSettings()
if err != nil {
return err
}
// If limitations are disabled, allow everything
if !settings.EnableLimitations {
return nil
}
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return err
}
// Pro users have unlimited access
if sub.IsPro() {
return nil
}
limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier)
if err != nil {
return err
}
usage, err := s.getUserUsage(userID)
if err != nil {
return err
}
switch limitType {
case "properties":
if limits.PropertiesLimit != nil && usage.Properties >= int64(*limits.PropertiesLimit) {
return ErrPropertiesLimitExceeded
}
case "tasks":
if limits.TasksLimit != nil && usage.Tasks >= int64(*limits.TasksLimit) {
return ErrTasksLimitExceeded
}
case "contractors":
if limits.ContractorsLimit != nil && usage.Contractors >= int64(*limits.ContractorsLimit) {
return ErrContractorsLimitExceeded
}
case "documents":
if limits.DocumentsLimit != nil && usage.Documents >= int64(*limits.DocumentsLimit) {
return ErrDocumentsLimitExceeded
}
}
return nil
}
// GetUpgradeTrigger gets an upgrade trigger by key
func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResponse, error) {
trigger, err := s.subscriptionRepo.GetUpgradeTrigger(key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUpgradeTriggerNotFound
}
return nil, err
}
return NewUpgradeTriggerResponse(trigger), nil
}
// GetFeatureBenefits gets all feature benefits
func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, error) {
benefits, err := s.subscriptionRepo.GetFeatureBenefits()
if err != nil {
return nil, err
}
result := make([]FeatureBenefitResponse, len(benefits))
for i, b := range benefits {
result[i] = *NewFeatureBenefitResponse(&b)
}
return result, nil
}
// GetActivePromotions gets active promotions for a user
func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return nil, err
}
promotions, err := s.subscriptionRepo.GetActivePromotions(sub.Tier)
if err != nil {
return nil, err
}
result := make([]PromotionResponse, len(promotions))
for i, p := range promotions {
result[i] = *NewPromotionResponse(&p)
}
return result, nil
}
// ProcessApplePurchase processes an Apple IAP purchase
func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData string) (*SubscriptionResponse, error) {
// TODO: Implement receipt validation with Apple's servers
// For now, just upgrade the user
// Store receipt data
if err := s.subscriptionRepo.UpdateReceiptData(userID, receiptData); err != nil {
return nil, err
}
// Upgrade to Pro (1 year from now - adjust based on actual subscription)
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return nil, err
}
return s.GetSubscription(userID)
}
// ProcessGooglePurchase processes a Google Play purchase
func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken string) (*SubscriptionResponse, error) {
// TODO: Implement token validation with Google's servers
// For now, just upgrade the user
// Store purchase token
if err := s.subscriptionRepo.UpdatePurchaseToken(userID, purchaseToken); err != nil {
return nil, err
}
// Upgrade to Pro (1 year from now - adjust based on actual subscription)
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil {
return nil, err
}
return s.GetSubscription(userID)
}
// CancelSubscription cancels a subscription (downgrades to free at end of period)
func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResponse, error) {
if err := s.subscriptionRepo.SetAutoRenew(userID, false); err != nil {
return nil, err
}
return s.GetSubscription(userID)
}
// === Response Types ===
// SubscriptionResponse represents a subscription in API response
type SubscriptionResponse struct {
Tier string `json:"tier"`
SubscribedAt *string `json:"subscribed_at"`
ExpiresAt *string `json:"expires_at"`
AutoRenew bool `json:"auto_renew"`
CancelledAt *string `json:"cancelled_at"`
Platform string `json:"platform"`
IsActive bool `json:"is_active"`
IsPro bool `json:"is_pro"`
}
// NewSubscriptionResponse creates a SubscriptionResponse from a model
func NewSubscriptionResponse(s *models.UserSubscription) *SubscriptionResponse {
resp := &SubscriptionResponse{
Tier: string(s.Tier),
AutoRenew: s.AutoRenew,
Platform: s.Platform,
IsActive: s.IsActive(),
IsPro: s.IsPro(),
}
if s.SubscribedAt != nil {
t := s.SubscribedAt.Format("2006-01-02T15:04:05Z")
resp.SubscribedAt = &t
}
if s.ExpiresAt != nil {
t := s.ExpiresAt.Format("2006-01-02T15:04:05Z")
resp.ExpiresAt = &t
}
if s.CancelledAt != nil {
t := s.CancelledAt.Format("2006-01-02T15:04:05Z")
resp.CancelledAt = &t
}
return resp
}
// TierLimitsResponse represents tier limits
type TierLimitsResponse struct {
Tier string `json:"tier"`
PropertiesLimit *int `json:"properties_limit"`
TasksLimit *int `json:"tasks_limit"`
ContractorsLimit *int `json:"contractors_limit"`
DocumentsLimit *int `json:"documents_limit"`
}
// NewTierLimitsResponse creates a TierLimitsResponse from a model
func NewTierLimitsResponse(l *models.TierLimits) *TierLimitsResponse {
return &TierLimitsResponse{
Tier: string(l.Tier),
PropertiesLimit: l.PropertiesLimit,
TasksLimit: l.TasksLimit,
ContractorsLimit: l.ContractorsLimit,
DocumentsLimit: l.DocumentsLimit,
}
}
// UsageResponse represents current usage
type UsageResponse struct {
Properties int64 `json:"properties"`
Tasks int64 `json:"tasks"`
Contractors int64 `json:"contractors"`
Documents int64 `json:"documents"`
}
// SubscriptionStatusResponse represents full subscription status
type SubscriptionStatusResponse struct {
Subscription *SubscriptionResponse `json:"subscription"`
Limits *TierLimitsResponse `json:"limits"`
Usage *UsageResponse `json:"usage,omitempty"`
LimitationsEnabled bool `json:"limitations_enabled"`
}
// UpgradeTriggerResponse represents an upgrade trigger
type UpgradeTriggerResponse struct {
TriggerKey string `json:"trigger_key"`
Title string `json:"title"`
Message string `json:"message"`
PromoHTML string `json:"promo_html"`
ButtonText string `json:"button_text"`
}
// NewUpgradeTriggerResponse creates an UpgradeTriggerResponse from a model
func NewUpgradeTriggerResponse(t *models.UpgradeTrigger) *UpgradeTriggerResponse {
return &UpgradeTriggerResponse{
TriggerKey: t.TriggerKey,
Title: t.Title,
Message: t.Message,
PromoHTML: t.PromoHTML,
ButtonText: t.ButtonText,
}
}
// FeatureBenefitResponse represents a feature benefit
type FeatureBenefitResponse struct {
FeatureName string `json:"feature_name"`
FreeTierText string `json:"free_tier_text"`
ProTierText string `json:"pro_tier_text"`
DisplayOrder int `json:"display_order"`
}
// NewFeatureBenefitResponse creates a FeatureBenefitResponse from a model
func NewFeatureBenefitResponse(f *models.FeatureBenefit) *FeatureBenefitResponse {
return &FeatureBenefitResponse{
FeatureName: f.FeatureName,
FreeTierText: f.FreeTierText,
ProTierText: f.ProTierText,
DisplayOrder: f.DisplayOrder,
}
}
// PromotionResponse represents a promotion
type PromotionResponse struct {
PromotionID string `json:"promotion_id"`
Title string `json:"title"`
Message string `json:"message"`
Link *string `json:"link"`
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
}
// NewPromotionResponse creates a PromotionResponse from a model
func NewPromotionResponse(p *models.Promotion) *PromotionResponse {
return &PromotionResponse{
PromotionID: p.PromotionID,
Title: p.Title,
Message: p.Message,
Link: p.Link,
StartDate: p.StartDate.Format("2006-01-02"),
EndDate: p.EndDate.Format("2006-01-02"),
}
}
// === Request Types ===
// ProcessPurchaseRequest represents an IAP purchase request
type ProcessPurchaseRequest struct {
ReceiptData string `json:"receipt_data"` // iOS
PurchaseToken string `json:"purchase_token"` // Android
Platform string `json:"platform" binding:"required,oneof=ios android"`
}

View File

@@ -0,0 +1,601 @@
package services
import (
"errors"
"time"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/dto/requests"
"github.com/treytartt/mycrib-api/internal/dto/responses"
"github.com/treytartt/mycrib-api/internal/models"
"github.com/treytartt/mycrib-api/internal/repositories"
)
// Task-related errors
var (
ErrTaskNotFound = errors.New("task not found")
ErrTaskAccessDenied = errors.New("you do not have access to this task")
ErrTaskAlreadyCancelled = errors.New("task is already cancelled")
ErrTaskAlreadyArchived = errors.New("task is already archived")
ErrCompletionNotFound = errors.New("task completion not found")
)
// TaskService handles task business logic
type TaskService struct {
taskRepo *repositories.TaskRepository
residenceRepo *repositories.ResidenceRepository
}
// NewTaskService creates a new task service
func NewTaskService(taskRepo *repositories.TaskRepository, residenceRepo *repositories.ResidenceRepository) *TaskService {
return &TaskService{
taskRepo: taskRepo,
residenceRepo: residenceRepo,
}
}
// === Task CRUD ===
// GetTask gets a task by ID with access check
func (s *TaskService) GetTask(taskID, userID uint) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// ListTasks lists all tasks accessible to a user
func (s *TaskService) ListTasks(userID uint) (*responses.TaskListResponse, error) {
// Get all residence IDs accessible to user
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
if len(residenceIDs) == 0 {
return &responses.TaskListResponse{Count: 0, Results: []responses.TaskResponse{}}, nil
}
tasks, err := s.taskRepo.FindByUser(userID, residenceIDs)
if err != nil {
return nil, err
}
resp := responses.NewTaskListResponse(tasks)
return &resp, nil
}
// GetTasksByResidence gets tasks for a specific residence (kanban board)
func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshold int) (*responses.KanbanBoardResponse, error) {
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
}
if daysThreshold <= 0 {
daysThreshold = 30 // Default
}
board, err := s.taskRepo.GetKanbanData(residenceID, daysThreshold)
if err != nil {
return nil, err
}
resp := responses.NewKanbanBoardResponse(board, residenceID)
return &resp, nil
}
// CreateTask creates a new task
func (s *TaskService) CreateTask(req *requests.CreateTaskRequest, userID uint) (*responses.TaskResponse, error) {
// Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
}
task := &models.Task{
ResidenceID: req.ResidenceID,
CreatedByID: userID,
Title: req.Title,
Description: req.Description,
CategoryID: req.CategoryID,
PriorityID: req.PriorityID,
StatusID: req.StatusID,
FrequencyID: req.FrequencyID,
AssignedToID: req.AssignedToID,
DueDate: req.DueDate,
EstimatedCost: req.EstimatedCost,
ContractorID: req.ContractorID,
}
if err := s.taskRepo.Create(task); err != nil {
return nil, err
}
// Reload with relations
task, err = s.taskRepo.FindByID(task.ID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// UpdateTask updates a task
func (s *TaskService) UpdateTask(taskID, userID uint, req *requests.UpdateTaskRequest) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
// Apply updates
if req.Title != nil {
task.Title = *req.Title
}
if req.Description != nil {
task.Description = *req.Description
}
if req.CategoryID != nil {
task.CategoryID = req.CategoryID
}
if req.PriorityID != nil {
task.PriorityID = req.PriorityID
}
if req.StatusID != nil {
task.StatusID = req.StatusID
}
if req.FrequencyID != nil {
task.FrequencyID = req.FrequencyID
}
if req.AssignedToID != nil {
task.AssignedToID = req.AssignedToID
}
if req.DueDate != nil {
task.DueDate = req.DueDate
}
if req.EstimatedCost != nil {
task.EstimatedCost = req.EstimatedCost
}
if req.ActualCost != nil {
task.ActualCost = req.ActualCost
}
if req.ContractorID != nil {
task.ContractorID = req.ContractorID
}
if err := s.taskRepo.Update(task); err != nil {
return nil, err
}
// Reload
task, err = s.taskRepo.FindByID(task.ID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// DeleteTask deletes a task
func (s *TaskService) DeleteTask(taskID, userID uint) error {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrTaskNotFound
}
return err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return err
}
if !hasAccess {
return ErrTaskAccessDenied
}
return s.taskRepo.Delete(taskID)
}
// === Task Actions ===
// MarkInProgress marks a task as in progress
func (s *TaskService) MarkInProgress(taskID, userID uint) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
// Find "In Progress" status
status, err := s.taskRepo.FindStatusByName("In Progress")
if err != nil {
return nil, err
}
if err := s.taskRepo.MarkInProgress(taskID, status.ID); err != nil {
return nil, err
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// CancelTask cancels a task
func (s *TaskService) CancelTask(taskID, userID uint) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
if task.IsCancelled {
return nil, ErrTaskAlreadyCancelled
}
if err := s.taskRepo.Cancel(taskID); err != nil {
return nil, err
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// UncancelTask uncancels a task
func (s *TaskService) UncancelTask(taskID, userID uint) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
if err := s.taskRepo.Uncancel(taskID); err != nil {
return nil, err
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// ArchiveTask archives a task
func (s *TaskService) ArchiveTask(taskID, userID uint) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
if task.IsArchived {
return nil, ErrTaskAlreadyArchived
}
if err := s.taskRepo.Archive(taskID); err != nil {
return nil, err
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// UnarchiveTask unarchives a task
func (s *TaskService) UnarchiveTask(taskID, userID uint) (*responses.TaskResponse, error) {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
if err := s.taskRepo.Unarchive(taskID); err != nil {
return nil, err
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
}
resp := responses.NewTaskResponse(task)
return &resp, nil
}
// === Task Completions ===
// CreateCompletion creates a task completion
func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest, userID uint) (*responses.TaskCompletionResponse, error) {
// Get the task
task, err := s.taskRepo.FindByID(req.TaskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
}
return nil, err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
completedAt := time.Now().UTC()
if req.CompletedAt != nil {
completedAt = *req.CompletedAt
}
completion := &models.TaskCompletion{
TaskID: req.TaskID,
CompletedByID: userID,
CompletedAt: completedAt,
Notes: req.Notes,
ActualCost: req.ActualCost,
PhotoURL: req.PhotoURL,
}
if err := s.taskRepo.CreateCompletion(completion); err != nil {
return nil, err
}
// Reload
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
if err != nil {
return nil, err
}
resp := responses.NewTaskCompletionResponse(completion)
return &resp, nil
}
// GetCompletion gets a task completion by ID
func (s *TaskService) GetCompletion(completionID, userID uint) (*responses.TaskCompletionResponse, error) {
completion, err := s.taskRepo.FindCompletionByID(completionID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCompletionNotFound
}
return nil, err
}
// Check access via task's residence
hasAccess, err := s.residenceRepo.HasAccess(completion.Task.ResidenceID, userID)
if err != nil {
return nil, err
}
if !hasAccess {
return nil, ErrTaskAccessDenied
}
resp := responses.NewTaskCompletionResponse(completion)
return &resp, nil
}
// ListCompletions lists all task completions for a user
func (s *TaskService) ListCompletions(userID uint) (*responses.TaskCompletionListResponse, error) {
// Get all residence IDs
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
}
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
if len(residenceIDs) == 0 {
return &responses.TaskCompletionListResponse{Count: 0, Results: []responses.TaskCompletionResponse{}}, nil
}
completions, err := s.taskRepo.FindCompletionsByUser(userID, residenceIDs)
if err != nil {
return nil, err
}
resp := responses.NewTaskCompletionListResponse(completions)
return &resp, nil
}
// DeleteCompletion deletes a task completion
func (s *TaskService) DeleteCompletion(completionID, userID uint) error {
completion, err := s.taskRepo.FindCompletionByID(completionID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrCompletionNotFound
}
return err
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(completion.Task.ResidenceID, userID)
if err != nil {
return err
}
if !hasAccess {
return ErrTaskAccessDenied
}
return s.taskRepo.DeleteCompletion(completionID)
}
// === Lookups ===
// GetCategories returns all task categories
func (s *TaskService) GetCategories() ([]responses.TaskCategoryResponse, error) {
categories, err := s.taskRepo.GetAllCategories()
if err != nil {
return nil, err
}
result := make([]responses.TaskCategoryResponse, len(categories))
for i, c := range categories {
result[i] = *responses.NewTaskCategoryResponse(&c)
}
return result, nil
}
// GetPriorities returns all task priorities
func (s *TaskService) GetPriorities() ([]responses.TaskPriorityResponse, error) {
priorities, err := s.taskRepo.GetAllPriorities()
if err != nil {
return nil, err
}
result := make([]responses.TaskPriorityResponse, len(priorities))
for i, p := range priorities {
result[i] = *responses.NewTaskPriorityResponse(&p)
}
return result, nil
}
// GetStatuses returns all task statuses
func (s *TaskService) GetStatuses() ([]responses.TaskStatusResponse, error) {
statuses, err := s.taskRepo.GetAllStatuses()
if err != nil {
return nil, err
}
result := make([]responses.TaskStatusResponse, len(statuses))
for i, st := range statuses {
result[i] = *responses.NewTaskStatusResponse(&st)
}
return result, nil
}
// GetFrequencies returns all task frequencies
func (s *TaskService) GetFrequencies() ([]responses.TaskFrequencyResponse, error) {
frequencies, err := s.taskRepo.GetAllFrequencies()
if err != nil {
return nil, err
}
result := make([]responses.TaskFrequencyResponse, len(frequencies))
for i, f := range frequencies {
result[i] = *responses.NewTaskFrequencyResponse(&f)
}
return result, nil
}

View File

@@ -0,0 +1,117 @@
package jobs
import (
"context"
"encoding/json"
"fmt"
"github.com/hibiken/asynq"
"github.com/rs/zerolog/log"
"github.com/treytartt/mycrib-api/internal/services"
"github.com/treytartt/mycrib-api/internal/worker"
)
// EmailJobHandler handles email-related background jobs
type EmailJobHandler struct {
emailService *services.EmailService
}
// NewEmailJobHandler creates a new email job handler
func NewEmailJobHandler(emailService *services.EmailService) *EmailJobHandler {
return &EmailJobHandler{
emailService: emailService,
}
}
// RegisterHandlers registers all email job handlers with the mux
func (h *EmailJobHandler) RegisterHandlers(mux *asynq.ServeMux) {
mux.HandleFunc(worker.TypeWelcomeEmail, h.HandleWelcomeEmail)
mux.HandleFunc(worker.TypeVerificationEmail, h.HandleVerificationEmail)
mux.HandleFunc(worker.TypePasswordResetEmail, h.HandlePasswordResetEmail)
mux.HandleFunc(worker.TypePasswordChangedEmail, h.HandlePasswordChangedEmail)
}
// HandleWelcomeEmail handles the welcome email task
func (h *EmailJobHandler) HandleWelcomeEmail(ctx context.Context, t *asynq.Task) error {
var p worker.WelcomeEmailPayload
if err := json.Unmarshal(t.Payload(), &p); err != nil {
return fmt.Errorf("failed to unmarshal payload: %w", err)
}
log.Info().
Str("to", p.To).
Str("type", "welcome").
Msg("Processing email job")
if err := h.emailService.SendWelcomeEmail(p.To, p.FirstName, p.ConfirmationCode); err != nil {
log.Error().Err(err).Str("to", p.To).Msg("Failed to send welcome email")
return err
}
log.Info().Str("to", p.To).Msg("Welcome email sent successfully")
return nil
}
// HandleVerificationEmail handles the verification email task
func (h *EmailJobHandler) HandleVerificationEmail(ctx context.Context, t *asynq.Task) error {
var p worker.VerificationEmailPayload
if err := json.Unmarshal(t.Payload(), &p); err != nil {
return fmt.Errorf("failed to unmarshal payload: %w", err)
}
log.Info().
Str("to", p.To).
Str("type", "verification").
Msg("Processing email job")
if err := h.emailService.SendVerificationEmail(p.To, p.FirstName, p.Code); err != nil {
log.Error().Err(err).Str("to", p.To).Msg("Failed to send verification email")
return err
}
log.Info().Str("to", p.To).Msg("Verification email sent successfully")
return nil
}
// HandlePasswordResetEmail handles the password reset email task
func (h *EmailJobHandler) HandlePasswordResetEmail(ctx context.Context, t *asynq.Task) error {
var p worker.PasswordResetEmailPayload
if err := json.Unmarshal(t.Payload(), &p); err != nil {
return fmt.Errorf("failed to unmarshal payload: %w", err)
}
log.Info().
Str("to", p.To).
Str("type", "password_reset").
Msg("Processing email job")
if err := h.emailService.SendPasswordResetEmail(p.To, p.FirstName, p.Code); err != nil {
log.Error().Err(err).Str("to", p.To).Msg("Failed to send password reset email")
return err
}
log.Info().Str("to", p.To).Msg("Password reset email sent successfully")
return nil
}
// HandlePasswordChangedEmail handles the password changed confirmation email task
func (h *EmailJobHandler) HandlePasswordChangedEmail(ctx context.Context, t *asynq.Task) error {
var p worker.EmailPayload
if err := json.Unmarshal(t.Payload(), &p); err != nil {
return fmt.Errorf("failed to unmarshal payload: %w", err)
}
log.Info().
Str("to", p.To).
Str("type", "password_changed").
Msg("Processing email job")
if err := h.emailService.SendPasswordChangedEmail(p.To, p.FirstName); err != nil {
log.Error().Err(err).Str("to", p.To).Msg("Failed to send password changed email")
return err
}
log.Info().Str("to", p.To).Msg("Password changed email sent successfully")
return nil
}

View File

@@ -0,0 +1,162 @@
package jobs
import (
"context"
"encoding/json"
"github.com/hibiken/asynq"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/mycrib-api/internal/config"
"github.com/treytartt/mycrib-api/internal/push"
)
// Task types
const (
TypeTaskReminder = "notification:task_reminder"
TypeOverdueReminder = "notification:overdue_reminder"
TypeDailyDigest = "notification:daily_digest"
TypeSendEmail = "email:send"
TypeSendPush = "push:send"
)
// Handler handles background job processing
type Handler struct {
db *gorm.DB
pushClient *push.GorushClient
config *config.Config
}
// NewHandler creates a new job handler
func NewHandler(db *gorm.DB, pushClient *push.GorushClient, cfg *config.Config) *Handler {
return &Handler{
db: db,
pushClient: pushClient,
config: cfg,
}
}
// HandleTaskReminder processes task reminder notifications
func (h *Handler) HandleTaskReminder(ctx context.Context, task *asynq.Task) error {
log.Info().Msg("Processing task reminder notifications...")
// TODO: Implement task reminder logic
// 1. Query tasks due today or tomorrow
// 2. Get user device tokens
// 3. Send push notifications via Gorush
log.Info().Msg("Task reminder notifications completed")
return nil
}
// HandleOverdueReminder processes overdue task notifications
func (h *Handler) HandleOverdueReminder(ctx context.Context, task *asynq.Task) error {
log.Info().Msg("Processing overdue task notifications...")
// TODO: Implement overdue reminder logic
// 1. Query overdue tasks
// 2. Get user device tokens
// 3. Send push notifications via Gorush
log.Info().Msg("Overdue task notifications completed")
return nil
}
// HandleDailyDigest processes daily digest notifications
func (h *Handler) HandleDailyDigest(ctx context.Context, task *asynq.Task) error {
log.Info().Msg("Processing daily digest notifications...")
// TODO: Implement daily digest logic
// 1. Aggregate task statistics per user
// 2. Get user device tokens
// 3. Send push notifications via Gorush
log.Info().Msg("Daily digest notifications completed")
return nil
}
// EmailPayload represents the payload for email tasks
type EmailPayload struct {
To string `json:"to"`
Subject string `json:"subject"`
Body string `json:"body"`
IsHTML bool `json:"is_html"`
}
// HandleSendEmail processes email sending tasks
func (h *Handler) HandleSendEmail(ctx context.Context, task *asynq.Task) error {
var payload EmailPayload
if err := json.Unmarshal(task.Payload(), &payload); err != nil {
return err
}
log.Info().
Str("to", payload.To).
Str("subject", payload.Subject).
Msg("Sending email...")
// TODO: Implement email sending via EmailService
log.Info().Str("to", payload.To).Msg("Email sent successfully")
return nil
}
// PushPayload represents the payload for push notification tasks
type PushPayload struct {
UserID uint `json:"user_id"`
Title string `json:"title"`
Message string `json:"message"`
Data map[string]string `json:"data,omitempty"`
}
// HandleSendPush processes push notification tasks
func (h *Handler) HandleSendPush(ctx context.Context, task *asynq.Task) error {
var payload PushPayload
if err := json.Unmarshal(task.Payload(), &payload); err != nil {
return err
}
log.Info().
Uint("user_id", payload.UserID).
Str("title", payload.Title).
Msg("Sending push notification...")
if h.pushClient == nil {
log.Warn().Msg("Push client not configured, skipping notification")
return nil
}
// TODO: Get user device tokens and send via Gorush
log.Info().Uint("user_id", payload.UserID).Msg("Push notification sent successfully")
return nil
}
// NewSendEmailTask creates a new email sending task
func NewSendEmailTask(to, subject, body string, isHTML bool) (*asynq.Task, error) {
payload, err := json.Marshal(EmailPayload{
To: to,
Subject: subject,
Body: body,
IsHTML: isHTML,
})
if err != nil {
return nil, err
}
return asynq.NewTask(TypeSendEmail, payload), nil
}
// NewSendPushTask creates a new push notification task
func NewSendPushTask(userID uint, title, message string, data map[string]string) (*asynq.Task, error) {
payload, err := json.Marshal(PushPayload{
UserID: userID,
Title: title,
Message: message,
Data: data,
})
if err != nil {
return nil, err
}
return asynq.NewTask(TypeSendPush, payload), nil
}

View File

@@ -0,0 +1,239 @@
package worker
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"github.com/rs/zerolog/log"
)
// Task types
const (
TypeWelcomeEmail = "email:welcome"
TypeVerificationEmail = "email:verification"
TypePasswordResetEmail = "email:password_reset"
TypePasswordChangedEmail = "email:password_changed"
TypeTaskCompletionEmail = "email:task_completion"
TypeGeneratePDFReport = "pdf:generate_report"
TypeUpdateContractorRating = "contractor:update_rating"
TypeDailyNotifications = "notifications:daily"
TypeTaskReminders = "notifications:task_reminders"
TypeOverdueReminders = "notifications:overdue_reminders"
)
// EmailPayload is the base payload for email tasks
type EmailPayload struct {
To string `json:"to"`
FirstName string `json:"first_name"`
}
// WelcomeEmailPayload is the payload for welcome emails
type WelcomeEmailPayload struct {
EmailPayload
ConfirmationCode string `json:"confirmation_code"`
}
// VerificationEmailPayload is the payload for verification emails
type VerificationEmailPayload struct {
EmailPayload
Code string `json:"code"`
}
// PasswordResetEmailPayload is the payload for password reset emails
type PasswordResetEmailPayload struct {
EmailPayload
Code string `json:"code"`
ResetToken string `json:"reset_token"`
}
// TaskClient wraps the asynq client for enqueuing tasks
type TaskClient struct {
client *asynq.Client
}
// NewTaskClient creates a new task client
func NewTaskClient(redisAddr string) *TaskClient {
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
return &TaskClient{client: client}
}
// Close closes the task client
func (c *TaskClient) Close() error {
return c.client.Close()
}
// EnqueueWelcomeEmail enqueues a welcome email task
func (c *TaskClient) EnqueueWelcomeEmail(to, firstName, code string) error {
payload, err := json.Marshal(WelcomeEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
ConfirmationCode: code,
})
if err != nil {
return err
}
task := asynq.NewTask(TypeWelcomeEmail, payload)
_, err = c.client.Enqueue(task, asynq.Queue("default"), asynq.MaxRetry(3))
if err != nil {
log.Error().Err(err).Str("to", to).Msg("Failed to enqueue welcome email")
return err
}
log.Debug().Str("to", to).Msg("Welcome email task enqueued")
return nil
}
// EnqueueVerificationEmail enqueues a verification email task
func (c *TaskClient) EnqueueVerificationEmail(to, firstName, code string) error {
payload, err := json.Marshal(VerificationEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
Code: code,
})
if err != nil {
return err
}
task := asynq.NewTask(TypeVerificationEmail, payload)
_, err = c.client.Enqueue(task, asynq.Queue("default"), asynq.MaxRetry(3))
if err != nil {
log.Error().Err(err).Str("to", to).Msg("Failed to enqueue verification email")
return err
}
log.Debug().Str("to", to).Msg("Verification email task enqueued")
return nil
}
// EnqueuePasswordResetEmail enqueues a password reset email task
func (c *TaskClient) EnqueuePasswordResetEmail(to, firstName, code, resetToken string) error {
payload, err := json.Marshal(PasswordResetEmailPayload{
EmailPayload: EmailPayload{To: to, FirstName: firstName},
Code: code,
ResetToken: resetToken,
})
if err != nil {
return err
}
task := asynq.NewTask(TypePasswordResetEmail, payload)
_, err = c.client.Enqueue(task, asynq.Queue("default"), asynq.MaxRetry(3))
if err != nil {
log.Error().Err(err).Str("to", to).Msg("Failed to enqueue password reset email")
return err
}
log.Debug().Str("to", to).Msg("Password reset email task enqueued")
return nil
}
// EnqueuePasswordChangedEmail enqueues a password changed confirmation email
func (c *TaskClient) EnqueuePasswordChangedEmail(to, firstName string) error {
payload, err := json.Marshal(EmailPayload{To: to, FirstName: firstName})
if err != nil {
return err
}
task := asynq.NewTask(TypePasswordChangedEmail, payload)
_, err = c.client.Enqueue(task, asynq.Queue("default"), asynq.MaxRetry(3))
if err != nil {
log.Error().Err(err).Str("to", to).Msg("Failed to enqueue password changed email")
return err
}
log.Debug().Str("to", to).Msg("Password changed email task enqueued")
return nil
}
// WorkerServer manages the asynq worker server
type WorkerServer struct {
server *asynq.Server
scheduler *asynq.Scheduler
}
// NewWorkerServer creates a new worker server
func NewWorkerServer(redisAddr string, concurrency int) *WorkerServer {
srv := asynq.NewServer(
asynq.RedisClientOpt{Addr: redisAddr},
asynq.Config{
Concurrency: concurrency,
Queues: map[string]int{
"critical": 6,
"default": 3,
"low": 1,
},
ErrorHandler: asynq.ErrorHandlerFunc(func(ctx context.Context, task *asynq.Task, err error) {
log.Error().
Err(err).
Str("type", task.Type()).
Bytes("payload", task.Payload()).
Msg("Task failed")
}),
},
)
// Create scheduler for periodic tasks
loc, _ := time.LoadLocation("UTC")
scheduler := asynq.NewScheduler(
asynq.RedisClientOpt{Addr: redisAddr},
&asynq.SchedulerOpts{
Location: loc,
},
)
return &WorkerServer{
server: srv,
scheduler: scheduler,
}
}
// RegisterHandlers registers task handlers
func (w *WorkerServer) RegisterHandlers(mux *asynq.ServeMux) {
// Handlers will be registered by the main worker process
}
// RegisterScheduledTasks registers periodic tasks
func (w *WorkerServer) RegisterScheduledTasks() error {
// Task reminders - 8:00 PM UTC daily
_, err := w.scheduler.Register("0 20 * * *", asynq.NewTask(TypeTaskReminders, nil))
if err != nil {
return fmt.Errorf("failed to register task reminders: %w", err)
}
// Overdue reminders - 9:00 AM UTC daily
_, err = w.scheduler.Register("0 9 * * *", asynq.NewTask(TypeOverdueReminders, nil))
if err != nil {
return fmt.Errorf("failed to register overdue reminders: %w", err)
}
// Daily notifications - 11:00 AM UTC daily
_, err = w.scheduler.Register("0 11 * * *", asynq.NewTask(TypeDailyNotifications, nil))
if err != nil {
return fmt.Errorf("failed to register daily notifications: %w", err)
}
return nil
}
// Start starts the worker server and scheduler
func (w *WorkerServer) Start(mux *asynq.ServeMux) error {
// Start scheduler
if err := w.scheduler.Start(); err != nil {
return fmt.Errorf("failed to start scheduler: %w", err)
}
// Start server
if err := w.server.Start(mux); err != nil {
return fmt.Errorf("failed to start worker server: %w", err)
}
return nil
}
// Shutdown gracefully shuts down the worker server
func (w *WorkerServer) Shutdown() {
w.scheduler.Shutdown()
w.server.Shutdown()
}