Harden API security: input validation, safe auth extraction, new tests, and deploy config
Comprehensive security hardening from audit findings: - Add validation tags to all DTO request structs (max lengths, ranges, enums) - Replace unsafe type assertions with MustGetAuthUser helper across all handlers - Remove query-param token auth from admin middleware (prevents URL token leakage) - Add request validation calls in handlers that were missing c.Validate() - Remove goroutines in handlers (timezone update now synchronous) - Add sanitize middleware and path traversal protection (path_utils) - Stop resetting admin passwords on migration restart - Warn on well-known default SECRET_KEY - Add ~30 new test files covering security regressions, auth safety, repos, and services - Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -81,6 +81,11 @@ func (h *AuthHandler) Register(c echo.Context) error {
|
||||
// Send welcome email with confirmation code (async)
|
||||
if h.emailService != nil && confirmationCode != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
|
||||
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
|
||||
}
|
||||
@@ -176,6 +181,11 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
|
||||
// Send post-verification welcome email with tips (async)
|
||||
if h.emailService != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
|
||||
}
|
||||
@@ -204,6 +214,11 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
|
||||
// Send verification email (async)
|
||||
if h.emailService != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
|
||||
}
|
||||
@@ -238,6 +253,11 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
|
||||
// Send password reset email (async) - only if user found
|
||||
if h.emailService != nil && code != "" && user != nil {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in password reset email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
|
||||
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
|
||||
}
|
||||
@@ -326,6 +346,11 @@ func (h *AuthHandler) AppleSignIn(c echo.Context) error {
|
||||
// Send welcome email for new users (async)
|
||||
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email")
|
||||
}
|
||||
@@ -368,6 +393,11 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
|
||||
// Send welcome email for new users (async)
|
||||
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
|
||||
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email")
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -25,17 +24,23 @@ func NewContractorHandler(contractorService *services.ContractorService) *Contra
|
||||
|
||||
// ListContractors handles GET /api/contractors/
|
||||
func (h *ContractorHandler) ListContractors(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := h.contractorService.ListContractors(user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetContractor handles GET /api/contractors/:id/
|
||||
func (h *ContractorHandler) GetContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -50,11 +55,17 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
|
||||
|
||||
// CreateContractor handles POST /api/contractors/
|
||||
func (h *ContractorHandler) CreateContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req requests.CreateContractorRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.contractorService.CreateContractor(&req, user.ID)
|
||||
if err != nil {
|
||||
@@ -65,7 +76,10 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
|
||||
|
||||
// UpdateContractor handles PUT/PATCH /api/contractors/:id/
|
||||
func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -75,6 +89,9 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -85,7 +102,10 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
|
||||
|
||||
// DeleteContractor handles DELETE /api/contractors/:id/
|
||||
func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -100,7 +120,10 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
|
||||
|
||||
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
|
||||
func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -115,7 +138,10 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
|
||||
|
||||
// GetContractorTasks handles GET /api/contractors/:id/tasks/
|
||||
func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_contractor_id")
|
||||
@@ -130,7 +156,10 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
|
||||
|
||||
// ListContractorsByResidence handles GET /api/contractors/by-residence/:residence_id/
|
||||
func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_residence_id")
|
||||
@@ -147,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
|
||||
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
|
||||
specialties, err := h.contractorService.GetSpecialties()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, specialties)
|
||||
}
|
||||
|
||||
182
internal/handlers/contractor_handler_test.go
Normal file
182
internal/handlers/contractor_handler_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupContractorHandler(t *testing.T) (*ContractorHandler, *echo.Echo, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
|
||||
handler := NewContractorHandler(contractorService)
|
||||
e := testutil.SetupTestRouter()
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func TestContractorHandler_CreateContractor_MissingName_Returns400(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateContractor)
|
||||
|
||||
t.Run("missing name returns 400 validation error", func(t *testing.T) {
|
||||
// Send request with no name (required field)
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should contain structured validation error
|
||||
assert.Contains(t, response, "error")
|
||||
assert.Contains(t, response, "fields")
|
||||
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "name", "validation error should reference the 'name' field")
|
||||
})
|
||||
|
||||
t.Run("empty body returns 400 validation error", func(t *testing.T) {
|
||||
// Send completely empty body
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", map[string]interface{}{}, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "error")
|
||||
})
|
||||
|
||||
t.Run("valid contractor creation succeeds", func(t *testing.T) {
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "John the Plumber",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_ListContractors_Error_NoRawErrorInResponse(t *testing.T) {
|
||||
_, e, db := setupContractorHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a handler with a broken service to simulate an internal error.
|
||||
// We do this by closing the underlying SQL connection, which will cause
|
||||
// the service to return an error on the next query.
|
||||
brokenDB := testutil.SetupTestDB(t)
|
||||
sqlDB, _ := brokenDB.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
brokenContractorRepo := repositories.NewContractorRepository(brokenDB)
|
||||
brokenResidenceRepo := repositories.NewResidenceRepository(brokenDB)
|
||||
brokenService := services.NewContractorService(brokenContractorRepo, brokenResidenceRepo)
|
||||
brokenHandler := NewContractorHandler(brokenService)
|
||||
|
||||
authGroup := e.Group("/api/broken-contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", brokenHandler.ListContractors)
|
||||
|
||||
t.Run("internal error does not leak raw error message", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/broken-contractors/", nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should contain the generic error key, NOT a raw database error
|
||||
errorMsg, ok := response["error"].(string)
|
||||
require.True(t, ok, "response should have an 'error' string field")
|
||||
|
||||
// Must not contain database-specific details
|
||||
assert.NotContains(t, errorMsg, "sql", "error message should not leak SQL details")
|
||||
assert.NotContains(t, errorMsg, "database", "error message should not leak database details")
|
||||
assert.NotContains(t, errorMsg, "closed", "error message should not leak connection state")
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_CreateContractor_100Specialties_Returns400(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateContractor)
|
||||
|
||||
t.Run("too many specialties rejected", func(t *testing.T) {
|
||||
// Create a slice with 100 specialty IDs (exceeds max=20)
|
||||
specialtyIDs := make([]uint, 100)
|
||||
for i := range specialtyIDs {
|
||||
specialtyIDs[i] = uint(i + 1)
|
||||
}
|
||||
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Over-specialized Contractor",
|
||||
SpecialtyIDs: specialtyIDs,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("20 specialties accepted", func(t *testing.T) {
|
||||
specialtyIDs := make([]uint, 20)
|
||||
for i := range specialtyIDs {
|
||||
specialtyIDs[i] = uint(i + 1)
|
||||
}
|
||||
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Multi-skilled Contractor",
|
||||
SpecialtyIDs: specialtyIDs,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
// Should pass validation (201 or success, not 400)
|
||||
assert.NotEqual(t, http.StatusBadRequest, w.Code, "20 specialties should pass validation")
|
||||
})
|
||||
|
||||
t.Run("rating above 5 rejected", func(t *testing.T) {
|
||||
rating := 6.0
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Bad Rating Contractor",
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
@@ -34,7 +34,10 @@ func NewDocumentHandler(documentService *services.DocumentService, storageServic
|
||||
|
||||
// ListDocuments handles GET /api/documents/
|
||||
func (h *DocumentHandler) ListDocuments(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build filter from supported query params.
|
||||
var filter *repositories.DocumentFilter
|
||||
@@ -71,7 +74,10 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
|
||||
|
||||
// GetDocument handles GET /api/documents/:id/
|
||||
func (h *DocumentHandler) GetDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -86,10 +92,13 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
|
||||
|
||||
// ListWarranties handles GET /api/documents/warranties/
|
||||
func (h *DocumentHandler) ListWarranties(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := h.documentService.ListWarranties(user.ID)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
|
||||
return apperrors.Internal(err)
|
||||
}
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
@@ -97,7 +106,10 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
|
||||
// CreateDocument handles POST /api/documents/
|
||||
// Supports both JSON and multipart form data (for file uploads)
|
||||
func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var req requests.CreateDocumentRequest
|
||||
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
@@ -198,6 +210,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.documentService.CreateDocument(&req, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -207,7 +223,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
|
||||
|
||||
// UpdateDocument handles PUT/PATCH /api/documents/:id/
|
||||
func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -217,6 +236,9 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -227,7 +249,10 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
|
||||
|
||||
// DeleteDocument handles DELETE /api/documents/:id/
|
||||
func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -242,7 +267,10 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
|
||||
|
||||
// ActivateDocument handles POST /api/documents/:id/activate/
|
||||
func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -257,7 +285,10 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
|
||||
|
||||
// DeactivateDocument handles POST /api/documents/:id/deactivate/
|
||||
func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -272,7 +303,10 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
|
||||
|
||||
// UploadDocumentImage handles POST /api/documents/:id/images/
|
||||
func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
@@ -316,7 +350,10 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
|
||||
|
||||
// DeleteDocumentImage handles DELETE /api/documents/:id/images/:imageId/
|
||||
func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_document_id")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
@@ -40,7 +38,10 @@ func NewMediaHandler(
|
||||
// ServeDocument serves a document file with access control
|
||||
// GET /api/media/document/:id
|
||||
func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -73,7 +74,10 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
|
||||
// ServeDocumentImage serves a document image with access control
|
||||
// GET /api/media/document-image/:id
|
||||
func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -111,7 +115,10 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
|
||||
// ServeCompletionImage serves a task completion image with access control
|
||||
// GET /api/media/completion-image/:id
|
||||
func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -152,7 +159,9 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
|
||||
return c.File(filePath)
|
||||
}
|
||||
|
||||
// resolveFilePath converts a stored URL to an actual file path
|
||||
// resolveFilePath converts a stored URL to an actual file path.
|
||||
// Returns empty string if the URL is empty or the resolved path would escape
|
||||
// the upload directory (path traversal attempt).
|
||||
func (h *MediaHandler) resolveFilePath(storedURL string) string {
|
||||
if storedURL == "" {
|
||||
return ""
|
||||
@@ -160,12 +169,18 @@ func (h *MediaHandler) resolveFilePath(storedURL string) string {
|
||||
|
||||
uploadDir := h.storageSvc.GetUploadDir()
|
||||
|
||||
// Handle legacy /uploads/... URLs
|
||||
// Strip legacy /uploads/ prefix to get relative path
|
||||
relativePath := storedURL
|
||||
if strings.HasPrefix(storedURL, "/uploads/") {
|
||||
relativePath := strings.TrimPrefix(storedURL, "/uploads/")
|
||||
return filepath.Join(uploadDir, relativePath)
|
||||
relativePath = strings.TrimPrefix(storedURL, "/uploads/")
|
||||
}
|
||||
|
||||
// Handle relative paths (new format)
|
||||
return filepath.Join(uploadDir, storedURL)
|
||||
// Use SafeResolvePath to validate containment within upload directory
|
||||
resolved, err := services.SafeResolvePath(uploadDir, relativePath)
|
||||
if err != nil {
|
||||
// Path traversal or invalid path — return empty to signal file not found
|
||||
return ""
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
74
internal/handlers/media_handler_test.go
Normal file
74
internal/handlers/media_handler_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
// newTestStorageService creates a StorageService with a known upload directory for testing.
|
||||
// It does NOT call NewStorageService because that creates directories on disk.
|
||||
// Instead, it directly constructs the struct with only what resolveFilePath needs.
|
||||
func newTestStorageService(uploadDir string) *services.StorageService {
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: uploadDir,
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg,image/png",
|
||||
}
|
||||
// Use the exported constructor helper that skips directory creation (for tests)
|
||||
return services.NewStorageServiceForTest(cfg)
|
||||
}
|
||||
|
||||
func TestResolveFilePath_NormalPath_Works(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
result := h.resolveFilePath("images/photo.jpg")
|
||||
require.NotEmpty(t, result)
|
||||
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
|
||||
}
|
||||
|
||||
func TestResolveFilePath_LegacyUploadPath_Works(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
result := h.resolveFilePath("/uploads/images/photo.jpg")
|
||||
require.NotEmpty(t, result)
|
||||
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
|
||||
}
|
||||
|
||||
func TestResolveFilePath_DotDotTraversal_Blocked(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
storedURL string
|
||||
}{
|
||||
{"simple dotdot", "../etc/passwd"},
|
||||
{"nested dotdot", "../../etc/shadow"},
|
||||
{"embedded dotdot", "images/../../etc/passwd"},
|
||||
{"legacy prefix with dotdot", "/uploads/../../../etc/passwd"},
|
||||
{"deep dotdot", "a/b/c/../../../../etc/passwd"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := h.resolveFilePath(tt.storedURL)
|
||||
assert.Empty(t, result, "path traversal should return empty string for: %s", tt.storedURL)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveFilePath_EmptyURL_ReturnsEmpty(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
h := NewMediaHandler(nil, nil, nil, storageSvc)
|
||||
|
||||
result := h.resolveFilePath("")
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
334
internal/handlers/noauth_test.go
Normal file
334
internal/handlers/noauth_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
// TestTaskHandler_NoAuth_Returns401 verifies that task handler endpoints return
|
||||
// 401 Unauthorized when no auth user is set in the context (e.g., auth middleware
|
||||
// misconfigured or bypassed). This is a regression test for P1-1 (SEC-19).
|
||||
func TestTaskHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskService := services.NewTaskService(taskRepo, residenceRepo)
|
||||
handler := NewTaskHandler(taskService, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/tasks/", handler.ListTasks)
|
||||
e.GET("/api/tasks/:id/", handler.GetTask)
|
||||
e.POST("/api/tasks/", handler.CreateTask)
|
||||
e.PUT("/api/tasks/:id/", handler.UpdateTask)
|
||||
e.DELETE("/api/tasks/:id/", handler.DeleteTask)
|
||||
e.POST("/api/tasks/:id/cancel/", handler.CancelTask)
|
||||
e.POST("/api/tasks/:id/mark-in-progress/", handler.MarkInProgress)
|
||||
e.GET("/api/task-completions/", handler.ListCompletions)
|
||||
e.POST("/api/task-completions/", handler.CreateCompletion)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListTasks", "GET", "/api/tasks/"},
|
||||
{"GetTask", "GET", "/api/tasks/1/"},
|
||||
{"CreateTask", "POST", "/api/tasks/"},
|
||||
{"UpdateTask", "PUT", "/api/tasks/1/"},
|
||||
{"DeleteTask", "DELETE", "/api/tasks/1/"},
|
||||
{"CancelTask", "POST", "/api/tasks/1/cancel/"},
|
||||
{"MarkInProgress", "POST", "/api/tasks/1/mark-in-progress/"},
|
||||
{"ListCompletions", "GET", "/api/task-completions/"},
|
||||
{"CreateCompletion", "POST", "/api/task-completions/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResidenceHandler_NoAuth_Returns401 verifies that residence handler endpoints
|
||||
// return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestResidenceHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
handler := NewResidenceHandler(residenceService, nil, nil, true)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/residences/", handler.ListResidences)
|
||||
e.GET("/api/residences/my/", handler.GetMyResidences)
|
||||
e.GET("/api/residences/summary/", handler.GetSummary)
|
||||
e.GET("/api/residences/:id/", handler.GetResidence)
|
||||
e.POST("/api/residences/", handler.CreateResidence)
|
||||
e.PUT("/api/residences/:id/", handler.UpdateResidence)
|
||||
e.DELETE("/api/residences/:id/", handler.DeleteResidence)
|
||||
e.POST("/api/residences/:id/generate-share-code/", handler.GenerateShareCode)
|
||||
e.POST("/api/residences/join-with-code/", handler.JoinWithCode)
|
||||
e.GET("/api/residences/:id/users/", handler.GetResidenceUsers)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListResidences", "GET", "/api/residences/"},
|
||||
{"GetMyResidences", "GET", "/api/residences/my/"},
|
||||
{"GetSummary", "GET", "/api/residences/summary/"},
|
||||
{"GetResidence", "GET", "/api/residences/1/"},
|
||||
{"CreateResidence", "POST", "/api/residences/"},
|
||||
{"UpdateResidence", "PUT", "/api/residences/1/"},
|
||||
{"DeleteResidence", "DELETE", "/api/residences/1/"},
|
||||
{"GenerateShareCode", "POST", "/api/residences/1/generate-share-code/"},
|
||||
{"JoinWithCode", "POST", "/api/residences/join-with-code/"},
|
||||
{"GetResidenceUsers", "GET", "/api/residences/1/users/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotificationHandler_NoAuth_Returns401 verifies that notification handler
|
||||
// endpoints return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestNotificationHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notificationRepo := repositories.NewNotificationRepository(db)
|
||||
notificationService := services.NewNotificationService(notificationRepo, nil)
|
||||
handler := NewNotificationHandler(notificationService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/notifications/", handler.ListNotifications)
|
||||
e.GET("/api/notifications/unread-count/", handler.GetUnreadCount)
|
||||
e.POST("/api/notifications/:id/read/", handler.MarkAsRead)
|
||||
e.POST("/api/notifications/mark-all-read/", handler.MarkAllAsRead)
|
||||
e.GET("/api/notifications/preferences/", handler.GetPreferences)
|
||||
e.PUT("/api/notifications/preferences/", handler.UpdatePreferences)
|
||||
e.POST("/api/notifications/devices/", handler.RegisterDevice)
|
||||
e.GET("/api/notifications/devices/", handler.ListDevices)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListNotifications", "GET", "/api/notifications/"},
|
||||
{"GetUnreadCount", "GET", "/api/notifications/unread-count/"},
|
||||
{"MarkAsRead", "POST", "/api/notifications/1/read/"},
|
||||
{"MarkAllAsRead", "POST", "/api/notifications/mark-all-read/"},
|
||||
{"GetPreferences", "GET", "/api/notifications/preferences/"},
|
||||
{"UpdatePreferences", "PUT", "/api/notifications/preferences/"},
|
||||
{"RegisterDevice", "POST", "/api/notifications/devices/"},
|
||||
{"ListDevices", "GET", "/api/notifications/devices/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDocumentHandler_NoAuth_Returns401 verifies that document handler endpoints
|
||||
// return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestDocumentHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
documentService := services.NewDocumentService(documentRepo, residenceRepo)
|
||||
handler := NewDocumentHandler(documentService, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/documents/", handler.ListDocuments)
|
||||
e.GET("/api/documents/:id/", handler.GetDocument)
|
||||
e.GET("/api/documents/warranties/", handler.ListWarranties)
|
||||
e.POST("/api/documents/", handler.CreateDocument)
|
||||
e.PUT("/api/documents/:id/", handler.UpdateDocument)
|
||||
e.DELETE("/api/documents/:id/", handler.DeleteDocument)
|
||||
e.POST("/api/documents/:id/activate/", handler.ActivateDocument)
|
||||
e.POST("/api/documents/:id/deactivate/", handler.DeactivateDocument)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListDocuments", "GET", "/api/documents/"},
|
||||
{"GetDocument", "GET", "/api/documents/1/"},
|
||||
{"ListWarranties", "GET", "/api/documents/warranties/"},
|
||||
{"CreateDocument", "POST", "/api/documents/"},
|
||||
{"UpdateDocument", "PUT", "/api/documents/1/"},
|
||||
{"DeleteDocument", "DELETE", "/api/documents/1/"},
|
||||
{"ActivateDocument", "POST", "/api/documents/1/activate/"},
|
||||
{"DeactivateDocument", "POST", "/api/documents/1/deactivate/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContractorHandler_NoAuth_Returns401 verifies that contractor handler endpoints
|
||||
// return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestContractorHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
|
||||
handler := NewContractorHandler(contractorService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/contractors/", handler.ListContractors)
|
||||
e.GET("/api/contractors/:id/", handler.GetContractor)
|
||||
e.POST("/api/contractors/", handler.CreateContractor)
|
||||
e.PUT("/api/contractors/:id/", handler.UpdateContractor)
|
||||
e.DELETE("/api/contractors/:id/", handler.DeleteContractor)
|
||||
e.POST("/api/contractors/:id/toggle-favorite/", handler.ToggleFavorite)
|
||||
e.GET("/api/contractors/:id/tasks/", handler.GetContractorTasks)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListContractors", "GET", "/api/contractors/"},
|
||||
{"GetContractor", "GET", "/api/contractors/1/"},
|
||||
{"CreateContractor", "POST", "/api/contractors/"},
|
||||
{"UpdateContractor", "PUT", "/api/contractors/1/"},
|
||||
{"DeleteContractor", "DELETE", "/api/contractors/1/"},
|
||||
{"ToggleFavorite", "POST", "/api/contractors/1/toggle-favorite/"},
|
||||
{"GetContractorTasks", "GET", "/api/contractors/1/tasks/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubscriptionHandler_NoAuth_Returns401 verifies that subscription handler
|
||||
// endpoints return 401 Unauthorized when no auth user is set in the context.
|
||||
func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
|
||||
handler := NewSubscriptionHandler(subscriptionService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/subscription/", handler.GetSubscription)
|
||||
e.GET("/api/subscription/status/", handler.GetSubscriptionStatus)
|
||||
e.GET("/api/subscription/promotions/", handler.GetPromotions)
|
||||
e.POST("/api/subscription/purchase/", handler.ProcessPurchase)
|
||||
e.POST("/api/subscription/cancel/", handler.CancelSubscription)
|
||||
e.POST("/api/subscription/restore/", handler.RestoreSubscription)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"GetSubscription", "GET", "/api/subscription/"},
|
||||
{"GetSubscriptionStatus", "GET", "/api/subscription/status/"},
|
||||
{"GetPromotions", "GET", "/api/subscription/promotions/"},
|
||||
{"ProcessPurchase", "POST", "/api/subscription/purchase/"},
|
||||
{"CancelSubscription", "POST", "/api/subscription/cancel/"},
|
||||
{"RestoreSubscription", "POST", "/api/subscription/restore/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMediaHandler_NoAuth_Returns401 verifies that media handler endpoints return
|
||||
// 401 Unauthorized when no auth user is set in the context.
|
||||
func TestMediaHandler_NoAuth_Returns401(t *testing.T) {
|
||||
handler := NewMediaHandler(nil, nil, nil, nil)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/media/document/:id", handler.ServeDocument)
|
||||
e.GET("/api/media/document-image/:id", handler.ServeDocumentImage)
|
||||
e.GET("/api/media/completion-image/:id", handler.ServeCompletionImage)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ServeDocument", "GET", "/api/media/document/1"},
|
||||
{"ServeDocumentImage", "GET", "/api/media/document-image/1"},
|
||||
{"ServeCompletionImage", "GET", "/api/media/completion-image/1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserHandler_NoAuth_Returns401 verifies that user handler endpoints return
|
||||
// 401 Unauthorized when no auth user is set in the context.
|
||||
func TestUserHandler_NoAuth_Returns401(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
userService := services.NewUserService(userRepo)
|
||||
handler := NewUserHandler(userService)
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register routes WITHOUT auth middleware
|
||||
e.GET("/api/users/", handler.ListUsers)
|
||||
e.GET("/api/users/:id/", handler.GetUser)
|
||||
e.GET("/api/users/profiles/", handler.ListProfiles)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"ListUsers", "GET", "/api/users/"},
|
||||
{"GetUser", "GET", "/api/users/1/"},
|
||||
{"ListProfiles", "GET", "/api/users/profiles/"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -24,7 +23,10 @@ func NewNotificationHandler(notificationService *services.NotificationService) *
|
||||
|
||||
// ListNotifications handles GET /api/notifications/
|
||||
func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
limit := 50
|
||||
offset := 0
|
||||
@@ -33,6 +35,9 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
if limit > 200 {
|
||||
limit = 200
|
||||
}
|
||||
if o := c.QueryParam("offset"); o != "" {
|
||||
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
|
||||
offset = parsed
|
||||
@@ -52,7 +57,10 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
|
||||
|
||||
// GetUnreadCount handles GET /api/notifications/unread-count/
|
||||
func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := h.notificationService.GetUnreadCount(user.ID)
|
||||
if err != nil {
|
||||
@@ -64,7 +72,10 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
|
||||
|
||||
// MarkAsRead handles POST /api/notifications/:id/read/
|
||||
func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
notificationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -81,9 +92,12 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
|
||||
|
||||
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
|
||||
func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := h.notificationService.MarkAllAsRead(user.ID)
|
||||
err = h.notificationService.MarkAllAsRead(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -93,7 +107,10 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
|
||||
|
||||
// GetPreferences handles GET /api/notifications/preferences/
|
||||
func (h *NotificationHandler) GetPreferences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefs, err := h.notificationService.GetPreferences(user.ID)
|
||||
if err != nil {
|
||||
@@ -105,12 +122,18 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
|
||||
|
||||
// UpdatePreferences handles PUT/PATCH /api/notifications/preferences/
|
||||
func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.UpdatePreferencesRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -122,12 +145,18 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
|
||||
|
||||
// RegisterDevice handles POST /api/notifications/devices/
|
||||
func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.RegisterDeviceRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
device, err := h.notificationService.RegisterDevice(user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -139,7 +168,10 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
|
||||
|
||||
// ListDevices handles GET /api/notifications/devices/
|
||||
func (h *NotificationHandler) ListDevices(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices, err := h.notificationService.ListDevices(user.ID)
|
||||
if err != nil {
|
||||
@@ -152,7 +184,10 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
|
||||
// UnregisterDevice handles POST /api/notifications/devices/unregister/
|
||||
// Accepts {registration_id, platform} and deactivates the matching device
|
||||
func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req struct {
|
||||
RegistrationID string `json:"registration_id"`
|
||||
@@ -168,7 +203,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
||||
req.Platform = "ios" // Default to iOS
|
||||
}
|
||||
|
||||
err := h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
||||
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -178,7 +213,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
|
||||
|
||||
// DeleteDevice handles DELETE /api/notifications/devices/:id/
|
||||
func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
deviceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
|
||||
88
internal/handlers/notification_handler_test.go
Normal file
88
internal/handlers/notification_handler_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/repositories"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupNotificationHandler(t *testing.T) (*NotificationHandler, *echo.Echo, *gorm.DB) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
notifService := services.NewNotificationService(notifRepo, nil)
|
||||
handler := NewNotificationHandler(notifService)
|
||||
e := testutil.SetupTestRouter()
|
||||
return handler, e, db
|
||||
}
|
||||
|
||||
func createTestNotifications(t *testing.T, db *gorm.DB, userID uint, count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
notif := &models.Notification{
|
||||
UserID: userID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: fmt.Sprintf("Test Notification %d", i+1),
|
||||
Body: fmt.Sprintf("Body %d", i+1),
|
||||
}
|
||||
err := db.Create(notif).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotificationHandler_ListNotifications_LimitCappedAt200(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
// Create 210 notifications to exceed the cap
|
||||
createTestNotifications(t, db, user.ID, 210)
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", handler.ListNotifications)
|
||||
|
||||
t.Run("limit is capped at 200 when user requests more", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=999", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 200, count, "response should contain at most 200 notifications when limit exceeds cap")
|
||||
})
|
||||
|
||||
t.Run("limit below cap is respected", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=10", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 10, count, "response should respect limit when below cap")
|
||||
})
|
||||
|
||||
t.Run("default limit is used when no limit param", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 50, count, "response should use default limit of 50")
|
||||
})
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/i18n"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
"github.com/treytartt/casera-api/internal/validator"
|
||||
)
|
||||
@@ -35,7 +34,10 @@ func NewResidenceHandler(residenceService *services.ResidenceService, pdfService
|
||||
|
||||
// ListResidences handles GET /api/residences/
|
||||
func (h *ResidenceHandler) ListResidences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.residenceService.ListResidences(user.ID)
|
||||
if err != nil {
|
||||
@@ -47,7 +49,10 @@ func (h *ResidenceHandler) ListResidences(c echo.Context) error {
|
||||
|
||||
// GetMyResidences handles GET /api/residences/my-residences/
|
||||
func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
response, err := h.residenceService.GetMyResidences(user.ID, userNow)
|
||||
@@ -61,7 +66,10 @@ func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
|
||||
// GetSummary handles GET /api/residences/summary/
|
||||
// Returns just the task statistics summary without full residence data
|
||||
func (h *ResidenceHandler) GetSummary(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
summary, err := h.residenceService.GetSummary(user.ID, userNow)
|
||||
@@ -74,7 +82,10 @@ func (h *ResidenceHandler) GetSummary(c echo.Context) error {
|
||||
|
||||
// GetResidence handles GET /api/residences/:id/
|
||||
func (h *ResidenceHandler) GetResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -91,7 +102,10 @@ func (h *ResidenceHandler) GetResidence(c echo.Context) error {
|
||||
|
||||
// CreateResidence handles POST /api/residences/
|
||||
func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req requests.CreateResidenceRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
@@ -111,7 +125,10 @@ func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
|
||||
|
||||
// UpdateResidence handles PUT/PATCH /api/residences/:id/
|
||||
func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -136,7 +153,10 @@ func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
|
||||
|
||||
// DeleteResidence handles DELETE /api/residences/:id/
|
||||
func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -154,7 +174,10 @@ func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
|
||||
// GetShareCode handles GET /api/residences/:id/share-code/
|
||||
// Returns the active share code for a residence, or null if none exists
|
||||
func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -175,7 +198,10 @@ func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
|
||||
|
||||
// GenerateShareCode handles POST /api/residences/:id/generate-share-code/
|
||||
func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -197,7 +223,10 @@ func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
|
||||
// GenerateSharePackage handles POST /api/residences/:id/generate-share-package/
|
||||
// Returns a share code with metadata for creating a .casera package file
|
||||
func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -218,12 +247,18 @@ func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
|
||||
|
||||
// JoinWithCode handles POST /api/residences/join-with-code/
|
||||
func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req requests.JoinWithCodeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.residenceService.JoinWithCode(req.Code, user.ID)
|
||||
if err != nil {
|
||||
@@ -235,7 +270,10 @@ func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
|
||||
|
||||
// GetResidenceUsers handles GET /api/residences/:id/users/
|
||||
func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -252,7 +290,10 @@ func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
|
||||
|
||||
// RemoveResidenceUser handles DELETE /api/residences/:id/users/:user_id/
|
||||
func (h *ResidenceHandler) RemoveResidenceUser(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -289,7 +330,10 @@ func (h *ResidenceHandler) GenerateTasksReport(c echo.Context) error {
|
||||
return apperrors.BadRequest("error.feature_disabled")
|
||||
}
|
||||
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
|
||||
@@ -525,3 +525,45 @@ func TestResidenceHandler_JSONResponses(t *testing.T) {
|
||||
assert.IsType(t, []map[string]interface{}{}, response)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_CreateResidence_NegativeBedrooms_Returns400(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateResidence)
|
||||
|
||||
t.Run("negative bedrooms rejected", func(t *testing.T) {
|
||||
bedrooms := -1
|
||||
req := requests.CreateResidenceRequest{
|
||||
Name: "Bad House",
|
||||
Bedrooms: &bedrooms,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("negative square footage rejected", func(t *testing.T) {
|
||||
sqft := -100
|
||||
req := requests.CreateResidenceRequest{
|
||||
Name: "Bad House",
|
||||
SquareFootage: &sqft,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("zero bedrooms accepted", func(t *testing.T) {
|
||||
bedrooms := 0
|
||||
req := requests.CreateResidenceRequest{
|
||||
Name: "Studio Apartment",
|
||||
Bedrooms: &bedrooms,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -23,7 +22,10 @@ func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *
|
||||
|
||||
// GetSubscription handles GET /api/subscription/
|
||||
func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.GetSubscription(user.ID)
|
||||
if err != nil {
|
||||
@@ -35,7 +37,10 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
|
||||
|
||||
// GetSubscriptionStatus handles GET /api/subscription/status/
|
||||
func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
|
||||
if err != nil {
|
||||
@@ -79,7 +84,10 @@ func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
|
||||
|
||||
// GetPromotions handles GET /api/subscription/promotions/
|
||||
func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
|
||||
if err != nil {
|
||||
@@ -91,15 +99,20 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
|
||||
|
||||
// ProcessPurchase handles POST /api/subscription/purchase/
|
||||
func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.ProcessPurchaseRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var subscription *services.SubscriptionResponse
|
||||
var err error
|
||||
|
||||
switch req.Platform {
|
||||
case "ios":
|
||||
@@ -129,7 +142,10 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
|
||||
|
||||
// CancelSubscription handles POST /api/subscription/cancel/
|
||||
func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
|
||||
if err != nil {
|
||||
@@ -144,16 +160,21 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
|
||||
|
||||
// RestoreSubscription handles POST /api/subscription/restore/
|
||||
func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var req services.ProcessPurchaseRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Same logic as ProcessPurchase - validates receipt/token and restores
|
||||
var subscription *services.SubscriptionResponse
|
||||
var err error
|
||||
|
||||
switch req.Platform {
|
||||
case "ios":
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/config"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
@@ -101,40 +101,39 @@ type AppleRenewalInfo struct {
|
||||
// HandleAppleWebhook handles POST /api/subscription/webhook/apple/
|
||||
func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
if !h.enabled {
|
||||
log.Printf("Apple Webhook: webhooks disabled by feature flag")
|
||||
log.Info().Msg("Apple Webhook: webhooks disabled by feature flag")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to read body: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||
}
|
||||
|
||||
var payload AppleNotificationPayload
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
log.Printf("Apple Webhook: Failed to parse payload: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to parse payload")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid payload"})
|
||||
}
|
||||
|
||||
// Decode and verify the signed payload (JWS)
|
||||
notification, err := h.decodeAppleSignedPayload(payload.SignedPayload)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to decode signed payload: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to decode signed payload")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signed payload"})
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: Received %s (subtype: %s) for bundle %s",
|
||||
notification.NotificationType, notification.Subtype, notification.Data.BundleID)
|
||||
log.Info().Str("type", notification.NotificationType).Str("subtype", notification.Subtype).Str("bundle", notification.Data.BundleID).Msg("Apple Webhook: Received notification")
|
||||
|
||||
// Dedup check using notificationUUID
|
||||
if notification.NotificationUUID != "" {
|
||||
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to check dedup: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to check dedup")
|
||||
// Continue processing on dedup check failure (fail-open)
|
||||
} else if alreadyProcessed {
|
||||
log.Printf("Apple Webhook: Duplicate event %s, skipping", notification.NotificationUUID)
|
||||
log.Info().Str("uuid", notification.NotificationUUID).Msg("Apple Webhook: Duplicate event, skipping")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
|
||||
}
|
||||
}
|
||||
@@ -143,8 +142,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
cfg := config.Get()
|
||||
if cfg != nil && cfg.AppleIAP.BundleID != "" {
|
||||
if notification.Data.BundleID != cfg.AppleIAP.BundleID {
|
||||
log.Printf("Apple Webhook: Bundle ID mismatch: got %s, expected %s",
|
||||
notification.Data.BundleID, cfg.AppleIAP.BundleID)
|
||||
log.Warn().Str("got", notification.Data.BundleID).Str("expected", cfg.AppleIAP.BundleID).Msg("Apple Webhook: Bundle ID mismatch")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "bundle ID mismatch"})
|
||||
}
|
||||
}
|
||||
@@ -152,7 +150,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
// Decode transaction info
|
||||
transactionInfo, err := h.decodeAppleTransaction(notification.Data.SignedTransactionInfo)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Failed to decode transaction: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to decode transaction")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid transaction info"})
|
||||
}
|
||||
|
||||
@@ -164,14 +162,14 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
|
||||
// Process the notification
|
||||
if err := h.processAppleNotification(notification, transactionInfo, renewalInfo); err != nil {
|
||||
log.Printf("Apple Webhook: Failed to process notification: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to process notification")
|
||||
// Still return 200 to prevent Apple from retrying
|
||||
}
|
||||
|
||||
// Record processed event for dedup
|
||||
if notification.NotificationUUID != "" {
|
||||
if err := h.webhookEventRepo.RecordEvent("apple", notification.NotificationUUID, notification.NotificationType, ""); err != nil {
|
||||
log.Printf("Apple Webhook: Failed to record event: %v", err)
|
||||
log.Error().Err(err).Msg("Apple Webhook: Failed to record event")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +177,8 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
|
||||
}
|
||||
|
||||
// decodeAppleSignedPayload decodes and verifies an Apple JWS payload
|
||||
// decodeAppleSignedPayload verifies and decodes an Apple JWS payload.
|
||||
// The JWS signature is verified before the payload is trusted.
|
||||
func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload string) (*AppleNotificationData, error) {
|
||||
// JWS format: header.payload.signature
|
||||
parts := strings.Split(signedPayload, ".")
|
||||
@@ -187,8 +186,11 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri
|
||||
return nil, fmt.Errorf("invalid JWS format")
|
||||
}
|
||||
|
||||
// Decode payload (we're trusting Apple's signature for now)
|
||||
// In production, you should verify the signature using Apple's root certificate
|
||||
// Verify the JWS signature before trusting the payload.
|
||||
if err := h.VerifyAppleSignature(signedPayload); err != nil {
|
||||
return nil, fmt.Errorf("Apple JWS signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode payload: %w", err)
|
||||
@@ -251,14 +253,12 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
|
||||
// Find user by stored receipt data (original transaction ID)
|
||||
user, err := h.findUserByAppleTransaction(transaction.OriginalTransactionID)
|
||||
if err != nil {
|
||||
log.Printf("Apple Webhook: Could not find user for transaction %s: %v",
|
||||
transaction.OriginalTransactionID, err)
|
||||
log.Warn().Err(err).Str("transaction_id", transaction.OriginalTransactionID).Msg("Apple Webhook: Could not find user for transaction")
|
||||
// Not an error - might be a transaction we don't track
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: Processing %s for user %d (product: %s)",
|
||||
notification.NotificationType, user.ID, transaction.ProductID)
|
||||
log.Info().Str("type", notification.NotificationType).Uint("user_id", user.ID).Str("product", transaction.ProductID).Msg("Apple Webhook: Processing notification")
|
||||
|
||||
switch notification.NotificationType {
|
||||
case "SUBSCRIBED":
|
||||
@@ -294,7 +294,7 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
|
||||
return h.handleAppleGracePeriodExpired(user.ID, transaction)
|
||||
|
||||
default:
|
||||
log.Printf("Apple Webhook: Unhandled notification type: %s", notification.NotificationType)
|
||||
log.Warn().Str("type", notification.NotificationType).Msg("Apple Webhook: Unhandled notification type")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -326,7 +326,7 @@ func (h *SubscriptionWebhookHandler) handleAppleSubscribed(userID uint, tx *Appl
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d subscribed, expires %v, autoRenew=%v", userID, expiresAt, autoRenew)
|
||||
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Bool("auto_renew", autoRenew).Msg("Apple Webhook: User subscribed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -337,7 +337,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewed(userID uint, tx *AppleTr
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d renewed, new expiry %v", userID, expiresAt)
|
||||
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Msg("Apple Webhook: User renewed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -357,13 +357,13 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
|
||||
if err := h.subscriptionRepo.SetCancelledAt(userID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Apple Webhook: User %d turned off auto-renew, will expire at end of period", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned off auto-renew, will expire at end of period")
|
||||
} else {
|
||||
// User turned auto-renew back on
|
||||
if err := h.subscriptionRepo.ClearCancelledAt(userID); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("Apple Webhook: User %d turned auto-renew back on", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned auto-renew back on")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -371,7 +371,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *AppleTransactionInfo, renewal *AppleRenewalInfo) error {
|
||||
// Subscription is in billing retry or grace period
|
||||
log.Printf("Apple Webhook: User %d failed to renew, may be in grace period", userID)
|
||||
log.Warn().Uint("user_id", userID).Msg("Apple Webhook: User failed to renew, may be in grace period")
|
||||
// Don't downgrade yet - Apple may retry billing
|
||||
return nil
|
||||
}
|
||||
@@ -381,7 +381,7 @@ func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTr
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d subscription expired, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -390,7 +390,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTra
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d got refund, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -399,7 +399,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTra
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d subscription revoked, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -408,7 +408,7 @@ func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint,
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Apple Webhook: User %d grace period expired, downgraded to free", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -481,32 +481,32 @@ const (
|
||||
// HandleGoogleWebhook handles POST /api/subscription/webhook/google/
|
||||
func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
if !h.enabled {
|
||||
log.Printf("Google Webhook: webhooks disabled by feature flag")
|
||||
log.Info().Msg("Google Webhook: webhooks disabled by feature flag")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Failed to read body: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
|
||||
}
|
||||
|
||||
var notification GoogleNotification
|
||||
if err := json.Unmarshal(body, ¬ification); err != nil {
|
||||
log.Printf("Google Webhook: Failed to parse notification: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to parse notification")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid notification"})
|
||||
}
|
||||
|
||||
// Decode the base64 data
|
||||
data, err := base64.StdEncoding.DecodeString(notification.Message.Data)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Failed to decode message data: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to decode message data")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid message data"})
|
||||
}
|
||||
|
||||
var devNotification GoogleDeveloperNotification
|
||||
if err := json.Unmarshal(data, &devNotification); err != nil {
|
||||
log.Printf("Google Webhook: Failed to parse developer notification: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to parse developer notification")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid developer notification"})
|
||||
}
|
||||
|
||||
@@ -515,17 +515,17 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
if messageID != "" {
|
||||
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Failed to check dedup: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to check dedup")
|
||||
// Continue processing on dedup check failure (fail-open)
|
||||
} else if alreadyProcessed {
|
||||
log.Printf("Google Webhook: Duplicate event %s, skipping", messageID)
|
||||
log.Info().Str("message_id", messageID).Msg("Google Webhook: Duplicate event, skipping")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle test notification
|
||||
if devNotification.TestNotification != nil {
|
||||
log.Printf("Google Webhook: Received test notification")
|
||||
log.Info().Msg("Google Webhook: Received test notification")
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"status": "test received"})
|
||||
}
|
||||
|
||||
@@ -533,8 +533,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
cfg := config.Get()
|
||||
if cfg != nil && cfg.GoogleIAP.PackageName != "" {
|
||||
if devNotification.PackageName != cfg.GoogleIAP.PackageName {
|
||||
log.Printf("Google Webhook: Package name mismatch: got %s, expected %s",
|
||||
devNotification.PackageName, cfg.GoogleIAP.PackageName)
|
||||
log.Warn().Str("got", devNotification.PackageName).Str("expected", cfg.GoogleIAP.PackageName).Msg("Google Webhook: Package name mismatch")
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "package name mismatch"})
|
||||
}
|
||||
}
|
||||
@@ -542,7 +541,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
// Process subscription notification
|
||||
if devNotification.SubscriptionNotification != nil {
|
||||
if err := h.processGoogleSubscriptionNotification(devNotification.SubscriptionNotification); err != nil {
|
||||
log.Printf("Google Webhook: Failed to process notification: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to process notification")
|
||||
// Still return 200 to acknowledge
|
||||
}
|
||||
}
|
||||
@@ -554,7 +553,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
|
||||
eventType = fmt.Sprintf("subscription_%d", devNotification.SubscriptionNotification.NotificationType)
|
||||
}
|
||||
if err := h.webhookEventRepo.RecordEvent("google", messageID, eventType, ""); err != nil {
|
||||
log.Printf("Google Webhook: Failed to record event: %v", err)
|
||||
log.Error().Err(err).Msg("Google Webhook: Failed to record event")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -567,12 +566,11 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
|
||||
// Find user by purchase token
|
||||
user, err := h.findUserByGoogleToken(notification.PurchaseToken)
|
||||
if err != nil {
|
||||
log.Printf("Google Webhook: Could not find user for token: %v", err)
|
||||
log.Warn().Err(err).Msg("Google Webhook: Could not find user for token")
|
||||
return nil // Not an error - might be unknown token
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: Processing type %d for user %d (subscription: %s)",
|
||||
notification.NotificationType, user.ID, notification.SubscriptionID)
|
||||
log.Info().Int("type", notification.NotificationType).Uint("user_id", user.ID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: Processing notification")
|
||||
|
||||
switch notification.NotificationType {
|
||||
case GoogleSubPurchased:
|
||||
@@ -606,7 +604,7 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
|
||||
return h.handleGooglePaused(user.ID, notification)
|
||||
|
||||
default:
|
||||
log.Printf("Google Webhook: Unhandled notification type: %d", notification.NotificationType)
|
||||
log.Warn().Int("type", notification.NotificationType).Msg("Google Webhook: Unhandled notification type")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -629,7 +627,7 @@ func (h *SubscriptionWebhookHandler) findUserByGoogleToken(purchaseToken string)
|
||||
func (h *SubscriptionWebhookHandler) handleGooglePurchased(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// New subscription - we should have already processed this via the client
|
||||
// This is a backup notification
|
||||
log.Printf("Google Webhook: User %d purchased subscription %s", userID, notification.SubscriptionID)
|
||||
log.Info().Uint("user_id", userID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: User purchased subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -648,7 +646,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRenewed(userID uint, notificati
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d renewed, extended to %v", userID, newExpiry)
|
||||
log.Info().Uint("user_id", userID).Time("expires", newExpiry).Msg("Google Webhook: User renewed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -659,7 +657,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRecovered(userID uint, notifica
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d subscription recovered", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription recovered")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -673,19 +671,19 @@ func (h *SubscriptionWebhookHandler) handleGoogleCanceled(userID uint, notificat
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d canceled, will expire at end of period", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User canceled, will expire at end of period")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleOnHold(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Account hold - payment issue, may recover
|
||||
log.Printf("Google Webhook: User %d subscription on hold", userID)
|
||||
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User subscription on hold")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGoogleGracePeriod(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// In grace period - user still has access but billing failed
|
||||
log.Printf("Google Webhook: User %d in grace period", userID)
|
||||
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User in grace period")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -702,7 +700,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d restarted subscription", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User restarted subscription")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -712,7 +710,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notificati
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d subscription revoked", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -722,13 +720,13 @@ func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notificati
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("Google Webhook: User %d subscription expired", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notification *GoogleSubscriptionNotification) error {
|
||||
// Subscription paused by user
|
||||
log.Printf("Google Webhook: User %d subscription paused", userID)
|
||||
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription paused")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -736,18 +734,21 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
|
||||
// Signature Verification (Optional but Recommended)
|
||||
// ====================
|
||||
|
||||
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate
|
||||
// This is optional but recommended for production
|
||||
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate.
|
||||
// If root certificates are not loaded, verification fails (deny by default).
|
||||
func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string) error {
|
||||
// Load Apple's root certificate if not already loaded
|
||||
// Deny by default when root certificates are not loaded.
|
||||
if h.appleRootCerts == nil {
|
||||
// Apple's root certificates can be downloaded from:
|
||||
// https://www.apple.com/certificateauthority/
|
||||
// You'd typically embed these or load from a file
|
||||
return nil // Skip verification for now
|
||||
return fmt.Errorf("Apple root certificates not configured: cannot verify JWS signature")
|
||||
}
|
||||
|
||||
// Parse the JWS token
|
||||
// Build a certificate pool from the loaded Apple root certificates
|
||||
rootPool := x509.NewCertPool()
|
||||
for _, cert := range h.appleRootCerts {
|
||||
rootPool.AddCert(cert)
|
||||
}
|
||||
|
||||
// Parse the JWS token and verify the signature using the x5c certificate chain
|
||||
token, err := jwt.Parse(signedPayload, func(token *jwt.Token) (interface{}, error) {
|
||||
// Get the x5c header (certificate chain)
|
||||
x5c, ok := token.Header["x5c"].([]interface{})
|
||||
@@ -755,21 +756,46 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
|
||||
return nil, fmt.Errorf("missing x5c header")
|
||||
}
|
||||
|
||||
// Decode the first certificate (leaf)
|
||||
// Decode the leaf certificate
|
||||
certData, err := base64.StdEncoding.DecodeString(x5c[0].(string))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certData)
|
||||
leafCert, err := x509.ParseCertificate(certData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate: %w", err)
|
||||
}
|
||||
|
||||
// Verify the certificate chain (simplified)
|
||||
// In production, you should verify the full chain
|
||||
// Build intermediate pool from remaining x5c entries
|
||||
intermediatePool := x509.NewCertPool()
|
||||
for i := 1; i < len(x5c); i++ {
|
||||
intermData, err := base64.StdEncoding.DecodeString(x5c[i].(string))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode intermediate certificate: %w", err)
|
||||
}
|
||||
intermCert, err := x509.ParseCertificate(intermData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse intermediate certificate: %w", err)
|
||||
}
|
||||
intermediatePool.AddCert(intermCert)
|
||||
}
|
||||
|
||||
return cert.PublicKey.(*ecdsa.PublicKey), nil
|
||||
// Verify the certificate chain against Apple's root certificates
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootPool,
|
||||
Intermediates: intermediatePool,
|
||||
}
|
||||
if _, err := leafCert.Verify(opts); err != nil {
|
||||
return nil, fmt.Errorf("certificate chain verification failed: %w", err)
|
||||
}
|
||||
|
||||
ecdsaKey, ok := leafCert.PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("leaf certificate public key is not ECDSA")
|
||||
}
|
||||
|
||||
return ecdsaKey, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@@ -783,13 +809,58 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyGooglePubSubToken verifies the Pub/Sub push token (if configured)
|
||||
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
|
||||
// Returns false (deny) when the Authorization header is missing or the token
|
||||
// cannot be validated. This prevents unauthenticated callers from injecting
|
||||
// webhook events.
|
||||
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
|
||||
// If you configured a push endpoint with authentication, verify here
|
||||
// The token is typically in the Authorization header
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
log.Warn().Msg("Google Webhook: missing Authorization header")
|
||||
return false
|
||||
}
|
||||
|
||||
// Expect "Bearer <token>" format
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
log.Warn().Msg("Google Webhook: Authorization header is not Bearer token")
|
||||
return false
|
||||
}
|
||||
|
||||
bearerToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if bearerToken == "" {
|
||||
log.Warn().Msg("Google Webhook: empty Bearer token")
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs
|
||||
// issued by accounts.google.com. We verify the claims to ensure the
|
||||
// token was intended for our service.
|
||||
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token")
|
||||
return false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
log.Warn().Msg("Google Webhook: invalid token claims")
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify issuer is Google
|
||||
issuer, _ := claims.GetIssuer()
|
||||
if issuer != "accounts.google.com" && issuer != "https://accounts.google.com" {
|
||||
log.Warn().Str("issuer", issuer).Msg("Google Webhook: unexpected issuer")
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the email claim matches a Google service account
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !strings.HasSuffix(email, ".gserviceaccount.com") {
|
||||
log.Warn().Str("email", email).Msg("Google Webhook: token email is not a Google service account")
|
||||
return false
|
||||
}
|
||||
|
||||
// For now, we rely on the endpoint being protected by your infrastructure
|
||||
// (e.g., only accessible from Google's IP ranges)
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
56
internal/handlers/subscription_webhook_handler_test.go
Normal file
56
internal/handlers/subscription_webhook_handler_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVerifyGooglePubSubToken_MissingAuth_ReturnsFalse(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
e := echo.New()
|
||||
// Request with no Authorization header
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
result := handler.VerifyGooglePubSubToken(c)
|
||||
assert.False(t, result, "VerifyGooglePubSubToken should return false when Authorization header is missing")
|
||||
}
|
||||
|
||||
func TestVerifyGooglePubSubToken_InvalidToken_ReturnsFalse(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-garbage-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
result := handler.VerifyGooglePubSubToken(c)
|
||||
assert.False(t, result, "VerifyGooglePubSubToken should return false for an invalid/unverifiable token")
|
||||
}
|
||||
|
||||
func TestDecodeAppleSignedPayload_InvalidJWS_ReturnsError(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
// No signature parts
|
||||
_, err := handler.decodeAppleSignedPayload("not-a-jws")
|
||||
assert.Error(t, err, "should reject payload that is not valid JWS format")
|
||||
}
|
||||
|
||||
func TestDecodeAppleSignedPayload_VerificationFails_ReturnsError(t *testing.T) {
|
||||
handler := &SubscriptionWebhookHandler{enabled: true}
|
||||
|
||||
// Construct a JWS-shaped string with 3 parts but no valid signature.
|
||||
// The handler should now attempt verification and fail.
|
||||
// header.payload.signature -- all base64url garbage
|
||||
fakeJWS := "eyJhbGciOiJFUzI1NiJ9.eyJ0ZXN0IjoidHJ1ZSJ9.invalidsig"
|
||||
|
||||
_, err := handler.decodeAppleSignedPayload(fakeJWS)
|
||||
assert.Error(t, err, "should return error when Apple signature verification fails")
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/dto/requests"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -32,13 +31,16 @@ func NewTaskHandler(taskService *services.TaskService, storageService *services.
|
||||
|
||||
// ListTasks handles GET /api/tasks/
|
||||
func (h *TaskHandler) ListTasks(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
|
||||
// This runs in a goroutine to avoid blocking the response
|
||||
// Runs synchronously — this is a lightweight DB upsert that should complete quickly
|
||||
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
|
||||
go h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
||||
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
|
||||
}
|
||||
|
||||
daysThreshold := 30
|
||||
@@ -62,7 +64,10 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
|
||||
|
||||
// GetTask handles GET /api/tasks/:id/
|
||||
func (h *TaskHandler) GetTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -77,7 +82,10 @@ func (h *TaskHandler) GetTask(c echo.Context) error {
|
||||
|
||||
// GetTasksByResidence handles GET /api/tasks/by-residence/:residence_id/
|
||||
func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
|
||||
@@ -106,13 +114,19 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
|
||||
|
||||
// CreateTask handles POST /api/tasks/
|
||||
func (h *TaskHandler) CreateTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
var req requests.CreateTaskRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.CreateTask(&req, user.ID, userNow)
|
||||
if err != nil {
|
||||
@@ -123,7 +137,10 @@ func (h *TaskHandler) CreateTask(c echo.Context) error {
|
||||
|
||||
// UpdateTask handles PUT/PATCH /api/tasks/:id/
|
||||
func (h *TaskHandler) UpdateTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -135,6 +152,9 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.UpdateTask(uint(taskID), user.ID, &req, userNow)
|
||||
if err != nil {
|
||||
@@ -145,7 +165,10 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
|
||||
|
||||
// DeleteTask handles DELETE /api/tasks/:id/
|
||||
func (h *TaskHandler) DeleteTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -160,7 +183,10 @@ func (h *TaskHandler) DeleteTask(c echo.Context) error {
|
||||
|
||||
// MarkInProgress handles POST /api/tasks/:id/mark-in-progress/
|
||||
func (h *TaskHandler) MarkInProgress(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -177,7 +203,10 @@ func (h *TaskHandler) MarkInProgress(c echo.Context) error {
|
||||
|
||||
// CancelTask handles POST /api/tasks/:id/cancel/
|
||||
func (h *TaskHandler) CancelTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -194,7 +223,10 @@ func (h *TaskHandler) CancelTask(c echo.Context) error {
|
||||
|
||||
// UncancelTask handles POST /api/tasks/:id/uncancel/
|
||||
func (h *TaskHandler) UncancelTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -211,7 +243,10 @@ func (h *TaskHandler) UncancelTask(c echo.Context) error {
|
||||
|
||||
// ArchiveTask handles POST /api/tasks/:id/archive/
|
||||
func (h *TaskHandler) ArchiveTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -228,7 +263,10 @@ func (h *TaskHandler) ArchiveTask(c echo.Context) error {
|
||||
|
||||
// UnarchiveTask handles POST /api/tasks/:id/unarchive/
|
||||
func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
@@ -246,7 +284,10 @@ func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
|
||||
// QuickComplete handles POST /api/tasks/:id/quick-complete/
|
||||
// Lightweight endpoint for widget - just returns 200 OK on success
|
||||
func (h *TaskHandler) QuickComplete(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -263,7 +304,10 @@ func (h *TaskHandler) QuickComplete(c echo.Context) error {
|
||||
|
||||
// GetTaskCompletions handles GET /api/tasks/:id/completions/
|
||||
func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_task_id")
|
||||
@@ -278,7 +322,10 @@ func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
|
||||
|
||||
// ListCompletions handles GET /api/task-completions/
|
||||
func (h *TaskHandler) ListCompletions(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := h.taskService.ListCompletions(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -288,7 +335,10 @@ func (h *TaskHandler) ListCompletions(c echo.Context) error {
|
||||
|
||||
// GetCompletion handles GET /api/task-completions/:id/
|
||||
func (h *TaskHandler) GetCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_completion_id")
|
||||
@@ -304,7 +354,10 @@ func (h *TaskHandler) GetCompletion(c echo.Context) error {
|
||||
// CreateCompletion handles POST /api/task-completions/
|
||||
// Supports both JSON and multipart form data (for image uploads)
|
||||
func (h *TaskHandler) CreateCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userNow := middleware.GetUserNow(c)
|
||||
|
||||
var req requests.CreateTaskCompletionRequest
|
||||
@@ -367,6 +420,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.CreateCompletion(&req, user.ID, userNow)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -376,7 +433,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
|
||||
|
||||
// UpdateCompletion handles PUT /api/task-completions/:id/
|
||||
func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_completion_id")
|
||||
@@ -386,6 +446,9 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := h.taskService.UpdateCompletion(uint(completionID), user.ID, &req)
|
||||
if err != nil {
|
||||
@@ -396,7 +459,10 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
|
||||
|
||||
// DeleteCompletion handles DELETE /api/task-completions/:id/
|
||||
func (h *TaskHandler) DeleteCompletion(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
return apperrors.BadRequest("error.invalid_completion_id")
|
||||
|
||||
@@ -506,6 +506,52 @@ func TestTaskHandler_CreateCompletion(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateCompletion_Rating6_Returns400(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Rate Me")
|
||||
|
||||
authGroup := e.Group("/api/task-completions")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateCompletion)
|
||||
|
||||
t.Run("rating out of bounds rejected", func(t *testing.T) {
|
||||
rating := 6
|
||||
req := requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("rating zero rejected", func(t *testing.T) {
|
||||
rating := 0
|
||||
req := requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("rating 5 accepted", func(t *testing.T) {
|
||||
rating := 5
|
||||
completedAt := time.Now().UTC()
|
||||
req := requests.CreateTaskCompletionRequest{
|
||||
TaskID: task.ID,
|
||||
CompletedAt: &completedAt,
|
||||
Rating: &rating,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_ListCompletions(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
@@ -603,6 +649,71 @@ func TestTaskHandler_DeleteCompletion(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_CreateTask_EmptyTitle_Returns400(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/tasks")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateTask)
|
||||
|
||||
t.Run("empty body returns 400 with validation errors", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/tasks/", map[string]interface{}{}, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should contain structured validation error
|
||||
assert.Contains(t, response, "error")
|
||||
assert.Contains(t, response, "fields")
|
||||
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
|
||||
assert.Contains(t, fields, "title", "validation error should reference 'title'")
|
||||
})
|
||||
|
||||
t.Run("missing title returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"residence_id": residence.ID,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "fields")
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "title", "validation error should reference 'title'")
|
||||
})
|
||||
|
||||
t.Run("missing residence_id returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"title": "Test Task",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "fields")
|
||||
fields := response["fields"].(map[string]interface{})
|
||||
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
|
||||
})
|
||||
}
|
||||
|
||||
func TestTaskHandler_GetLookups(t *testing.T) {
|
||||
handler, e, db := setupTaskHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
@@ -32,7 +33,14 @@ func (h *TrackingHandler) TrackEmailOpen(c echo.Context) error {
|
||||
if trackingID != "" && h.onboardingService != nil {
|
||||
// Record the open (async, don't block response)
|
||||
go func() {
|
||||
_ = h.onboardingService.RecordEmailOpened(trackingID)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().Interface("panic", r).Str("tracking_id", trackingID).Msg("Panic in email open tracking goroutine")
|
||||
}
|
||||
}()
|
||||
if err := h.onboardingService.RecordEmailOpened(trackingID); err != nil {
|
||||
log.Error().Err(err).Str("tracking_id", trackingID).Msg("Failed to record email open")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,11 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -73,17 +76,38 @@ func (h *UploadHandler) UploadCompletion(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// DeleteFileRequest is the request body for deleting a file.
|
||||
type DeleteFileRequest struct {
|
||||
URL string `json:"url" validate:"required"`
|
||||
}
|
||||
|
||||
// DeleteFile handles DELETE /api/uploads
|
||||
// Expects JSON body with "url" field
|
||||
// Expects JSON body with "url" field.
|
||||
//
|
||||
// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete
|
||||
// any file if they know the URL. The upload system does not track which user uploaded
|
||||
// which file, so a proper fix requires adding an uploads table or file ownership metadata.
|
||||
// For now, deletions are logged with user ID for audit trail, and StorageService.Delete
|
||||
// enforces path containment to prevent deleting files outside the upload directory.
|
||||
func (h *UploadHandler) DeleteFile(c echo.Context) error {
|
||||
var req struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
}
|
||||
var req DeleteFileRequest
|
||||
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return apperrors.BadRequest("error.invalid_request")
|
||||
}
|
||||
|
||||
if err := c.Validate(&req); err != nil {
|
||||
return apperrors.BadRequest("error.url_required")
|
||||
}
|
||||
|
||||
// Log the deletion with user ID for audit trail
|
||||
if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok {
|
||||
log.Info().
|
||||
Uint("user_id", user.ID).
|
||||
Str("file_url", req.URL).
|
||||
Msg("File deletion requested")
|
||||
}
|
||||
|
||||
if err := h.storageService.Delete(req.URL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
43
internal/handlers/upload_handler_test.go
Normal file
43
internal/handlers/upload_handler_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/casera-api/internal/i18n"
|
||||
"github.com/treytartt/casera-api/internal/testutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Initialize i18n so the custom error handler can localize error messages.
|
||||
// Other handler tests get this from testutil.SetupTestDB, but these tests
|
||||
// don't need a database.
|
||||
i18n.Init()
|
||||
}
|
||||
|
||||
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
|
||||
// Use a test storage service — DeleteFile won't reach storage since validation fails first
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
handler := NewUploadHandler(storageSvc)
|
||||
|
||||
e := testutil.SetupTestRouter()
|
||||
|
||||
// Register route
|
||||
e.DELETE("/api/uploads/", handler.DeleteFile)
|
||||
|
||||
// Send request with empty JSON body (url field missing)
|
||||
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
|
||||
storageSvc := newTestStorageService("/var/uploads")
|
||||
handler := NewUploadHandler(storageSvc)
|
||||
|
||||
e := testutil.SetupTestRouter()
|
||||
e.DELETE("/api/uploads/", handler.DeleteFile)
|
||||
|
||||
// Send request with empty url field
|
||||
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/treytartt/casera-api/internal/apperrors"
|
||||
"github.com/treytartt/casera-api/internal/middleware"
|
||||
"github.com/treytartt/casera-api/internal/models"
|
||||
"github.com/treytartt/casera-api/internal/services"
|
||||
)
|
||||
|
||||
@@ -26,7 +25,10 @@ func NewUserHandler(userService *services.UserService) *UserHandler {
|
||||
|
||||
// ListUsers handles GET /api/users/
|
||||
func (h *UserHandler) ListUsers(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Only allow listing users that share residences with the current user
|
||||
users, err := h.userService.ListUsersInSharedResidences(user.ID)
|
||||
@@ -42,7 +44,10 @@ func (h *UserHandler) ListUsers(c echo.Context) error {
|
||||
|
||||
// GetUser handles GET /api/users/:id/
|
||||
func (h *UserHandler) GetUser(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
||||
if err != nil {
|
||||
@@ -60,7 +65,10 @@ func (h *UserHandler) GetUser(c echo.Context) error {
|
||||
|
||||
// ListProfiles handles GET /api/users/profiles/
|
||||
func (h *UserHandler) ListProfiles(c echo.Context) error {
|
||||
user := c.Get(middleware.AuthUserKey).(*models.User)
|
||||
user, err := middleware.MustGetAuthUser(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// List profiles of users in shared residences
|
||||
profiles, err := h.userService.ListProfilesInSharedResidences(user.ID)
|
||||
|
||||
Reference in New Issue
Block a user