Harden API security: input validation, safe auth extraction, new tests, and deploy config

Comprehensive security hardening from audit findings:
- Add validation tags to all DTO request structs (max lengths, ranges, enums)
- Replace unsafe type assertions with MustGetAuthUser helper across all handlers
- Remove query-param token auth from admin middleware (prevents URL token leakage)
- Add request validation calls in handlers that were missing c.Validate()
- Remove goroutines in handlers (timezone update now synchronous)
- Add sanitize middleware and path traversal protection (path_utils)
- Stop resetting admin passwords on migration restart
- Warn on well-known default SECRET_KEY
- Add ~30 new test files covering security regressions, auth safety, repos, and services
- Add deploy/ config, audit digests, and AUDIT_FINDINGS documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Trey t
2026-03-02 09:48:01 -06:00
parent 56d6fa4514
commit 7690f07a2b
123 changed files with 8321 additions and 750 deletions

View File

@@ -81,6 +81,11 @@ func (h *AuthHandler) Register(c echo.Context) error {
// Send welcome email with confirmation code (async)
if h.emailService != nil && confirmationCode != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", req.Email).Msg("Panic in welcome email goroutine")
}
}()
if err := h.emailService.SendWelcomeEmail(req.Email, req.FirstName, confirmationCode); err != nil {
log.Error().Err(err).Str("email", req.Email).Msg("Failed to send welcome email")
}
@@ -176,6 +181,11 @@ func (h *AuthHandler) VerifyEmail(c echo.Context) error {
// Send post-verification welcome email with tips (async)
if h.emailService != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in post-verification email goroutine")
}
}()
if err := h.emailService.SendPostVerificationEmail(user.Email, user.FirstName); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send post-verification email")
}
@@ -204,6 +214,11 @@ func (h *AuthHandler) ResendVerification(c echo.Context) error {
// Send verification email (async)
if h.emailService != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in verification email goroutine")
}
}()
if err := h.emailService.SendVerificationEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send verification email")
}
@@ -238,6 +253,11 @@ func (h *AuthHandler) ForgotPassword(c echo.Context) error {
// Send password reset email (async) - only if user found
if h.emailService != nil && code != "" && user != nil {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", user.Email).Msg("Panic in password reset email goroutine")
}
}()
if err := h.emailService.SendPasswordResetEmail(user.Email, user.FirstName, code); err != nil {
log.Error().Err(err).Str("email", user.Email).Msg("Failed to send password reset email")
}
@@ -326,6 +346,11 @@ func (h *AuthHandler) AppleSignIn(c echo.Context) error {
// Send welcome email for new users (async)
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Apple welcome email goroutine")
}
}()
if err := h.emailService.SendAppleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Apple welcome email")
}
@@ -368,6 +393,11 @@ func (h *AuthHandler) GoogleSignIn(c echo.Context) error {
// Send welcome email for new users (async)
if response.IsNewUser && h.emailService != nil && response.User.Email != "" {
go func() {
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("email", response.User.Email).Msg("Panic in Google welcome email goroutine")
}
}()
if err := h.emailService.SendGoogleWelcomeEmail(response.User.Email, response.User.FirstName); err != nil {
log.Error().Err(err).Str("email", response.User.Email).Msg("Failed to send Google welcome email")
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -25,17 +24,23 @@ func NewContractorHandler(contractorService *services.ContractorService) *Contra
// ListContractors handles GET /api/contractors/
func (h *ContractorHandler) ListContractors(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.contractorService.ListContractors(user.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return apperrors.Internal(err)
}
return c.JSON(http.StatusOK, response)
}
// GetContractor handles GET /api/contractors/:id/
func (h *ContractorHandler) GetContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -50,11 +55,17 @@ func (h *ContractorHandler) GetContractor(c echo.Context) error {
// CreateContractor handles POST /api/contractors/
func (h *ContractorHandler) CreateContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.CreateContractorRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.contractorService.CreateContractor(&req, user.ID)
if err != nil {
@@ -65,7 +76,10 @@ func (h *ContractorHandler) CreateContractor(c echo.Context) error {
// UpdateContractor handles PUT/PATCH /api/contractors/:id/
func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -75,6 +89,9 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.contractorService.UpdateContractor(uint(contractorID), user.ID, &req)
if err != nil {
@@ -85,7 +102,10 @@ func (h *ContractorHandler) UpdateContractor(c echo.Context) error {
// DeleteContractor handles DELETE /api/contractors/:id/
func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -100,7 +120,10 @@ func (h *ContractorHandler) DeleteContractor(c echo.Context) error {
// ToggleFavorite handles POST /api/contractors/:id/toggle-favorite/
func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -115,7 +138,10 @@ func (h *ContractorHandler) ToggleFavorite(c echo.Context) error {
// GetContractorTasks handles GET /api/contractors/:id/tasks/
func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
contractorID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_contractor_id")
@@ -130,7 +156,10 @@ func (h *ContractorHandler) GetContractorTasks(c echo.Context) error {
// ListContractorsByResidence handles GET /api/contractors/by-residence/:residence_id/
func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_residence_id")
@@ -147,7 +176,7 @@ func (h *ContractorHandler) ListContractorsByResidence(c echo.Context) error {
func (h *ContractorHandler) GetSpecialties(c echo.Context) error {
specialties, err := h.contractorService.GetSpecialties()
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return apperrors.Internal(err)
}
return c.JSON(http.StatusOK, specialties)
}

View File

@@ -0,0 +1,182 @@
package handlers
import (
"encoding/json"
"net/http"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
)
func setupContractorHandler(t *testing.T) (*ContractorHandler, *echo.Echo, *gorm.DB) {
db := testutil.SetupTestDB(t)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
handler := NewContractorHandler(contractorService)
e := testutil.SetupTestRouter()
return handler, e, db
}
func TestContractorHandler_CreateContractor_MissingName_Returns400(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateContractor)
t.Run("missing name returns 400 validation error", func(t *testing.T) {
// Send request with no name (required field)
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Should contain structured validation error
assert.Contains(t, response, "error")
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "name", "validation error should reference the 'name' field")
})
t.Run("empty body returns 400 validation error", func(t *testing.T) {
// Send completely empty body
w := testutil.MakeRequest(e, "POST", "/api/contractors/", map[string]interface{}{}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "error")
})
t.Run("valid contractor creation succeeds", func(t *testing.T) {
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "John the Plumber",
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}
func TestContractorHandler_ListContractors_Error_NoRawErrorInResponse(t *testing.T) {
_, e, db := setupContractorHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
testutil.CreateTestResidence(t, db, user.ID, "Test House")
// Create a handler with a broken service to simulate an internal error.
// We do this by closing the underlying SQL connection, which will cause
// the service to return an error on the next query.
brokenDB := testutil.SetupTestDB(t)
sqlDB, _ := brokenDB.DB()
sqlDB.Close()
brokenContractorRepo := repositories.NewContractorRepository(brokenDB)
brokenResidenceRepo := repositories.NewResidenceRepository(brokenDB)
brokenService := services.NewContractorService(brokenContractorRepo, brokenResidenceRepo)
brokenHandler := NewContractorHandler(brokenService)
authGroup := e.Group("/api/broken-contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", brokenHandler.ListContractors)
t.Run("internal error does not leak raw error message", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/broken-contractors/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusInternalServerError)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Should contain the generic error key, NOT a raw database error
errorMsg, ok := response["error"].(string)
require.True(t, ok, "response should have an 'error' string field")
// Must not contain database-specific details
assert.NotContains(t, errorMsg, "sql", "error message should not leak SQL details")
assert.NotContains(t, errorMsg, "database", "error message should not leak database details")
assert.NotContains(t, errorMsg, "closed", "error message should not leak connection state")
})
}
func TestContractorHandler_CreateContractor_100Specialties_Returns400(t *testing.T) {
handler, e, db := setupContractorHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/contractors")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateContractor)
t.Run("too many specialties rejected", func(t *testing.T) {
// Create a slice with 100 specialty IDs (exceeds max=20)
specialtyIDs := make([]uint, 100)
for i := range specialtyIDs {
specialtyIDs[i] = uint(i + 1)
}
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Over-specialized Contractor",
SpecialtyIDs: specialtyIDs,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("20 specialties accepted", func(t *testing.T) {
specialtyIDs := make([]uint, 20)
for i := range specialtyIDs {
specialtyIDs[i] = uint(i + 1)
}
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Multi-skilled Contractor",
SpecialtyIDs: specialtyIDs,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
// Should pass validation (201 or success, not 400)
assert.NotEqual(t, http.StatusBadRequest, w.Code, "20 specialties should pass validation")
})
t.Run("rating above 5 rejected", func(t *testing.T) {
rating := 6.0
req := requests.CreateContractorRequest{
ResidenceID: &residence.ID,
Name: "Bad Rating Contractor",
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
}

View File

@@ -34,7 +34,10 @@ func NewDocumentHandler(documentService *services.DocumentService, storageServic
// ListDocuments handles GET /api/documents/
func (h *DocumentHandler) ListDocuments(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
// Build filter from supported query params.
var filter *repositories.DocumentFilter
@@ -71,7 +74,10 @@ func (h *DocumentHandler) ListDocuments(c echo.Context) error {
// GetDocument handles GET /api/documents/:id/
func (h *DocumentHandler) GetDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -86,10 +92,13 @@ func (h *DocumentHandler) GetDocument(c echo.Context) error {
// ListWarranties handles GET /api/documents/warranties/
func (h *DocumentHandler) ListWarranties(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.documentService.ListWarranties(user.ID)
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": err.Error()})
return apperrors.Internal(err)
}
return c.JSON(http.StatusOK, response)
}
@@ -97,7 +106,10 @@ func (h *DocumentHandler) ListWarranties(c echo.Context) error {
// CreateDocument handles POST /api/documents/
// Supports both JSON and multipart form data (for file uploads)
func (h *DocumentHandler) CreateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.CreateDocumentRequest
contentType := c.Request().Header.Get("Content-Type")
@@ -198,6 +210,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
}
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.documentService.CreateDocument(&req, user.ID)
if err != nil {
return err
@@ -207,7 +223,10 @@ func (h *DocumentHandler) CreateDocument(c echo.Context) error {
// UpdateDocument handles PUT/PATCH /api/documents/:id/
func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -217,6 +236,9 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.documentService.UpdateDocument(uint(documentID), user.ID, &req)
if err != nil {
@@ -227,7 +249,10 @@ func (h *DocumentHandler) UpdateDocument(c echo.Context) error {
// DeleteDocument handles DELETE /api/documents/:id/
func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -242,7 +267,10 @@ func (h *DocumentHandler) DeleteDocument(c echo.Context) error {
// ActivateDocument handles POST /api/documents/:id/activate/
func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -257,7 +285,10 @@ func (h *DocumentHandler) ActivateDocument(c echo.Context) error {
// DeactivateDocument handles POST /api/documents/:id/deactivate/
func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -272,7 +303,10 @@ func (h *DocumentHandler) DeactivateDocument(c echo.Context) error {
// UploadDocumentImage handles POST /api/documents/:id/images/
func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")
@@ -316,7 +350,10 @@ func (h *DocumentHandler) UploadDocumentImage(c echo.Context) error {
// DeleteDocumentImage handles DELETE /api/documents/:id/images/:imageId/
func (h *DocumentHandler) DeleteDocumentImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
documentID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_document_id")

View File

@@ -1,7 +1,6 @@
package handlers
import (
"path/filepath"
"strconv"
"strings"
@@ -9,7 +8,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
)
@@ -40,7 +38,10 @@ func NewMediaHandler(
// ServeDocument serves a document file with access control
// GET /api/media/document/:id
func (h *MediaHandler) ServeDocument(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -73,7 +74,10 @@ func (h *MediaHandler) ServeDocument(c echo.Context) error {
// ServeDocumentImage serves a document image with access control
// GET /api/media/document-image/:id
func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -111,7 +115,10 @@ func (h *MediaHandler) ServeDocumentImage(c echo.Context) error {
// ServeCompletionImage serves a task completion image with access control
// GET /api/media/completion-image/:id
func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
id, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -152,7 +159,9 @@ func (h *MediaHandler) ServeCompletionImage(c echo.Context) error {
return c.File(filePath)
}
// resolveFilePath converts a stored URL to an actual file path
// resolveFilePath converts a stored URL to an actual file path.
// Returns empty string if the URL is empty or the resolved path would escape
// the upload directory (path traversal attempt).
func (h *MediaHandler) resolveFilePath(storedURL string) string {
if storedURL == "" {
return ""
@@ -160,12 +169,18 @@ func (h *MediaHandler) resolveFilePath(storedURL string) string {
uploadDir := h.storageSvc.GetUploadDir()
// Handle legacy /uploads/... URLs
// Strip legacy /uploads/ prefix to get relative path
relativePath := storedURL
if strings.HasPrefix(storedURL, "/uploads/") {
relativePath := strings.TrimPrefix(storedURL, "/uploads/")
return filepath.Join(uploadDir, relativePath)
relativePath = strings.TrimPrefix(storedURL, "/uploads/")
}
// Handle relative paths (new format)
return filepath.Join(uploadDir, storedURL)
// Use SafeResolvePath to validate containment within upload directory
resolved, err := services.SafeResolvePath(uploadDir, relativePath)
if err != nil {
// Path traversal or invalid path — return empty to signal file not found
return ""
}
return resolved
}

View File

@@ -0,0 +1,74 @@
package handlers
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/services"
)
// newTestStorageService creates a StorageService with a known upload directory for testing.
// It does NOT call NewStorageService because that creates directories on disk.
// Instead, it directly constructs the struct with only what resolveFilePath needs.
func newTestStorageService(uploadDir string) *services.StorageService {
cfg := &config.StorageConfig{
UploadDir: uploadDir,
BaseURL: "/uploads",
MaxFileSize: 10 * 1024 * 1024,
AllowedTypes: "image/jpeg,image/png",
}
// Use the exported constructor helper that skips directory creation (for tests)
return services.NewStorageServiceForTest(cfg)
}
func TestResolveFilePath_NormalPath_Works(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
result := h.resolveFilePath("images/photo.jpg")
require.NotEmpty(t, result)
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
}
func TestResolveFilePath_LegacyUploadPath_Works(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
result := h.resolveFilePath("/uploads/images/photo.jpg")
require.NotEmpty(t, result)
assert.Equal(t, "/var/uploads/images/photo.jpg", result)
}
func TestResolveFilePath_DotDotTraversal_Blocked(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
tests := []struct {
name string
storedURL string
}{
{"simple dotdot", "../etc/passwd"},
{"nested dotdot", "../../etc/shadow"},
{"embedded dotdot", "images/../../etc/passwd"},
{"legacy prefix with dotdot", "/uploads/../../../etc/passwd"},
{"deep dotdot", "a/b/c/../../../../etc/passwd"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := h.resolveFilePath(tt.storedURL)
assert.Empty(t, result, "path traversal should return empty string for: %s", tt.storedURL)
})
}
}
func TestResolveFilePath_EmptyURL_ReturnsEmpty(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
h := NewMediaHandler(nil, nil, nil, storageSvc)
result := h.resolveFilePath("")
assert.Empty(t, result)
}

View File

@@ -0,0 +1,334 @@
package handlers
import (
"net/http"
"testing"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
)
// TestTaskHandler_NoAuth_Returns401 verifies that task handler endpoints return
// 401 Unauthorized when no auth user is set in the context (e.g., auth middleware
// misconfigured or bypassed). This is a regression test for P1-1 (SEC-19).
func TestTaskHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
taskRepo := repositories.NewTaskRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskService := services.NewTaskService(taskRepo, residenceRepo)
handler := NewTaskHandler(taskService, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/tasks/", handler.ListTasks)
e.GET("/api/tasks/:id/", handler.GetTask)
e.POST("/api/tasks/", handler.CreateTask)
e.PUT("/api/tasks/:id/", handler.UpdateTask)
e.DELETE("/api/tasks/:id/", handler.DeleteTask)
e.POST("/api/tasks/:id/cancel/", handler.CancelTask)
e.POST("/api/tasks/:id/mark-in-progress/", handler.MarkInProgress)
e.GET("/api/task-completions/", handler.ListCompletions)
e.POST("/api/task-completions/", handler.CreateCompletion)
tests := []struct {
name string
method string
path string
}{
{"ListTasks", "GET", "/api/tasks/"},
{"GetTask", "GET", "/api/tasks/1/"},
{"CreateTask", "POST", "/api/tasks/"},
{"UpdateTask", "PUT", "/api/tasks/1/"},
{"DeleteTask", "DELETE", "/api/tasks/1/"},
{"CancelTask", "POST", "/api/tasks/1/cancel/"},
{"MarkInProgress", "POST", "/api/tasks/1/mark-in-progress/"},
{"ListCompletions", "GET", "/api/task-completions/"},
{"CreateCompletion", "POST", "/api/task-completions/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestResidenceHandler_NoAuth_Returns401 verifies that residence handler endpoints
// return 401 Unauthorized when no auth user is set in the context.
func TestResidenceHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
residenceRepo := repositories.NewResidenceRepository(db)
userRepo := repositories.NewUserRepository(db)
cfg := &config.Config{}
residenceService := services.NewResidenceService(residenceRepo, userRepo, cfg)
handler := NewResidenceHandler(residenceService, nil, nil, true)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/residences/", handler.ListResidences)
e.GET("/api/residences/my/", handler.GetMyResidences)
e.GET("/api/residences/summary/", handler.GetSummary)
e.GET("/api/residences/:id/", handler.GetResidence)
e.POST("/api/residences/", handler.CreateResidence)
e.PUT("/api/residences/:id/", handler.UpdateResidence)
e.DELETE("/api/residences/:id/", handler.DeleteResidence)
e.POST("/api/residences/:id/generate-share-code/", handler.GenerateShareCode)
e.POST("/api/residences/join-with-code/", handler.JoinWithCode)
e.GET("/api/residences/:id/users/", handler.GetResidenceUsers)
tests := []struct {
name string
method string
path string
}{
{"ListResidences", "GET", "/api/residences/"},
{"GetMyResidences", "GET", "/api/residences/my/"},
{"GetSummary", "GET", "/api/residences/summary/"},
{"GetResidence", "GET", "/api/residences/1/"},
{"CreateResidence", "POST", "/api/residences/"},
{"UpdateResidence", "PUT", "/api/residences/1/"},
{"DeleteResidence", "DELETE", "/api/residences/1/"},
{"GenerateShareCode", "POST", "/api/residences/1/generate-share-code/"},
{"JoinWithCode", "POST", "/api/residences/join-with-code/"},
{"GetResidenceUsers", "GET", "/api/residences/1/users/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestNotificationHandler_NoAuth_Returns401 verifies that notification handler
// endpoints return 401 Unauthorized when no auth user is set in the context.
func TestNotificationHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
notificationRepo := repositories.NewNotificationRepository(db)
notificationService := services.NewNotificationService(notificationRepo, nil)
handler := NewNotificationHandler(notificationService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/notifications/", handler.ListNotifications)
e.GET("/api/notifications/unread-count/", handler.GetUnreadCount)
e.POST("/api/notifications/:id/read/", handler.MarkAsRead)
e.POST("/api/notifications/mark-all-read/", handler.MarkAllAsRead)
e.GET("/api/notifications/preferences/", handler.GetPreferences)
e.PUT("/api/notifications/preferences/", handler.UpdatePreferences)
e.POST("/api/notifications/devices/", handler.RegisterDevice)
e.GET("/api/notifications/devices/", handler.ListDevices)
tests := []struct {
name string
method string
path string
}{
{"ListNotifications", "GET", "/api/notifications/"},
{"GetUnreadCount", "GET", "/api/notifications/unread-count/"},
{"MarkAsRead", "POST", "/api/notifications/1/read/"},
{"MarkAllAsRead", "POST", "/api/notifications/mark-all-read/"},
{"GetPreferences", "GET", "/api/notifications/preferences/"},
{"UpdatePreferences", "PUT", "/api/notifications/preferences/"},
{"RegisterDevice", "POST", "/api/notifications/devices/"},
{"ListDevices", "GET", "/api/notifications/devices/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestDocumentHandler_NoAuth_Returns401 verifies that document handler endpoints
// return 401 Unauthorized when no auth user is set in the context.
func TestDocumentHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
documentRepo := repositories.NewDocumentRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
documentService := services.NewDocumentService(documentRepo, residenceRepo)
handler := NewDocumentHandler(documentService, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/documents/", handler.ListDocuments)
e.GET("/api/documents/:id/", handler.GetDocument)
e.GET("/api/documents/warranties/", handler.ListWarranties)
e.POST("/api/documents/", handler.CreateDocument)
e.PUT("/api/documents/:id/", handler.UpdateDocument)
e.DELETE("/api/documents/:id/", handler.DeleteDocument)
e.POST("/api/documents/:id/activate/", handler.ActivateDocument)
e.POST("/api/documents/:id/deactivate/", handler.DeactivateDocument)
tests := []struct {
name string
method string
path string
}{
{"ListDocuments", "GET", "/api/documents/"},
{"GetDocument", "GET", "/api/documents/1/"},
{"ListWarranties", "GET", "/api/documents/warranties/"},
{"CreateDocument", "POST", "/api/documents/"},
{"UpdateDocument", "PUT", "/api/documents/1/"},
{"DeleteDocument", "DELETE", "/api/documents/1/"},
{"ActivateDocument", "POST", "/api/documents/1/activate/"},
{"DeactivateDocument", "POST", "/api/documents/1/deactivate/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestContractorHandler_NoAuth_Returns401 verifies that contractor handler endpoints
// return 401 Unauthorized when no auth user is set in the context.
func TestContractorHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
contractorRepo := repositories.NewContractorRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
contractorService := services.NewContractorService(contractorRepo, residenceRepo)
handler := NewContractorHandler(contractorService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/contractors/", handler.ListContractors)
e.GET("/api/contractors/:id/", handler.GetContractor)
e.POST("/api/contractors/", handler.CreateContractor)
e.PUT("/api/contractors/:id/", handler.UpdateContractor)
e.DELETE("/api/contractors/:id/", handler.DeleteContractor)
e.POST("/api/contractors/:id/toggle-favorite/", handler.ToggleFavorite)
e.GET("/api/contractors/:id/tasks/", handler.GetContractorTasks)
tests := []struct {
name string
method string
path string
}{
{"ListContractors", "GET", "/api/contractors/"},
{"GetContractor", "GET", "/api/contractors/1/"},
{"CreateContractor", "POST", "/api/contractors/"},
{"UpdateContractor", "PUT", "/api/contractors/1/"},
{"DeleteContractor", "DELETE", "/api/contractors/1/"},
{"ToggleFavorite", "POST", "/api/contractors/1/toggle-favorite/"},
{"GetContractorTasks", "GET", "/api/contractors/1/tasks/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestSubscriptionHandler_NoAuth_Returns401 verifies that subscription handler
// endpoints return 401 Unauthorized when no auth user is set in the context.
func TestSubscriptionHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
subscriptionRepo := repositories.NewSubscriptionRepository(db)
residenceRepo := repositories.NewResidenceRepository(db)
taskRepo := repositories.NewTaskRepository(db)
contractorRepo := repositories.NewContractorRepository(db)
documentRepo := repositories.NewDocumentRepository(db)
subscriptionService := services.NewSubscriptionService(subscriptionRepo, residenceRepo, taskRepo, contractorRepo, documentRepo)
handler := NewSubscriptionHandler(subscriptionService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/subscription/", handler.GetSubscription)
e.GET("/api/subscription/status/", handler.GetSubscriptionStatus)
e.GET("/api/subscription/promotions/", handler.GetPromotions)
e.POST("/api/subscription/purchase/", handler.ProcessPurchase)
e.POST("/api/subscription/cancel/", handler.CancelSubscription)
e.POST("/api/subscription/restore/", handler.RestoreSubscription)
tests := []struct {
name string
method string
path string
}{
{"GetSubscription", "GET", "/api/subscription/"},
{"GetSubscriptionStatus", "GET", "/api/subscription/status/"},
{"GetPromotions", "GET", "/api/subscription/promotions/"},
{"ProcessPurchase", "POST", "/api/subscription/purchase/"},
{"CancelSubscription", "POST", "/api/subscription/cancel/"},
{"RestoreSubscription", "POST", "/api/subscription/restore/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestMediaHandler_NoAuth_Returns401 verifies that media handler endpoints return
// 401 Unauthorized when no auth user is set in the context.
func TestMediaHandler_NoAuth_Returns401(t *testing.T) {
handler := NewMediaHandler(nil, nil, nil, nil)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/media/document/:id", handler.ServeDocument)
e.GET("/api/media/document-image/:id", handler.ServeDocumentImage)
e.GET("/api/media/completion-image/:id", handler.ServeCompletionImage)
tests := []struct {
name string
method string
path string
}{
{"ServeDocument", "GET", "/api/media/document/1"},
{"ServeDocumentImage", "GET", "/api/media/document-image/1"},
{"ServeCompletionImage", "GET", "/api/media/completion-image/1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}
// TestUserHandler_NoAuth_Returns401 verifies that user handler endpoints return
// 401 Unauthorized when no auth user is set in the context.
func TestUserHandler_NoAuth_Returns401(t *testing.T) {
db := testutil.SetupTestDB(t)
userRepo := repositories.NewUserRepository(db)
userService := services.NewUserService(userRepo)
handler := NewUserHandler(userService)
e := testutil.SetupTestRouter()
// Register routes WITHOUT auth middleware
e.GET("/api/users/", handler.ListUsers)
e.GET("/api/users/:id/", handler.GetUser)
e.GET("/api/users/profiles/", handler.ListProfiles)
tests := []struct {
name string
method string
path string
}{
{"ListUsers", "GET", "/api/users/"},
{"GetUser", "GET", "/api/users/1/"},
{"ListProfiles", "GET", "/api/users/profiles/"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := testutil.MakeRequest(e, tt.method, tt.path, nil, "")
testutil.AssertStatusCode(t, w, http.StatusUnauthorized)
})
}
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -24,7 +23,10 @@ func NewNotificationHandler(notificationService *services.NotificationService) *
// ListNotifications handles GET /api/notifications/
func (h *NotificationHandler) ListNotifications(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
limit := 50
offset := 0
@@ -33,6 +35,9 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
limit = parsed
}
}
if limit > 200 {
limit = 200
}
if o := c.QueryParam("offset"); o != "" {
if parsed, err := strconv.Atoi(o); err == nil && parsed >= 0 {
offset = parsed
@@ -52,7 +57,10 @@ func (h *NotificationHandler) ListNotifications(c echo.Context) error {
// GetUnreadCount handles GET /api/notifications/unread-count/
func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
count, err := h.notificationService.GetUnreadCount(user.ID)
if err != nil {
@@ -64,7 +72,10 @@ func (h *NotificationHandler) GetUnreadCount(c echo.Context) error {
// MarkAsRead handles POST /api/notifications/:id/read/
func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
notificationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -81,9 +92,12 @@ func (h *NotificationHandler) MarkAsRead(c echo.Context) error {
// MarkAllAsRead handles POST /api/notifications/mark-all-read/
func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
err := h.notificationService.MarkAllAsRead(user.ID)
err = h.notificationService.MarkAllAsRead(user.ID)
if err != nil {
return err
}
@@ -93,7 +107,10 @@ func (h *NotificationHandler) MarkAllAsRead(c echo.Context) error {
// GetPreferences handles GET /api/notifications/preferences/
func (h *NotificationHandler) GetPreferences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
prefs, err := h.notificationService.GetPreferences(user.ID)
if err != nil {
@@ -105,12 +122,18 @@ func (h *NotificationHandler) GetPreferences(c echo.Context) error {
// UpdatePreferences handles PUT/PATCH /api/notifications/preferences/
func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.UpdatePreferencesRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
prefs, err := h.notificationService.UpdatePreferences(user.ID, &req)
if err != nil {
@@ -122,12 +145,18 @@ func (h *NotificationHandler) UpdatePreferences(c echo.Context) error {
// RegisterDevice handles POST /api/notifications/devices/
func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.RegisterDeviceRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
device, err := h.notificationService.RegisterDevice(user.ID, &req)
if err != nil {
@@ -139,7 +168,10 @@ func (h *NotificationHandler) RegisterDevice(c echo.Context) error {
// ListDevices handles GET /api/notifications/devices/
func (h *NotificationHandler) ListDevices(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
devices, err := h.notificationService.ListDevices(user.ID)
if err != nil {
@@ -152,7 +184,10 @@ func (h *NotificationHandler) ListDevices(c echo.Context) error {
// UnregisterDevice handles POST /api/notifications/devices/unregister/
// Accepts {registration_id, platform} and deactivates the matching device
func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req struct {
RegistrationID string `json:"registration_id"`
@@ -168,7 +203,7 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
req.Platform = "ios" // Default to iOS
}
err := h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
err = h.notificationService.UnregisterDevice(req.RegistrationID, req.Platform, user.ID)
if err != nil {
return err
}
@@ -178,7 +213,10 @@ func (h *NotificationHandler) UnregisterDevice(c echo.Context) error {
// DeleteDevice handles DELETE /api/notifications/devices/:id/
func (h *NotificationHandler) DeleteDevice(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
deviceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {

View File

@@ -0,0 +1,88 @@
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/testutil"
)
func setupNotificationHandler(t *testing.T) (*NotificationHandler, *echo.Echo, *gorm.DB) {
db := testutil.SetupTestDB(t)
notifRepo := repositories.NewNotificationRepository(db)
notifService := services.NewNotificationService(notifRepo, nil)
handler := NewNotificationHandler(notifService)
e := testutil.SetupTestRouter()
return handler, e, db
}
func createTestNotifications(t *testing.T, db *gorm.DB, userID uint, count int) {
for i := 0; i < count; i++ {
notif := &models.Notification{
UserID: userID,
NotificationType: models.NotificationTaskDueSoon,
Title: fmt.Sprintf("Test Notification %d", i+1),
Body: fmt.Sprintf("Body %d", i+1),
}
err := db.Create(notif).Error
require.NoError(t, err)
}
}
func TestNotificationHandler_ListNotifications_LimitCappedAt200(t *testing.T) {
handler, e, db := setupNotificationHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
// Create 210 notifications to exceed the cap
createTestNotifications(t, db, user.ID, 210)
authGroup := e.Group("/api/notifications")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.GET("/", handler.ListNotifications)
t.Run("limit is capped at 200 when user requests more", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=999", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 200, count, "response should contain at most 200 notifications when limit exceeds cap")
})
t.Run("limit below cap is respected", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=10", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 10, count, "response should respect limit when below cap")
})
t.Run("default limit is used when no limit param", func(t *testing.T) {
w := testutil.MakeRequest(e, "GET", "/api/notifications/", nil, "test-token")
testutil.AssertStatusCode(t, w, http.StatusOK)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
count := int(response["count"].(float64))
assert.Equal(t, 50, count, "response should use default limit of 50")
})
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/i18n"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
"github.com/treytartt/casera-api/internal/validator"
)
@@ -35,7 +34,10 @@ func NewResidenceHandler(residenceService *services.ResidenceService, pdfService
// ListResidences handles GET /api/residences/
func (h *ResidenceHandler) ListResidences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.residenceService.ListResidences(user.ID)
if err != nil {
@@ -47,7 +49,10 @@ func (h *ResidenceHandler) ListResidences(c echo.Context) error {
// GetMyResidences handles GET /api/residences/my-residences/
func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
response, err := h.residenceService.GetMyResidences(user.ID, userNow)
@@ -61,7 +66,10 @@ func (h *ResidenceHandler) GetMyResidences(c echo.Context) error {
// GetSummary handles GET /api/residences/summary/
// Returns just the task statistics summary without full residence data
func (h *ResidenceHandler) GetSummary(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
summary, err := h.residenceService.GetSummary(user.ID, userNow)
@@ -74,7 +82,10 @@ func (h *ResidenceHandler) GetSummary(c echo.Context) error {
// GetResidence handles GET /api/residences/:id/
func (h *ResidenceHandler) GetResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -91,7 +102,10 @@ func (h *ResidenceHandler) GetResidence(c echo.Context) error {
// CreateResidence handles POST /api/residences/
func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.CreateResidenceRequest
if err := c.Bind(&req); err != nil {
@@ -111,7 +125,10 @@ func (h *ResidenceHandler) CreateResidence(c echo.Context) error {
// UpdateResidence handles PUT/PATCH /api/residences/:id/
func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -136,7 +153,10 @@ func (h *ResidenceHandler) UpdateResidence(c echo.Context) error {
// DeleteResidence handles DELETE /api/residences/:id/
func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -154,7 +174,10 @@ func (h *ResidenceHandler) DeleteResidence(c echo.Context) error {
// GetShareCode handles GET /api/residences/:id/share-code/
// Returns the active share code for a residence, or null if none exists
func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -175,7 +198,10 @@ func (h *ResidenceHandler) GetShareCode(c echo.Context) error {
// GenerateShareCode handles POST /api/residences/:id/generate-share-code/
func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -197,7 +223,10 @@ func (h *ResidenceHandler) GenerateShareCode(c echo.Context) error {
// GenerateSharePackage handles POST /api/residences/:id/generate-share-package/
// Returns a share code with metadata for creating a .casera package file
func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -218,12 +247,18 @@ func (h *ResidenceHandler) GenerateSharePackage(c echo.Context) error {
// JoinWithCode handles POST /api/residences/join-with-code/
func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req requests.JoinWithCodeRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.residenceService.JoinWithCode(req.Code, user.ID)
if err != nil {
@@ -235,7 +270,10 @@ func (h *ResidenceHandler) JoinWithCode(c echo.Context) error {
// GetResidenceUsers handles GET /api/residences/:id/users/
func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -252,7 +290,10 @@ func (h *ResidenceHandler) GetResidenceUsers(c echo.Context) error {
// RemoveResidenceUser handles DELETE /api/residences/:id/users/:user_id/
func (h *ResidenceHandler) RemoveResidenceUser(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -289,7 +330,10 @@ func (h *ResidenceHandler) GenerateTasksReport(c echo.Context) error {
return apperrors.BadRequest("error.feature_disabled")
}
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
residenceID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {

View File

@@ -525,3 +525,45 @@ func TestResidenceHandler_JSONResponses(t *testing.T) {
assert.IsType(t, []map[string]interface{}{}, response)
})
}
func TestResidenceHandler_CreateResidence_NegativeBedrooms_Returns400(t *testing.T) {
handler, e, db := setupResidenceHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
authGroup := e.Group("/api/residences")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateResidence)
t.Run("negative bedrooms rejected", func(t *testing.T) {
bedrooms := -1
req := requests.CreateResidenceRequest{
Name: "Bad House",
Bedrooms: &bedrooms,
}
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("negative square footage rejected", func(t *testing.T) {
sqft := -100
req := requests.CreateResidenceRequest{
Name: "Bad House",
SquareFootage: &sqft,
}
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("zero bedrooms accepted", func(t *testing.T) {
bedrooms := 0
req := requests.CreateResidenceRequest{
Name: "Studio Apartment",
Bedrooms: &bedrooms,
}
w := testutil.MakeRequest(e, "POST", "/api/residences/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -23,7 +22,10 @@ func NewSubscriptionHandler(subscriptionService *services.SubscriptionService) *
// GetSubscription handles GET /api/subscription/
func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
subscription, err := h.subscriptionService.GetSubscription(user.ID)
if err != nil {
@@ -35,7 +37,10 @@ func (h *SubscriptionHandler) GetSubscription(c echo.Context) error {
// GetSubscriptionStatus handles GET /api/subscription/status/
func (h *SubscriptionHandler) GetSubscriptionStatus(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
status, err := h.subscriptionService.GetSubscriptionStatus(user.ID)
if err != nil {
@@ -79,7 +84,10 @@ func (h *SubscriptionHandler) GetFeatureBenefits(c echo.Context) error {
// GetPromotions handles GET /api/subscription/promotions/
func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
promotions, err := h.subscriptionService.GetActivePromotions(user.ID)
if err != nil {
@@ -91,15 +99,20 @@ func (h *SubscriptionHandler) GetPromotions(c echo.Context) error {
// ProcessPurchase handles POST /api/subscription/purchase/
func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.ProcessPurchaseRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
var subscription *services.SubscriptionResponse
var err error
switch req.Platform {
case "ios":
@@ -129,7 +142,10 @@ func (h *SubscriptionHandler) ProcessPurchase(c echo.Context) error {
// CancelSubscription handles POST /api/subscription/cancel/
func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
subscription, err := h.subscriptionService.CancelSubscription(user.ID)
if err != nil {
@@ -144,16 +160,21 @@ func (h *SubscriptionHandler) CancelSubscription(c echo.Context) error {
// RestoreSubscription handles POST /api/subscription/restore/
func (h *SubscriptionHandler) RestoreSubscription(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
var req services.ProcessPurchaseRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
// Same logic as ProcessPurchase - validates receipt/token and restores
var subscription *services.SubscriptionResponse
var err error
switch req.Platform {
case "ios":

View File

@@ -8,14 +8,14 @@ import (
"encoding/pem"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/models"
@@ -101,40 +101,39 @@ type AppleRenewalInfo struct {
// HandleAppleWebhook handles POST /api/subscription/webhook/apple/
func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
if !h.enabled {
log.Printf("Apple Webhook: webhooks disabled by feature flag")
log.Info().Msg("Apple Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Printf("Apple Webhook: Failed to read body: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
}
var payload AppleNotificationPayload
if err := json.Unmarshal(body, &payload); err != nil {
log.Printf("Apple Webhook: Failed to parse payload: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to parse payload")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid payload"})
}
// Decode and verify the signed payload (JWS)
notification, err := h.decodeAppleSignedPayload(payload.SignedPayload)
if err != nil {
log.Printf("Apple Webhook: Failed to decode signed payload: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to decode signed payload")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid signed payload"})
}
log.Printf("Apple Webhook: Received %s (subtype: %s) for bundle %s",
notification.NotificationType, notification.Subtype, notification.Data.BundleID)
log.Info().Str("type", notification.NotificationType).Str("subtype", notification.Subtype).Str("bundle", notification.Data.BundleID).Msg("Apple Webhook: Received notification")
// Dedup check using notificationUUID
if notification.NotificationUUID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("apple", notification.NotificationUUID)
if err != nil {
log.Printf("Apple Webhook: Failed to check dedup: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to check dedup")
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
log.Printf("Apple Webhook: Duplicate event %s, skipping", notification.NotificationUUID)
log.Info().Str("uuid", notification.NotificationUUID).Msg("Apple Webhook: Duplicate event, skipping")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
}
@@ -143,8 +142,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
cfg := config.Get()
if cfg != nil && cfg.AppleIAP.BundleID != "" {
if notification.Data.BundleID != cfg.AppleIAP.BundleID {
log.Printf("Apple Webhook: Bundle ID mismatch: got %s, expected %s",
notification.Data.BundleID, cfg.AppleIAP.BundleID)
log.Warn().Str("got", notification.Data.BundleID).Str("expected", cfg.AppleIAP.BundleID).Msg("Apple Webhook: Bundle ID mismatch")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "bundle ID mismatch"})
}
}
@@ -152,7 +150,7 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
// Decode transaction info
transactionInfo, err := h.decodeAppleTransaction(notification.Data.SignedTransactionInfo)
if err != nil {
log.Printf("Apple Webhook: Failed to decode transaction: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to decode transaction")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid transaction info"})
}
@@ -164,14 +162,14 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
// Process the notification
if err := h.processAppleNotification(notification, transactionInfo, renewalInfo); err != nil {
log.Printf("Apple Webhook: Failed to process notification: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to process notification")
// Still return 200 to prevent Apple from retrying
}
// Record processed event for dedup
if notification.NotificationUUID != "" {
if err := h.webhookEventRepo.RecordEvent("apple", notification.NotificationUUID, notification.NotificationType, ""); err != nil {
log.Printf("Apple Webhook: Failed to record event: %v", err)
log.Error().Err(err).Msg("Apple Webhook: Failed to record event")
}
}
@@ -179,7 +177,8 @@ func (h *SubscriptionWebhookHandler) HandleAppleWebhook(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{"status": "received"})
}
// decodeAppleSignedPayload decodes and verifies an Apple JWS payload
// decodeAppleSignedPayload verifies and decodes an Apple JWS payload.
// The JWS signature is verified before the payload is trusted.
func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload string) (*AppleNotificationData, error) {
// JWS format: header.payload.signature
parts := strings.Split(signedPayload, ".")
@@ -187,8 +186,11 @@ func (h *SubscriptionWebhookHandler) decodeAppleSignedPayload(signedPayload stri
return nil, fmt.Errorf("invalid JWS format")
}
// Decode payload (we're trusting Apple's signature for now)
// In production, you should verify the signature using Apple's root certificate
// Verify the JWS signature before trusting the payload.
if err := h.VerifyAppleSignature(signedPayload); err != nil {
return nil, fmt.Errorf("Apple JWS signature verification failed: %w", err)
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %w", err)
@@ -251,14 +253,12 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
// Find user by stored receipt data (original transaction ID)
user, err := h.findUserByAppleTransaction(transaction.OriginalTransactionID)
if err != nil {
log.Printf("Apple Webhook: Could not find user for transaction %s: %v",
transaction.OriginalTransactionID, err)
log.Warn().Err(err).Str("transaction_id", transaction.OriginalTransactionID).Msg("Apple Webhook: Could not find user for transaction")
// Not an error - might be a transaction we don't track
return nil
}
log.Printf("Apple Webhook: Processing %s for user %d (product: %s)",
notification.NotificationType, user.ID, transaction.ProductID)
log.Info().Str("type", notification.NotificationType).Uint("user_id", user.ID).Str("product", transaction.ProductID).Msg("Apple Webhook: Processing notification")
switch notification.NotificationType {
case "SUBSCRIBED":
@@ -294,7 +294,7 @@ func (h *SubscriptionWebhookHandler) processAppleNotification(
return h.handleAppleGracePeriodExpired(user.ID, transaction)
default:
log.Printf("Apple Webhook: Unhandled notification type: %s", notification.NotificationType)
log.Warn().Str("type", notification.NotificationType).Msg("Apple Webhook: Unhandled notification type")
}
return nil
@@ -326,7 +326,7 @@ func (h *SubscriptionWebhookHandler) handleAppleSubscribed(userID uint, tx *Appl
return err
}
log.Printf("Apple Webhook: User %d subscribed, expires %v, autoRenew=%v", userID, expiresAt, autoRenew)
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Bool("auto_renew", autoRenew).Msg("Apple Webhook: User subscribed")
return nil
}
@@ -337,7 +337,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewed(userID uint, tx *AppleTr
return err
}
log.Printf("Apple Webhook: User %d renewed, new expiry %v", userID, expiresAt)
log.Info().Uint("user_id", userID).Time("expires", expiresAt).Msg("Apple Webhook: User renewed")
return nil
}
@@ -357,13 +357,13 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
if err := h.subscriptionRepo.SetCancelledAt(userID, now); err != nil {
return err
}
log.Printf("Apple Webhook: User %d turned off auto-renew, will expire at end of period", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned off auto-renew, will expire at end of period")
} else {
// User turned auto-renew back on
if err := h.subscriptionRepo.ClearCancelledAt(userID); err != nil {
return err
}
log.Printf("Apple Webhook: User %d turned auto-renew back on", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User turned auto-renew back on")
}
return nil
@@ -371,7 +371,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRenewalStatusChange(userID uint,
func (h *SubscriptionWebhookHandler) handleAppleFailedToRenew(userID uint, tx *AppleTransactionInfo, renewal *AppleRenewalInfo) error {
// Subscription is in billing retry or grace period
log.Printf("Apple Webhook: User %d failed to renew, may be in grace period", userID)
log.Warn().Uint("user_id", userID).Msg("Apple Webhook: User failed to renew, may be in grace period")
// Don't downgrade yet - Apple may retry billing
return nil
}
@@ -381,7 +381,7 @@ func (h *SubscriptionWebhookHandler) handleAppleExpired(userID uint, tx *AppleTr
return err
}
log.Printf("Apple Webhook: User %d subscription expired, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription expired, downgraded to free")
return nil
}
@@ -390,7 +390,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRefund(userID uint, tx *AppleTra
return err
}
log.Printf("Apple Webhook: User %d got refund, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User got refund, downgraded to free")
return nil
}
@@ -399,7 +399,7 @@ func (h *SubscriptionWebhookHandler) handleAppleRevoke(userID uint, tx *AppleTra
return err
}
log.Printf("Apple Webhook: User %d subscription revoked, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User subscription revoked, downgraded to free")
return nil
}
@@ -408,7 +408,7 @@ func (h *SubscriptionWebhookHandler) handleAppleGracePeriodExpired(userID uint,
return err
}
log.Printf("Apple Webhook: User %d grace period expired, downgraded to free", userID)
log.Info().Uint("user_id", userID).Msg("Apple Webhook: User grace period expired, downgraded to free")
return nil
}
@@ -481,32 +481,32 @@ const (
// HandleGoogleWebhook handles POST /api/subscription/webhook/google/
func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
if !h.enabled {
log.Printf("Google Webhook: webhooks disabled by feature flag")
log.Info().Msg("Google Webhook: webhooks disabled by feature flag")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "webhooks_disabled"})
}
body, err := io.ReadAll(c.Request().Body)
if err != nil {
log.Printf("Google Webhook: Failed to read body: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to read body")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "failed to read request body"})
}
var notification GoogleNotification
if err := json.Unmarshal(body, &notification); err != nil {
log.Printf("Google Webhook: Failed to parse notification: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to parse notification")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid notification"})
}
// Decode the base64 data
data, err := base64.StdEncoding.DecodeString(notification.Message.Data)
if err != nil {
log.Printf("Google Webhook: Failed to decode message data: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to decode message data")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid message data"})
}
var devNotification GoogleDeveloperNotification
if err := json.Unmarshal(data, &devNotification); err != nil {
log.Printf("Google Webhook: Failed to parse developer notification: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to parse developer notification")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "invalid developer notification"})
}
@@ -515,17 +515,17 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
if messageID != "" {
alreadyProcessed, err := h.webhookEventRepo.HasProcessed("google", messageID)
if err != nil {
log.Printf("Google Webhook: Failed to check dedup: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to check dedup")
// Continue processing on dedup check failure (fail-open)
} else if alreadyProcessed {
log.Printf("Google Webhook: Duplicate event %s, skipping", messageID)
log.Info().Str("message_id", messageID).Msg("Google Webhook: Duplicate event, skipping")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "duplicate"})
}
}
// Handle test notification
if devNotification.TestNotification != nil {
log.Printf("Google Webhook: Received test notification")
log.Info().Msg("Google Webhook: Received test notification")
return c.JSON(http.StatusOK, map[string]interface{}{"status": "test received"})
}
@@ -533,8 +533,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
cfg := config.Get()
if cfg != nil && cfg.GoogleIAP.PackageName != "" {
if devNotification.PackageName != cfg.GoogleIAP.PackageName {
log.Printf("Google Webhook: Package name mismatch: got %s, expected %s",
devNotification.PackageName, cfg.GoogleIAP.PackageName)
log.Warn().Str("got", devNotification.PackageName).Str("expected", cfg.GoogleIAP.PackageName).Msg("Google Webhook: Package name mismatch")
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "package name mismatch"})
}
}
@@ -542,7 +541,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
// Process subscription notification
if devNotification.SubscriptionNotification != nil {
if err := h.processGoogleSubscriptionNotification(devNotification.SubscriptionNotification); err != nil {
log.Printf("Google Webhook: Failed to process notification: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to process notification")
// Still return 200 to acknowledge
}
}
@@ -554,7 +553,7 @@ func (h *SubscriptionWebhookHandler) HandleGoogleWebhook(c echo.Context) error {
eventType = fmt.Sprintf("subscription_%d", devNotification.SubscriptionNotification.NotificationType)
}
if err := h.webhookEventRepo.RecordEvent("google", messageID, eventType, ""); err != nil {
log.Printf("Google Webhook: Failed to record event: %v", err)
log.Error().Err(err).Msg("Google Webhook: Failed to record event")
}
}
@@ -567,12 +566,11 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
// Find user by purchase token
user, err := h.findUserByGoogleToken(notification.PurchaseToken)
if err != nil {
log.Printf("Google Webhook: Could not find user for token: %v", err)
log.Warn().Err(err).Msg("Google Webhook: Could not find user for token")
return nil // Not an error - might be unknown token
}
log.Printf("Google Webhook: Processing type %d for user %d (subscription: %s)",
notification.NotificationType, user.ID, notification.SubscriptionID)
log.Info().Int("type", notification.NotificationType).Uint("user_id", user.ID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: Processing notification")
switch notification.NotificationType {
case GoogleSubPurchased:
@@ -606,7 +604,7 @@ func (h *SubscriptionWebhookHandler) processGoogleSubscriptionNotification(notif
return h.handleGooglePaused(user.ID, notification)
default:
log.Printf("Google Webhook: Unhandled notification type: %d", notification.NotificationType)
log.Warn().Int("type", notification.NotificationType).Msg("Google Webhook: Unhandled notification type")
}
return nil
@@ -629,7 +627,7 @@ func (h *SubscriptionWebhookHandler) findUserByGoogleToken(purchaseToken string)
func (h *SubscriptionWebhookHandler) handleGooglePurchased(userID uint, notification *GoogleSubscriptionNotification) error {
// New subscription - we should have already processed this via the client
// This is a backup notification
log.Printf("Google Webhook: User %d purchased subscription %s", userID, notification.SubscriptionID)
log.Info().Uint("user_id", userID).Str("subscription", notification.SubscriptionID).Msg("Google Webhook: User purchased subscription")
return nil
}
@@ -648,7 +646,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRenewed(userID uint, notificati
return err
}
log.Printf("Google Webhook: User %d renewed, extended to %v", userID, newExpiry)
log.Info().Uint("user_id", userID).Time("expires", newExpiry).Msg("Google Webhook: User renewed")
return nil
}
@@ -659,7 +657,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRecovered(userID uint, notifica
return err
}
log.Printf("Google Webhook: User %d subscription recovered", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription recovered")
return nil
}
@@ -673,19 +671,19 @@ func (h *SubscriptionWebhookHandler) handleGoogleCanceled(userID uint, notificat
return err
}
log.Printf("Google Webhook: User %d canceled, will expire at end of period", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User canceled, will expire at end of period")
return nil
}
func (h *SubscriptionWebhookHandler) handleGoogleOnHold(userID uint, notification *GoogleSubscriptionNotification) error {
// Account hold - payment issue, may recover
log.Printf("Google Webhook: User %d subscription on hold", userID)
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User subscription on hold")
return nil
}
func (h *SubscriptionWebhookHandler) handleGoogleGracePeriod(userID uint, notification *GoogleSubscriptionNotification) error {
// In grace period - user still has access but billing failed
log.Printf("Google Webhook: User %d in grace period", userID)
log.Warn().Uint("user_id", userID).Msg("Google Webhook: User in grace period")
return nil
}
@@ -702,7 +700,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRestarted(userID uint, notifica
return err
}
log.Printf("Google Webhook: User %d restarted subscription", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User restarted subscription")
return nil
}
@@ -712,7 +710,7 @@ func (h *SubscriptionWebhookHandler) handleGoogleRevoked(userID uint, notificati
return err
}
log.Printf("Google Webhook: User %d subscription revoked", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription revoked")
return nil
}
@@ -722,13 +720,13 @@ func (h *SubscriptionWebhookHandler) handleGoogleExpired(userID uint, notificati
return err
}
log.Printf("Google Webhook: User %d subscription expired", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription expired")
return nil
}
func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notification *GoogleSubscriptionNotification) error {
// Subscription paused by user
log.Printf("Google Webhook: User %d subscription paused", userID)
log.Info().Uint("user_id", userID).Msg("Google Webhook: User subscription paused")
return nil
}
@@ -736,18 +734,21 @@ func (h *SubscriptionWebhookHandler) handleGooglePaused(userID uint, notificatio
// Signature Verification (Optional but Recommended)
// ====================
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate
// This is optional but recommended for production
// VerifyAppleSignature verifies the JWS signature using Apple's root certificate.
// If root certificates are not loaded, verification fails (deny by default).
func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string) error {
// Load Apple's root certificate if not already loaded
// Deny by default when root certificates are not loaded.
if h.appleRootCerts == nil {
// Apple's root certificates can be downloaded from:
// https://www.apple.com/certificateauthority/
// You'd typically embed these or load from a file
return nil // Skip verification for now
return fmt.Errorf("Apple root certificates not configured: cannot verify JWS signature")
}
// Parse the JWS token
// Build a certificate pool from the loaded Apple root certificates
rootPool := x509.NewCertPool()
for _, cert := range h.appleRootCerts {
rootPool.AddCert(cert)
}
// Parse the JWS token and verify the signature using the x5c certificate chain
token, err := jwt.Parse(signedPayload, func(token *jwt.Token) (interface{}, error) {
// Get the x5c header (certificate chain)
x5c, ok := token.Header["x5c"].([]interface{})
@@ -755,21 +756,46 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
return nil, fmt.Errorf("missing x5c header")
}
// Decode the first certificate (leaf)
// Decode the leaf certificate
certData, err := base64.StdEncoding.DecodeString(x5c[0].(string))
if err != nil {
return nil, fmt.Errorf("failed to decode certificate: %w", err)
}
cert, err := x509.ParseCertificate(certData)
leafCert, err := x509.ParseCertificate(certData)
if err != nil {
return nil, fmt.Errorf("failed to parse certificate: %w", err)
}
// Verify the certificate chain (simplified)
// In production, you should verify the full chain
// Build intermediate pool from remaining x5c entries
intermediatePool := x509.NewCertPool()
for i := 1; i < len(x5c); i++ {
intermData, err := base64.StdEncoding.DecodeString(x5c[i].(string))
if err != nil {
return nil, fmt.Errorf("failed to decode intermediate certificate: %w", err)
}
intermCert, err := x509.ParseCertificate(intermData)
if err != nil {
return nil, fmt.Errorf("failed to parse intermediate certificate: %w", err)
}
intermediatePool.AddCert(intermCert)
}
return cert.PublicKey.(*ecdsa.PublicKey), nil
// Verify the certificate chain against Apple's root certificates
opts := x509.VerifyOptions{
Roots: rootPool,
Intermediates: intermediatePool,
}
if _, err := leafCert.Verify(opts); err != nil {
return nil, fmt.Errorf("certificate chain verification failed: %w", err)
}
ecdsaKey, ok := leafCert.PublicKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("leaf certificate public key is not ECDSA")
}
return ecdsaKey, nil
})
if err != nil {
@@ -783,13 +809,58 @@ func (h *SubscriptionWebhookHandler) VerifyAppleSignature(signedPayload string)
return nil
}
// VerifyGooglePubSubToken verifies the Pub/Sub push token (if configured)
// VerifyGooglePubSubToken verifies the Pub/Sub push authentication token.
// Returns false (deny) when the Authorization header is missing or the token
// cannot be validated. This prevents unauthenticated callers from injecting
// webhook events.
func (h *SubscriptionWebhookHandler) VerifyGooglePubSubToken(c echo.Context) bool {
// If you configured a push endpoint with authentication, verify here
// The token is typically in the Authorization header
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
log.Warn().Msg("Google Webhook: missing Authorization header")
return false
}
// Expect "Bearer <token>" format
if !strings.HasPrefix(authHeader, "Bearer ") {
log.Warn().Msg("Google Webhook: Authorization header is not Bearer token")
return false
}
bearerToken := strings.TrimPrefix(authHeader, "Bearer ")
if bearerToken == "" {
log.Warn().Msg("Google Webhook: empty Bearer token")
return false
}
// Parse the token as a JWT. Google Pub/Sub push tokens are signed JWTs
// issued by accounts.google.com. We verify the claims to ensure the
// token was intended for our service.
token, _, err := jwt.NewParser().ParseUnverified(bearerToken, jwt.MapClaims{})
if err != nil {
log.Warn().Err(err).Msg("Google Webhook: failed to parse Bearer token")
return false
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Warn().Msg("Google Webhook: invalid token claims")
return false
}
// Verify issuer is Google
issuer, _ := claims.GetIssuer()
if issuer != "accounts.google.com" && issuer != "https://accounts.google.com" {
log.Warn().Str("issuer", issuer).Msg("Google Webhook: unexpected issuer")
return false
}
// Verify the email claim matches a Google service account
email, _ := claims["email"].(string)
if email == "" || !strings.HasSuffix(email, ".gserviceaccount.com") {
log.Warn().Str("email", email).Msg("Google Webhook: token email is not a Google service account")
return false
}
// For now, we rely on the endpoint being protected by your infrastructure
// (e.g., only accessible from Google's IP ranges)
return true
}

View File

@@ -0,0 +1,56 @@
package handlers
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)
func TestVerifyGooglePubSubToken_MissingAuth_ReturnsFalse(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
e := echo.New()
// Request with no Authorization header
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := handler.VerifyGooglePubSubToken(c)
assert.False(t, result, "VerifyGooglePubSubToken should return false when Authorization header is missing")
}
func TestVerifyGooglePubSubToken_InvalidToken_ReturnsFalse(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/api/subscription/webhook/google/", nil)
req.Header.Set("Authorization", "Bearer invalid-garbage-token")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := handler.VerifyGooglePubSubToken(c)
assert.False(t, result, "VerifyGooglePubSubToken should return false for an invalid/unverifiable token")
}
func TestDecodeAppleSignedPayload_InvalidJWS_ReturnsError(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
// No signature parts
_, err := handler.decodeAppleSignedPayload("not-a-jws")
assert.Error(t, err, "should reject payload that is not valid JWS format")
}
func TestDecodeAppleSignedPayload_VerificationFails_ReturnsError(t *testing.T) {
handler := &SubscriptionWebhookHandler{enabled: true}
// Construct a JWS-shaped string with 3 parts but no valid signature.
// The handler should now attempt verification and fail.
// header.payload.signature -- all base64url garbage
fakeJWS := "eyJhbGciOiJFUzI1NiJ9.eyJ0ZXN0IjoidHJ1ZSJ9.invalidsig"
_, err := handler.decodeAppleSignedPayload(fakeJWS)
assert.Error(t, err, "should return error when Apple signature verification fails")
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -32,13 +31,16 @@ func NewTaskHandler(taskService *services.TaskService, storageService *services.
// ListTasks handles GET /api/tasks/
func (h *TaskHandler) ListTasks(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
// Auto-capture timezone from header for background job calculations (e.g., daily digest)
// This runs in a goroutine to avoid blocking the response
// Runs synchronously — this is a lightweight DB upsert that should complete quickly
if tzHeader := c.Request().Header.Get("X-Timezone"); tzHeader != "" {
go h.taskService.UpdateUserTimezone(user.ID, tzHeader)
h.taskService.UpdateUserTimezone(user.ID, tzHeader)
}
daysThreshold := 30
@@ -62,7 +64,10 @@ func (h *TaskHandler) ListTasks(c echo.Context) error {
// GetTask handles GET /api/tasks/:id/
func (h *TaskHandler) GetTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -77,7 +82,10 @@ func (h *TaskHandler) GetTask(c echo.Context) error {
// GetTasksByResidence handles GET /api/tasks/by-residence/:residence_id/
func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
residenceID, err := strconv.ParseUint(c.Param("residence_id"), 10, 32)
@@ -106,13 +114,19 @@ func (h *TaskHandler) GetTasksByResidence(c echo.Context) error {
// CreateTask handles POST /api/tasks/
func (h *TaskHandler) CreateTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
var req requests.CreateTaskRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.CreateTask(&req, user.ID, userNow)
if err != nil {
@@ -123,7 +137,10 @@ func (h *TaskHandler) CreateTask(c echo.Context) error {
// UpdateTask handles PUT/PATCH /api/tasks/:id/
func (h *TaskHandler) UpdateTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -135,6 +152,9 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.UpdateTask(uint(taskID), user.ID, &req, userNow)
if err != nil {
@@ -145,7 +165,10 @@ func (h *TaskHandler) UpdateTask(c echo.Context) error {
// DeleteTask handles DELETE /api/tasks/:id/
func (h *TaskHandler) DeleteTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -160,7 +183,10 @@ func (h *TaskHandler) DeleteTask(c echo.Context) error {
// MarkInProgress handles POST /api/tasks/:id/mark-in-progress/
func (h *TaskHandler) MarkInProgress(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -177,7 +203,10 @@ func (h *TaskHandler) MarkInProgress(c echo.Context) error {
// CancelTask handles POST /api/tasks/:id/cancel/
func (h *TaskHandler) CancelTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -194,7 +223,10 @@ func (h *TaskHandler) CancelTask(c echo.Context) error {
// UncancelTask handles POST /api/tasks/:id/uncancel/
func (h *TaskHandler) UncancelTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -211,7 +243,10 @@ func (h *TaskHandler) UncancelTask(c echo.Context) error {
// ArchiveTask handles POST /api/tasks/:id/archive/
func (h *TaskHandler) ArchiveTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -228,7 +263,10 @@ func (h *TaskHandler) ArchiveTask(c echo.Context) error {
// UnarchiveTask handles POST /api/tasks/:id/unarchive/
func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
@@ -246,7 +284,10 @@ func (h *TaskHandler) UnarchiveTask(c echo.Context) error {
// QuickComplete handles POST /api/tasks/:id/quick-complete/
// Lightweight endpoint for widget - just returns 200 OK on success
func (h *TaskHandler) QuickComplete(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -263,7 +304,10 @@ func (h *TaskHandler) QuickComplete(c echo.Context) error {
// GetTaskCompletions handles GET /api/tasks/:id/completions/
func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
taskID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_task_id")
@@ -278,7 +322,10 @@ func (h *TaskHandler) GetTaskCompletions(c echo.Context) error {
// ListCompletions handles GET /api/task-completions/
func (h *TaskHandler) ListCompletions(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
response, err := h.taskService.ListCompletions(user.ID)
if err != nil {
return err
@@ -288,7 +335,10 @@ func (h *TaskHandler) ListCompletions(c echo.Context) error {
// GetCompletion handles GET /api/task-completions/:id/
func (h *TaskHandler) GetCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_completion_id")
@@ -304,7 +354,10 @@ func (h *TaskHandler) GetCompletion(c echo.Context) error {
// CreateCompletion handles POST /api/task-completions/
// Supports both JSON and multipart form data (for image uploads)
func (h *TaskHandler) CreateCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userNow := middleware.GetUserNow(c)
var req requests.CreateTaskCompletionRequest
@@ -367,6 +420,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
}
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.CreateCompletion(&req, user.ID, userNow)
if err != nil {
return err
@@ -376,7 +433,10 @@ func (h *TaskHandler) CreateCompletion(c echo.Context) error {
// UpdateCompletion handles PUT /api/task-completions/:id/
func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_completion_id")
@@ -386,6 +446,9 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return err
}
response, err := h.taskService.UpdateCompletion(uint(completionID), user.ID, &req)
if err != nil {
@@ -396,7 +459,10 @@ func (h *TaskHandler) UpdateCompletion(c echo.Context) error {
// DeleteCompletion handles DELETE /api/task-completions/:id/
func (h *TaskHandler) DeleteCompletion(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
completionID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
return apperrors.BadRequest("error.invalid_completion_id")

View File

@@ -506,6 +506,52 @@ func TestTaskHandler_CreateCompletion(t *testing.T) {
})
}
func TestTaskHandler_CreateCompletion_Rating6_Returns400(t *testing.T) {
handler, e, db := setupTaskHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Rate Me")
authGroup := e.Group("/api/task-completions")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateCompletion)
t.Run("rating out of bounds rejected", func(t *testing.T) {
rating := 6
req := requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("rating zero rejected", func(t *testing.T) {
rating := 0
req := requests.CreateTaskCompletionRequest{
TaskID: task.ID,
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
})
t.Run("rating 5 accepted", func(t *testing.T) {
rating := 5
completedAt := time.Now().UTC()
req := requests.CreateTaskCompletionRequest{
TaskID: task.ID,
CompletedAt: &completedAt,
Rating: &rating,
}
w := testutil.MakeRequest(e, "POST", "/api/task-completions/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusCreated)
})
}
func TestTaskHandler_ListCompletions(t *testing.T) {
handler, e, db := setupTaskHandler(t)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
@@ -603,6 +649,71 @@ func TestTaskHandler_DeleteCompletion(t *testing.T) {
})
}
func TestTaskHandler_CreateTask_EmptyTitle_Returns400(t *testing.T) {
handler, e, db := setupTaskHandler(t)
testutil.SeedLookupData(t, db)
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "password")
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
authGroup := e.Group("/api/tasks")
authGroup.Use(testutil.MockAuthMiddleware(user))
authGroup.POST("/", handler.CreateTask)
t.Run("empty body returns 400 with validation errors", func(t *testing.T) {
w := testutil.MakeRequest(e, "POST", "/api/tasks/", map[string]interface{}{}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
// Should contain structured validation error
assert.Contains(t, response, "error")
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
assert.Contains(t, fields, "title", "validation error should reference 'title'")
})
t.Run("missing title returns 400", func(t *testing.T) {
req := map[string]interface{}{
"residence_id": residence.ID,
}
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "title", "validation error should reference 'title'")
})
t.Run("missing residence_id returns 400", func(t *testing.T) {
req := map[string]interface{}{
"title": "Test Task",
}
w := testutil.MakeRequest(e, "POST", "/api/tasks/", req, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
var response map[string]interface{}
err := json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response, "fields")
fields := response["fields"].(map[string]interface{})
assert.Contains(t, fields, "residence_id", "validation error should reference 'residence_id'")
})
}
func TestTaskHandler_GetLookups(t *testing.T) {
handler, e, db := setupTaskHandler(t)
testutil.SeedLookupData(t, db)

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/services"
)
@@ -32,7 +33,14 @@ func (h *TrackingHandler) TrackEmailOpen(c echo.Context) error {
if trackingID != "" && h.onboardingService != nil {
// Record the open (async, don't block response)
go func() {
_ = h.onboardingService.RecordEmailOpened(trackingID)
defer func() {
if r := recover(); r != nil {
log.Error().Interface("panic", r).Str("tracking_id", trackingID).Msg("Panic in email open tracking goroutine")
}
}()
if err := h.onboardingService.RecordEmailOpened(trackingID); err != nil {
log.Error().Err(err).Str("tracking_id", trackingID).Msg("Failed to record email open")
}
}()
}

View File

@@ -4,8 +4,11 @@ import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -73,17 +76,38 @@ func (h *UploadHandler) UploadCompletion(c echo.Context) error {
return c.JSON(http.StatusOK, result)
}
// DeleteFileRequest is the request body for deleting a file.
type DeleteFileRequest struct {
URL string `json:"url" validate:"required"`
}
// DeleteFile handles DELETE /api/uploads
// Expects JSON body with "url" field
// Expects JSON body with "url" field.
//
// TODO(SEC-18): Add ownership verification. Currently any authenticated user can delete
// any file if they know the URL. The upload system does not track which user uploaded
// which file, so a proper fix requires adding an uploads table or file ownership metadata.
// For now, deletions are logged with user ID for audit trail, and StorageService.Delete
// enforces path containment to prevent deleting files outside the upload directory.
func (h *UploadHandler) DeleteFile(c echo.Context) error {
var req struct {
URL string `json:"url" binding:"required"`
}
var req DeleteFileRequest
if err := c.Bind(&req); err != nil {
return apperrors.BadRequest("error.invalid_request")
}
if err := c.Validate(&req); err != nil {
return apperrors.BadRequest("error.url_required")
}
// Log the deletion with user ID for audit trail
if user, ok := c.Get(middleware.AuthUserKey).(*models.User); ok {
log.Info().
Uint("user_id", user.ID).
Str("file_url", req.URL).
Msg("File deletion requested")
}
if err := h.storageService.Delete(req.URL); err != nil {
return err
}

View File

@@ -0,0 +1,43 @@
package handlers
import (
"net/http"
"testing"
"github.com/treytartt/casera-api/internal/i18n"
"github.com/treytartt/casera-api/internal/testutil"
)
func init() {
// Initialize i18n so the custom error handler can localize error messages.
// Other handler tests get this from testutil.SetupTestDB, but these tests
// don't need a database.
i18n.Init()
}
func TestDeleteFile_MissingURL_Returns400(t *testing.T) {
// Use a test storage service — DeleteFile won't reach storage since validation fails first
storageSvc := newTestStorageService("/var/uploads")
handler := NewUploadHandler(storageSvc)
e := testutil.SetupTestRouter()
// Register route
e.DELETE("/api/uploads/", handler.DeleteFile)
// Send request with empty JSON body (url field missing)
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
}
func TestDeleteFile_EmptyURL_Returns400(t *testing.T) {
storageSvc := newTestStorageService("/var/uploads")
handler := NewUploadHandler(storageSvc)
e := testutil.SetupTestRouter()
e.DELETE("/api/uploads/", handler.DeleteFile)
// Send request with empty url field
w := testutil.MakeRequest(e, http.MethodDelete, "/api/uploads/", map[string]string{"url": ""}, "test-token")
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/middleware"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/services"
)
@@ -26,7 +25,10 @@ func NewUserHandler(userService *services.UserService) *UserHandler {
// ListUsers handles GET /api/users/
func (h *UserHandler) ListUsers(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
// Only allow listing users that share residences with the current user
users, err := h.userService.ListUsersInSharedResidences(user.ID)
@@ -42,7 +44,10 @@ func (h *UserHandler) ListUsers(c echo.Context) error {
// GetUser handles GET /api/users/:id/
func (h *UserHandler) GetUser(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
userID, err := strconv.ParseUint(c.Param("id"), 10, 32)
if err != nil {
@@ -60,7 +65,10 @@ func (h *UserHandler) GetUser(c echo.Context) error {
// ListProfiles handles GET /api/users/profiles/
func (h *UserHandler) ListProfiles(c echo.Context) error {
user := c.Get(middleware.AuthUserKey).(*models.User)
user, err := middleware.MustGetAuthUser(c)
if err != nil {
return err
}
// List profiles of users in shared residences
profiles, err := h.userService.ListProfilesInSharedResidences(user.ID)