Migrate from Gin to Echo framework and add comprehensive integration tests

Major changes:
- Migrate all handlers from Gin to Echo framework
- Add new apperrors, echohelpers, and validator packages
- Update middleware for Echo compatibility
- Add ArchivedHandler to task categorization chain (archived tasks go to cancelled_tasks column)
- Add 6 new integration tests:
  - RecurringTaskLifecycle: NextDueDate advancement for weekly/monthly tasks
  - MultiUserSharing: Complex sharing with user removal
  - TaskStateTransitions: All state transitions and kanban column changes
  - DateBoundaryEdgeCases: Threshold boundary testing
  - CascadeOperations: Residence deletion cascade effects
  - MultiUserOperations: Shared residence collaboration
- Add single-purpose repository functions for kanban columns (GetOverdueTasks, GetDueSoonTasks, etc.)
- Fix RemoveUser route param mismatch (userId -> user_id)
- Fix determineExpectedColumn helper to correctly prioritize in_progress over overdue

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Trey t
2025-12-16 13:52:08 -06:00
parent c51f1ce34a
commit 6dac34e373
98 changed files with 8209 additions and 4425 deletions

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/dto/responses"
@@ -18,18 +19,20 @@ import (
"github.com/treytartt/casera-api/internal/repositories"
)
// Deprecated: Legacy error constants - kept for reference during transition
// Use apperrors package instead
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrUsernameTaken = errors.New("username already taken")
ErrEmailTaken = errors.New("email already taken")
ErrUserInactive = errors.New("user account is inactive")
ErrInvalidCode = errors.New("invalid verification code")
ErrCodeExpired = errors.New("verification code expired")
ErrAlreadyVerified = errors.New("email already verified")
ErrRateLimitExceeded = errors.New("too many requests, please try again later")
ErrInvalidResetToken = errors.New("invalid or expired reset token")
ErrAppleSignInFailed = errors.New("Apple Sign In failed")
ErrGoogleSignInFailed = errors.New("Google Sign In failed")
// ErrInvalidCredentials = errors.New("invalid credentials")
// ErrUsernameTaken = errors.New("username already taken")
// ErrEmailTaken = errors.New("email already taken")
// ErrUserInactive = errors.New("user account is inactive")
// ErrInvalidCode = errors.New("invalid verification code")
// ErrCodeExpired = errors.New("verification code expired")
// ErrAlreadyVerified = errors.New("email already verified")
// ErrRateLimitExceeded = errors.New("too many requests, please try again later")
// ErrInvalidResetToken = errors.New("invalid or expired reset token")
ErrAppleSignInFailed = errors.New("Apple Sign In failed")
ErrGoogleSignInFailed = errors.New("Google Sign In failed")
)
// AuthService handles authentication business logic
@@ -63,25 +66,25 @@ func (s *AuthService) Login(req *requests.LoginRequest) (*responses.LoginRespons
user, err := s.userRepo.FindByUsernameOrEmail(identifier)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) {
return nil, ErrInvalidCredentials
return nil, apperrors.Unauthorized("error.invalid_credentials")
}
return nil, fmt.Errorf("failed to find user: %w", err)
return nil, apperrors.Internal(err)
}
// Check if user is active
if !user.IsActive {
return nil, ErrUserInactive
return nil, apperrors.Unauthorized("error.account_inactive")
}
// Verify password
if !user.CheckPassword(req.Password) {
return nil, ErrInvalidCredentials
return nil, apperrors.Unauthorized("error.invalid_credentials")
}
// Get or create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Update last login
@@ -101,19 +104,19 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// Check if username exists
exists, err := s.userRepo.ExistsByUsername(req.Username)
if err != nil {
return nil, "", fmt.Errorf("failed to check username: %w", err)
return nil, "", apperrors.Internal(err)
}
if exists {
return nil, "", ErrUsernameTaken
return nil, "", apperrors.Conflict("error.username_taken")
}
// Check if email exists
exists, err = s.userRepo.ExistsByEmail(req.Email)
if err != nil {
return nil, "", fmt.Errorf("failed to check email: %w", err)
return nil, "", apperrors.Internal(err)
}
if exists {
return nil, "", ErrEmailTaken
return nil, "", apperrors.Conflict("error.email_taken")
}
// Create user
@@ -127,12 +130,12 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// Hash password
if err := user.SetPassword(req.Password); err != nil {
return nil, "", fmt.Errorf("failed to hash password: %w", err)
return nil, "", apperrors.Internal(err)
}
// Save user
if err := s.userRepo.Create(user); err != nil {
return nil, "", fmt.Errorf("failed to create user: %w", err)
return nil, "", apperrors.Internal(err)
}
// Create user profile
@@ -152,7 +155,7 @@ func (s *AuthService) Register(req *requests.RegisterRequest) (*responses.Regist
// Create auth token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, "", fmt.Errorf("failed to create token: %w", err)
return nil, "", apperrors.Internal(err)
}
// Generate confirmation code - use fixed code in debug mode for easier local testing
@@ -203,10 +206,10 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
if req.Email != nil && *req.Email != user.Email {
exists, err := s.userRepo.ExistsByEmail(*req.Email)
if err != nil {
return nil, fmt.Errorf("failed to check email: %w", err)
return nil, apperrors.Internal(err)
}
if exists {
return nil, ErrEmailTaken
return nil, apperrors.Conflict("error.email_already_taken")
}
user.Email = *req.Email
}
@@ -219,7 +222,7 @@ func (s *AuthService) UpdateProfile(userID uint, req *requests.UpdateProfileRequ
}
if err := s.userRepo.Update(user); err != nil {
return nil, fmt.Errorf("failed to update user: %w", err)
return nil, apperrors.Internal(err)
}
// Reload with profile
@@ -237,18 +240,18 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
// Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID)
if err != nil {
return fmt.Errorf("failed to get profile: %w", err)
return apperrors.Internal(err)
}
// Check if already verified
if profile.Verified {
return ErrAlreadyVerified
return apperrors.BadRequest("error.email_already_verified")
}
// Check for test code in debug mode
if s.cfg.Server.Debug && code == "123456" {
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
return fmt.Errorf("failed to verify profile: %w", err)
return apperrors.Internal(err)
}
return nil
}
@@ -257,22 +260,22 @@ func (s *AuthService) VerifyEmail(userID uint, code string) error {
confirmCode, err := s.userRepo.FindConfirmationCode(userID, code)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) {
return ErrInvalidCode
return apperrors.BadRequest("error.invalid_verification_code")
}
if errors.Is(err, repositories.ErrCodeExpired) {
return ErrCodeExpired
return apperrors.BadRequest("error.verification_code_expired")
}
return err
return apperrors.Internal(err)
}
// Mark code as used
if err := s.userRepo.MarkConfirmationCodeUsed(confirmCode.ID); err != nil {
return fmt.Errorf("failed to mark code as used: %w", err)
return apperrors.Internal(err)
}
// Set profile as verified
if err := s.userRepo.SetProfileVerified(userID, true); err != nil {
return fmt.Errorf("failed to verify profile: %w", err)
return apperrors.Internal(err)
}
return nil
@@ -283,12 +286,12 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
// Get user profile
profile, err := s.userRepo.GetOrCreateProfile(userID)
if err != nil {
return "", fmt.Errorf("failed to get profile: %w", err)
return "", apperrors.Internal(err)
}
// Check if already verified
if profile.Verified {
return "", ErrAlreadyVerified
return "", apperrors.BadRequest("error.email_already_verified")
}
// Generate new code - use fixed code in debug mode for easier local testing
@@ -301,7 +304,7 @@ func (s *AuthService) ResendVerificationCode(userID uint) (string, error) {
expiresAt := time.Now().UTC().Add(s.cfg.Security.ConfirmationExpiry)
if _, err := s.userRepo.CreateConfirmationCode(userID, code, expiresAt); err != nil {
return "", fmt.Errorf("failed to create confirmation code: %w", err)
return "", apperrors.Internal(err)
}
return code, nil
@@ -322,10 +325,10 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
// Check rate limit
count, err := s.userRepo.CountRecentPasswordResetRequests(user.ID)
if err != nil {
return "", nil, fmt.Errorf("failed to check rate limit: %w", err)
return "", nil, apperrors.Internal(err)
}
if count >= int64(s.cfg.Security.MaxPasswordResetRate) {
return "", nil, ErrRateLimitExceeded
return "", nil, apperrors.TooManyRequests("error.rate_limit_exceeded")
}
// Generate code and reset token - use fixed code in debug mode for easier local testing
@@ -341,11 +344,11 @@ func (s *AuthService) ForgotPassword(email string) (string, *models.User, error)
// Hash the code before storing
codeHash, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return "", nil, fmt.Errorf("failed to hash code: %w", err)
return "", nil, apperrors.Internal(err)
}
if _, err := s.userRepo.CreatePasswordResetCode(user.ID, string(codeHash), resetToken, expiresAt); err != nil {
return "", nil, fmt.Errorf("failed to create reset code: %w", err)
return "", nil, apperrors.Internal(err)
}
return code, user, nil
@@ -357,9 +360,9 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
resetCode, user, err := s.userRepo.FindPasswordResetCodeByEmail(email)
if err != nil {
if errors.Is(err, repositories.ErrUserNotFound) || errors.Is(err, repositories.ErrCodeNotFound) {
return "", ErrInvalidCode
return "", apperrors.BadRequest("error.invalid_verification_code")
}
return "", err
return "", apperrors.Internal(err)
}
// Check for test code in debug mode
@@ -371,18 +374,18 @@ func (s *AuthService) VerifyResetCode(email, code string) (string, error) {
if !resetCode.CheckCode(code) {
// Increment attempts
s.userRepo.IncrementResetCodeAttempts(resetCode.ID)
return "", ErrInvalidCode
return "", apperrors.BadRequest("error.invalid_verification_code")
}
// Check if code is still valid
if !resetCode.IsValid() {
if resetCode.Used {
return "", ErrInvalidCode
return "", apperrors.BadRequest("error.invalid_verification_code")
}
if resetCode.Attempts >= resetCode.MaxAttempts {
return "", ErrRateLimitExceeded
return "", apperrors.TooManyRequests("error.rate_limit_exceeded")
}
return "", ErrCodeExpired
return "", apperrors.BadRequest("error.verification_code_expired")
}
_ = user // user available if needed
@@ -396,24 +399,24 @@ func (s *AuthService) ResetPassword(resetToken, newPassword string) error {
resetCode, err := s.userRepo.FindPasswordResetCodeByToken(resetToken)
if err != nil {
if errors.Is(err, repositories.ErrCodeNotFound) || errors.Is(err, repositories.ErrCodeExpired) {
return ErrInvalidResetToken
return apperrors.BadRequest("error.invalid_reset_token")
}
return err
return apperrors.Internal(err)
}
// Get the user
user, err := s.userRepo.FindByID(resetCode.UserID)
if err != nil {
return fmt.Errorf("failed to find user: %w", err)
return apperrors.Internal(err)
}
// Update password
if err := user.SetPassword(newPassword); err != nil {
return fmt.Errorf("failed to hash password: %w", err)
return apperrors.Internal(err)
}
if err := s.userRepo.Update(user); err != nil {
return fmt.Errorf("failed to update user: %w", err)
return apperrors.Internal(err)
}
// Mark reset code as used
@@ -436,7 +439,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// 1. Verify the Apple JWT token
claims, err := appleAuth.VerifyIdentityToken(ctx, req.IDToken)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrAppleSignInFailed, err)
return nil, apperrors.Unauthorized("error.invalid_credentials").Wrap(err)
}
// Use the subject from claims as the authoritative Apple ID
@@ -451,17 +454,17 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// User already linked with this Apple ID - log them in
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
return nil, apperrors.Internal(err)
}
if !user.IsActive {
return nil, ErrUserInactive
return nil, apperrors.Unauthorized("error.account_inactive")
}
// Get or create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Update last login
@@ -487,7 +490,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
}
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
return nil, fmt.Errorf("failed to link Apple ID: %w", err)
return nil, apperrors.Internal(err)
}
// Mark as verified since Apple verified the email
@@ -496,7 +499,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
// Get or create token
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Update last login
@@ -529,7 +532,7 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
_ = user.SetPassword(randomPassword)
if err := s.userRepo.Create(user); err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
return nil, apperrors.Internal(err)
}
// Create profile (already verified since Apple verified)
@@ -554,13 +557,13 @@ func (s *AuthService) AppleSignIn(ctx context.Context, appleAuth *AppleAuthServi
IsPrivateEmail: isPrivateRelayEmail(email) || claims.IsPrivateRelayEmail(),
}
if err := s.userRepo.CreateAppleSocialAuth(appleAuthRecord); err != nil {
return nil, fmt.Errorf("failed to create Apple auth: %w", err)
return nil, apperrors.Internal(err)
}
// Create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Reload user with profile
@@ -578,12 +581,12 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// 1. Verify the Google ID token
tokenInfo, err := googleAuth.VerifyIDToken(ctx, req.IDToken)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrGoogleSignInFailed, err)
return nil, apperrors.Unauthorized("error.invalid_credentials").Wrap(err)
}
googleID := tokenInfo.Sub
if googleID == "" {
return nil, fmt.Errorf("%w: missing subject claim", ErrGoogleSignInFailed)
return nil, apperrors.Unauthorized("error.invalid_credentials")
}
// 2. Check if this Google ID is already linked to an account
@@ -592,17 +595,17 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// User already linked with this Google ID - log them in
user, err := s.userRepo.FindByIDWithProfile(existingAuth.UserID)
if err != nil {
return nil, fmt.Errorf("failed to find user: %w", err)
return nil, apperrors.Internal(err)
}
if !user.IsActive {
return nil, ErrUserInactive
return nil, apperrors.Unauthorized("error.account_inactive")
}
// Get or create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Update last login
@@ -629,7 +632,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
Picture: tokenInfo.Picture,
}
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil {
return nil, fmt.Errorf("failed to link Google ID: %w", err)
return nil, apperrors.Internal(err)
}
// Mark as verified since Google verified the email
@@ -640,7 +643,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
// Get or create token
token, err := s.userRepo.GetOrCreateToken(existingUser.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Update last login
@@ -673,7 +676,7 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
_ = user.SetPassword(randomPassword)
if err := s.userRepo.Create(user); err != nil {
return nil, fmt.Errorf("failed to create user: %w", err)
return nil, apperrors.Internal(err)
}
// Create profile (already verified if Google verified email)
@@ -699,13 +702,13 @@ func (s *AuthService) GoogleSignIn(ctx context.Context, googleAuth *GoogleAuthSe
Picture: tokenInfo.Picture,
}
if err := s.userRepo.CreateGoogleSocialAuth(googleAuthRecord); err != nil {
return nil, fmt.Errorf("failed to create Google auth: %w", err)
return nil, apperrors.Internal(err)
}
// Create token
token, err := s.userRepo.GetOrCreateToken(user.ID)
if err != nil {
return nil, fmt.Errorf("failed to create token: %w", err)
return nil, apperrors.Internal(err)
}
// Reload user with profile

View File

@@ -5,17 +5,18 @@ import (
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/dto/responses"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
)
// Contractor-related errors
var (
ErrContractorNotFound = errors.New("contractor not found")
ErrContractorAccessDenied = errors.New("you do not have access to this contractor")
)
// Deprecated: Use apperrors.NotFound("error.contractor_not_found") instead
// var (
// ErrContractorNotFound = errors.New("contractor not found")
// ErrContractorAccessDenied = errors.New("you do not have access to this contractor")
// )
// ContractorService handles contractor business logic
type ContractorService struct {
@@ -36,14 +37,14 @@ func (s *ContractorService) GetContractor(contractorID, userID uint) (*responses
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
return nil, apperrors.NotFound("error.contractor_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
return nil, ErrContractorAccessDenied
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
resp := responses.NewContractorResponse(contractor)
@@ -73,13 +74,13 @@ func (s *ContractorService) ListContractors(userID uint) ([]responses.Contractor
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// FindByUser now handles both personal and residence contractors
contractors, err := s.contractorRepo.FindByUser(userID, residenceIDs)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewContractorListResponse(contractors), nil
@@ -91,10 +92,10 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
if req.ResidenceID != nil {
hasAccess, err := s.residenceRepo.HasAccess(*req.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
}
@@ -122,20 +123,20 @@ func (s *ContractorService) CreateContractor(req *requests.CreateContractorReque
}
if err := s.contractorRepo.Create(contractor); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Set specialties if provided
if len(req.SpecialtyIDs) > 0 {
if err := s.contractorRepo.SetSpecialties(contractor.ID, req.SpecialtyIDs); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
}
// Reload with relations
contractor, reloadErr := s.contractorRepo.FindByID(contractor.ID)
if reloadErr != nil {
return nil, reloadErr
return nil, apperrors.Internal(reloadErr)
}
resp := responses.NewContractorResponse(contractor)
@@ -147,14 +148,14 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
return nil, apperrors.NotFound("error.contractor_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
return nil, ErrContractorAccessDenied
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
// Apply updates
@@ -199,20 +200,20 @@ func (s *ContractorService) UpdateContractor(contractorID, userID uint, req *req
contractor.ResidenceID = req.ResidenceID
if err := s.contractorRepo.Update(contractor); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Update specialties if provided
if req.SpecialtyIDs != nil {
if err := s.contractorRepo.SetSpecialties(contractorID, req.SpecialtyIDs); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
}
// Reload
contractor, err = s.contractorRepo.FindByID(contractorID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewContractorResponse(contractor)
@@ -224,17 +225,21 @@ func (s *ContractorService) DeleteContractor(contractorID, userID uint) error {
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrContractorNotFound
return apperrors.NotFound("error.contractor_not_found")
}
return err
return apperrors.Internal(err)
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
return ErrContractorAccessDenied
return apperrors.Forbidden("error.contractor_access_denied")
}
return s.contractorRepo.Delete(contractorID)
if err := s.contractorRepo.Delete(contractorID); err != nil {
return apperrors.Internal(err)
}
return nil
}
// ToggleFavorite toggles the favorite status of a contractor and returns the updated contractor
@@ -242,25 +247,25 @@ func (s *ContractorService) ToggleFavorite(contractorID, userID uint) (*response
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
return nil, apperrors.NotFound("error.contractor_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
return nil, ErrContractorAccessDenied
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
_, err = s.contractorRepo.ToggleFavorite(contractorID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Re-fetch the contractor to get the updated state with all relations
contractor, err = s.contractorRepo.FindByID(contractorID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewContractorResponse(contractor)
@@ -272,19 +277,19 @@ func (s *ContractorService) GetContractorTasks(contractorID, userID uint) ([]res
contractor, err := s.contractorRepo.FindByID(contractorID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrContractorNotFound
return nil, apperrors.NotFound("error.contractor_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
if !s.hasContractorAccess(contractor, userID) {
return nil, ErrContractorAccessDenied
return nil, apperrors.Forbidden("error.contractor_access_denied")
}
tasks, err := s.contractorRepo.GetTasksForContractor(contractorID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewTaskListResponse(tasks), nil
@@ -295,15 +300,15 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
// Check user has access to the residence
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
contractors, err := s.contractorRepo.FindByResidence(residenceID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewContractorListResponse(contractors), nil
@@ -313,7 +318,7 @@ func (s *ContractorService) ListContractorsByResidence(residenceID, userID uint)
func (s *ContractorService) GetSpecialties() ([]responses.ContractorSpecialtyResponse, error) {
specialties, err := s.contractorRepo.GetAllSpecialties()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]responses.ContractorSpecialtyResponse, len(specialties))

View File

@@ -5,6 +5,7 @@ import (
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/dto/responses"
"github.com/treytartt/casera-api/internal/models"
@@ -12,10 +13,11 @@ import (
)
// Document-related errors
var (
ErrDocumentNotFound = errors.New("document not found")
ErrDocumentAccessDenied = errors.New("you do not have access to this document")
)
// DEPRECATED: These constants are deprecated. Use apperrors package instead.
// var (
// ErrDocumentNotFound = errors.New("document not found")
// ErrDocumentAccessDenied = errors.New("you do not have access to this document")
// )
// DocumentService handles document business logic
type DocumentService struct {
@@ -36,18 +38,18 @@ func (s *DocumentService) GetDocument(documentID, userID uint) (*responses.Docum
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDocumentNotFound
return nil, apperrors.NotFound("error.document_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
return nil, apperrors.Forbidden("error.document_access_denied")
}
resp := responses.NewDocumentResponse(document)
@@ -59,7 +61,7 @@ func (s *DocumentService) ListDocuments(userID uint) ([]responses.DocumentRespon
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if len(residenceIDs) == 0 {
@@ -68,7 +70,7 @@ func (s *DocumentService) ListDocuments(userID uint) ([]responses.DocumentRespon
documents, err := s.documentRepo.FindByUser(residenceIDs)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewDocumentListResponse(documents), nil
@@ -79,7 +81,7 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if len(residenceIDs) == 0 {
@@ -88,7 +90,7 @@ func (s *DocumentService) ListWarranties(userID uint) ([]responses.DocumentRespo
documents, err := s.documentRepo.FindWarranties(residenceIDs)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewDocumentListResponse(documents), nil
@@ -99,10 +101,10 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
// Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
documentType := req.DocumentType
@@ -131,7 +133,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
}
if err := s.documentRepo.Create(document); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Create images if provided
@@ -151,7 +153,7 @@ func (s *DocumentService) CreateDocument(req *requests.CreateDocumentRequest, us
// Reload with relations
document, err = s.documentRepo.FindByID(document.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewDocumentResponse(document)
@@ -163,18 +165,18 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDocumentNotFound
return nil, apperrors.NotFound("error.document_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
return nil, apperrors.Forbidden("error.document_access_denied")
}
// Apply updates
@@ -222,13 +224,13 @@ func (s *DocumentService) UpdateDocument(documentID, userID uint, req *requests.
}
if err := s.documentRepo.Update(document); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
document, err = s.documentRepo.FindByID(documentID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewDocumentResponse(document)
@@ -240,21 +242,25 @@ func (s *DocumentService) DeleteDocument(documentID, userID uint) error {
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrDocumentNotFound
return apperrors.NotFound("error.document_not_found")
}
return err
return apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return err
return apperrors.Internal(err)
}
if !hasAccess {
return ErrDocumentAccessDenied
return apperrors.Forbidden("error.document_access_denied")
}
return s.documentRepo.Delete(documentID)
if err := s.documentRepo.Delete(documentID); err != nil {
return apperrors.Internal(err)
}
return nil
}
// ActivateDocument activates a document
@@ -262,26 +268,26 @@ func (s *DocumentService) ActivateDocument(documentID, userID uint) (*responses.
// First check if document exists (even if inactive)
var document models.Document
if err := s.documentRepo.FindByIDIncludingInactive(documentID, &document); err != nil {
return nil, ErrDocumentNotFound
return nil, apperrors.NotFound("error.document_not_found")
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
return nil, apperrors.Forbidden("error.document_access_denied")
}
if err := s.documentRepo.Activate(documentID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
doc, err := s.documentRepo.FindByID(documentID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewDocumentResponse(doc)
@@ -293,22 +299,22 @@ func (s *DocumentService) DeactivateDocument(documentID, userID uint) (*response
document, err := s.documentRepo.FindByID(documentID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrDocumentNotFound
return nil, apperrors.NotFound("error.document_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(document.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrDocumentAccessDenied
return nil, apperrors.Forbidden("error.document_access_denied")
}
if err := s.documentRepo.Deactivate(documentID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
document.IsActive = false

View File

@@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/push"
"github.com/treytartt/casera-api/internal/repositories"
@@ -15,8 +16,11 @@ import (
// Notification-related errors
var (
// Deprecated: Use apperrors.NotFound("error.notification_not_found") instead
ErrNotificationNotFound = errors.New("notification not found")
// Deprecated: Use apperrors.NotFound("error.device_not_found") instead
ErrDeviceNotFound = errors.New("device not found")
// Deprecated: Use apperrors.BadRequest("error.invalid_platform") instead
ErrInvalidPlatform = errors.New("invalid platform, must be 'ios' or 'android'")
)
@@ -40,7 +44,7 @@ func NewNotificationService(notificationRepo *repositories.NotificationRepositor
func (s *NotificationService) GetNotifications(userID uint, limit, offset int) ([]NotificationResponse, error) {
notifications, err := s.notificationRepo.FindByUser(userID, limit, offset)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]NotificationResponse, len(notifications))
@@ -52,7 +56,11 @@ func (s *NotificationService) GetNotifications(userID uint, limit, offset int) (
// GetUnreadCount gets the count of unread notifications
func (s *NotificationService) GetUnreadCount(userID uint) (int64, error) {
return s.notificationRepo.CountUnread(userID)
count, err := s.notificationRepo.CountUnread(userID)
if err != nil {
return 0, apperrors.Internal(err)
}
return count, nil
}
// MarkAsRead marks a notification as read
@@ -60,21 +68,27 @@ func (s *NotificationService) MarkAsRead(notificationID, userID uint) error {
notification, err := s.notificationRepo.FindByID(notificationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrNotificationNotFound
return apperrors.NotFound("error.notification_not_found")
}
return err
return apperrors.Internal(err)
}
if notification.UserID != userID {
return ErrNotificationNotFound
return apperrors.NotFound("error.notification_not_found")
}
return s.notificationRepo.MarkAsRead(notificationID)
if err := s.notificationRepo.MarkAsRead(notificationID); err != nil {
return apperrors.Internal(err)
}
return nil
}
// MarkAllAsRead marks all notifications as read
func (s *NotificationService) MarkAllAsRead(userID uint) error {
return s.notificationRepo.MarkAllAsRead(userID)
if err := s.notificationRepo.MarkAllAsRead(userID); err != nil {
return apperrors.Internal(err)
}
return nil
}
// CreateAndSendNotification creates a notification and sends it via push
@@ -82,7 +96,7 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
// Check user preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return err
return apperrors.Internal(err)
}
// Check if notification type is enabled
@@ -101,13 +115,13 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
}
if err := s.notificationRepo.Create(notification); err != nil {
return err
return apperrors.Internal(err)
}
// Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
if err != nil {
return err
return apperrors.Internal(err)
}
// Convert data for push
@@ -128,11 +142,14 @@ func (s *NotificationService) CreateAndSendNotification(ctx context.Context, use
err = s.pushClient.SendToAll(ctx, iosTokens, androidTokens, title, body, pushData)
if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error())
return err
return apperrors.Internal(err)
}
}
return s.notificationRepo.MarkAsSent(notification.ID)
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil {
return apperrors.Internal(err)
}
return nil
}
// isNotificationEnabled checks if a notification type is enabled for user
@@ -161,7 +178,7 @@ func (s *NotificationService) isNotificationEnabled(prefs *models.NotificationPr
func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewNotificationPreferencesResponse(prefs), nil
}
@@ -170,7 +187,7 @@ func (s *NotificationService) GetPreferences(userID uint) (*NotificationPreferen
func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferencesRequest) (*NotificationPreferencesResponse, error) {
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if req.TaskDueSoon != nil {
@@ -214,7 +231,7 @@ func (s *NotificationService) UpdatePreferences(userID uint, req *UpdatePreferen
}
if err := s.notificationRepo.UpdatePreferences(prefs); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewNotificationPreferencesResponse(prefs), nil
@@ -230,7 +247,7 @@ func (s *NotificationService) RegisterDevice(userID uint, req *RegisterDeviceReq
case push.PlatformAndroid:
return s.registerGCMDevice(userID, req)
default:
return nil, ErrInvalidPlatform
return nil, apperrors.BadRequest("error.invalid_platform")
}
}
@@ -244,7 +261,7 @@ func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDevic
existing.Name = req.Name
existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateAPNSDevice(existing); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewAPNSDeviceResponse(existing), nil
}
@@ -258,7 +275,7 @@ func (s *NotificationService) registerAPNSDevice(userID uint, req *RegisterDevic
Active: true,
}
if err := s.notificationRepo.CreateAPNSDevice(device); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewAPNSDeviceResponse(device), nil
}
@@ -273,7 +290,7 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice
existing.Name = req.Name
existing.DeviceID = req.DeviceID
if err := s.notificationRepo.UpdateGCMDevice(existing); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewGCMDeviceResponse(existing), nil
}
@@ -288,7 +305,7 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice
Active: true,
}
if err := s.notificationRepo.CreateGCMDevice(device); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewGCMDeviceResponse(device), nil
}
@@ -297,12 +314,12 @@ func (s *NotificationService) registerGCMDevice(userID uint, req *RegisterDevice
func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error) {
iosDevices, err := s.notificationRepo.FindAPNSDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
return nil, apperrors.Internal(err)
}
androidDevices, err := s.notificationRepo.FindGCMDevicesByUser(userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]DeviceResponse, 0, len(iosDevices)+len(androidDevices))
@@ -317,14 +334,19 @@ func (s *NotificationService) ListDevices(userID uint) ([]DeviceResponse, error)
// DeleteDevice deletes a device
func (s *NotificationService) DeleteDevice(deviceID uint, platform string, userID uint) error {
var err error
switch platform {
case push.PlatformIOS:
return s.notificationRepo.DeactivateAPNSDevice(deviceID)
err = s.notificationRepo.DeactivateAPNSDevice(deviceID)
case push.PlatformAndroid:
return s.notificationRepo.DeactivateGCMDevice(deviceID)
err = s.notificationRepo.DeactivateGCMDevice(deviceID)
default:
return ErrInvalidPlatform
return apperrors.BadRequest("error.invalid_platform")
}
if err != nil {
return apperrors.Internal(err)
}
return nil
}
// === Response/Request Types ===
@@ -490,7 +512,7 @@ func (s *NotificationService) CreateAndSendTaskNotification(
// Check user notification preferences
prefs, err := s.notificationRepo.GetOrCreatePreferences(userID)
if err != nil {
return err
return apperrors.Internal(err)
}
if !s.isNotificationEnabled(prefs, notificationType) {
return nil // Skip silently
@@ -527,13 +549,13 @@ func (s *NotificationService) CreateAndSendTaskNotification(
}
if err := s.notificationRepo.Create(notification); err != nil {
return err
return apperrors.Internal(err)
}
// Get device tokens
iosTokens, androidTokens, err := s.notificationRepo.GetActiveTokensForUser(userID)
if err != nil {
return err
return apperrors.Internal(err)
}
// Convert data for push payload
@@ -556,9 +578,12 @@ func (s *NotificationService) CreateAndSendTaskNotification(
err = s.pushClient.SendActionableNotification(ctx, iosTokens, androidTokens, title, body, pushData, iosCategoryID)
if err != nil {
s.notificationRepo.SetError(notification.ID, err.Error())
return err
return apperrors.Internal(err)
}
}
return s.notificationRepo.MarkAsSent(notification.ID)
if err := s.notificationRepo.MarkAsSent(notification.ID); err != nil {
return apperrors.Internal(err)
}
return nil
}

View File

@@ -6,6 +6,7 @@ import (
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/dto/responses"
@@ -14,15 +15,17 @@ import (
"github.com/treytartt/casera-api/internal/task/predicates"
)
// Common errors
// Common errors (deprecated - kept for reference, now using apperrors package)
// Most errors have been migrated to apperrors, but some are still used by other handlers
// TODO: Migrate handlers to use apperrors instead of these constants
var (
ErrResidenceNotFound = errors.New("residence not found")
ErrResidenceAccessDenied = errors.New("you do not have access to this residence")
ErrNotResidenceOwner = errors.New("only the residence owner can perform this action")
ErrCannotRemoveOwner = errors.New("cannot remove the owner from the residence")
ErrUserAlreadyMember = errors.New("user is already a member of this residence")
ErrShareCodeInvalid = errors.New("invalid or expired share code")
ErrShareCodeExpired = errors.New("share code has expired")
ErrResidenceNotFound = errors.New("residence not found")
ErrResidenceAccessDenied = errors.New("you do not have access to this residence")
ErrNotResidenceOwner = errors.New("only the residence owner can perform this action")
ErrCannotRemoveOwner = errors.New("cannot remove the owner from the residence")
ErrUserAlreadyMember = errors.New("user is already a member of this residence")
ErrShareCodeInvalid = errors.New("invalid or expired share code")
ErrShareCodeExpired = errors.New("share code has expired")
ErrPropertiesLimitReached = errors.New("you have reached the maximum number of properties for your subscription tier")
)
@@ -53,18 +56,18 @@ func (s *ResidenceService) GetResidence(residenceID, userID uint) (*responses.Re
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
residence, err := s.residenceRepo.FindByID(residenceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrResidenceNotFound
return nil, apperrors.NotFound("error.residence_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewResidenceResponse(residence)
@@ -75,7 +78,7 @@ func (s *ResidenceService) GetResidence(residenceID, userID uint) (*responses.Re
func (s *ResidenceService) ListResidences(userID uint) ([]responses.ResidenceResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewResidenceListResponse(residences), nil
@@ -84,38 +87,31 @@ func (s *ResidenceService) ListResidences(userID uint) ([]responses.ResidenceRes
// GetMyResidences returns residences with additional details (tasks, completions, etc.)
// This is the "my-residences" endpoint that returns richer data.
// The `now` parameter should be the start of day in the user's timezone for accurate overdue detection.
//
// NOTE: Summary statistics (TotalTasks, TotalOverdue, etc.) are now calculated client-side
// from kanban data for performance. Only TotalResidences and per-residence OverdueCount
// are returned from the server.
func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*responses.MyResidencesResponse, error) {
residences, err := s.residenceRepo.FindByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
residenceResponses := responses.NewResidenceListResponse(residences)
// Build summary with real task statistics
// Summary statistics (TotalTasks, TotalOverdue, etc.) are calculated client-side
// from kanban data. We only populate TotalResidences here.
summary := responses.TotalSummary{
TotalResidences: len(residences),
}
// Get task statistics if task repository is available
// Get per-residence overdue counts for residence card badges
if s.taskRepo != nil && len(residences) > 0 {
// Collect residence IDs
residenceIDs := make([]uint, len(residences))
for i, r := range residences {
residenceIDs[i] = r.ID
}
// Get aggregated statistics using user's timezone-aware time
stats, err := s.taskRepo.GetTaskStatistics(residenceIDs, now)
if err == nil && stats != nil {
summary.TotalTasks = stats.TotalTasks
summary.TotalPending = stats.TotalPending
summary.TotalOverdue = stats.TotalOverdue
summary.TasksDueNextWeek = stats.TasksDueNextWeek
summary.TasksDueNextMonth = stats.TasksDueNextMonth
}
// Get per-residence overdue counts using user's timezone-aware time
overdueCounts, err := s.taskRepo.GetOverdueCountByResidence(residenceIDs, now)
if err == nil && overdueCounts != nil {
for i := range residenceResponses {
@@ -134,32 +130,22 @@ func (s *ResidenceService) GetMyResidences(userID uint, now time.Time) (*respons
// GetSummary returns just the task summary statistics for a user's residences.
// This is a lightweight endpoint for refreshing summary counts without full residence data.
// The `now` parameter should be the start of day in the user's timezone for accurate overdue detection.
//
// DEPRECATED: Summary statistics are now calculated client-side from kanban data.
// This endpoint only returns TotalResidences; other fields will be zero.
// Clients should use calculateSummaryFromKanban() instead.
func (s *ResidenceService) GetSummary(userID uint, now time.Time) (*responses.TotalSummary, error) {
// Get residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
summary := &responses.TotalSummary{
// Summary statistics are calculated client-side from kanban data.
// We only return TotalResidences here.
return &responses.TotalSummary{
TotalResidences: len(residenceIDs),
}
// Get task statistics if task repository is available
if s.taskRepo != nil && len(residenceIDs) > 0 {
// Get aggregated statistics using user's timezone-aware time
stats, err := s.taskRepo.GetTaskStatistics(residenceIDs, now)
if err == nil && stats != nil {
summary.TotalTasks = stats.TotalTasks
summary.TotalPending = stats.TotalPending
summary.TotalOverdue = stats.TotalOverdue
summary.TasksDueNextWeek = stats.TasksDueNextWeek
summary.TasksDueNextMonth = stats.TasksDueNextMonth
}
}
return summary, nil
}, nil
}
// getSummaryForUser returns an empty summary placeholder.
@@ -215,13 +201,13 @@ func (s *ResidenceService) CreateResidence(req *requests.CreateResidenceRequest,
}
if err := s.residenceRepo.Create(residence); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload with relations
residence, err := s.residenceRepo.FindByID(residence.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Get updated summary
@@ -238,18 +224,18 @@ func (s *ResidenceService) UpdateResidence(residenceID, userID uint, req *reques
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !isOwner {
return nil, ErrNotResidenceOwner
return nil, apperrors.Forbidden("error.not_residence_owner")
}
residence, err := s.residenceRepo.FindByID(residenceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrResidenceNotFound
return nil, apperrors.NotFound("error.residence_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Apply updates (only non-nil fields)
@@ -306,13 +292,13 @@ func (s *ResidenceService) UpdateResidence(residenceID, userID uint, req *reques
}
if err := s.residenceRepo.Update(residence); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload with relations
residence, err = s.residenceRepo.FindByID(residence.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Get updated summary
@@ -329,14 +315,14 @@ func (s *ResidenceService) DeleteResidence(residenceID, userID uint) (*responses
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !isOwner {
return nil, ErrNotResidenceOwner
return nil, apperrors.Forbidden("error.not_residence_owner")
}
if err := s.residenceRepo.Delete(residenceID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Get updated summary
@@ -353,10 +339,10 @@ func (s *ResidenceService) GenerateShareCode(residenceID, userID uint, expiresIn
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !isOwner {
return nil, ErrNotResidenceOwner
return nil, apperrors.Forbidden("error.not_residence_owner")
}
// Default to 24 hours if not specified
@@ -366,7 +352,7 @@ func (s *ResidenceService) GenerateShareCode(residenceID, userID uint, expiresIn
shareCode, err := s.residenceRepo.CreateShareCode(residenceID, userID, time.Duration(expiresInHours)*time.Hour)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.GenerateShareCodeResponse{
@@ -380,22 +366,22 @@ func (s *ResidenceService) GenerateSharePackage(residenceID, userID uint, expire
// Check ownership (only owners can share residences)
isOwner, err := s.residenceRepo.IsOwner(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !isOwner {
return nil, ErrNotResidenceOwner
return nil, apperrors.Forbidden("error.not_residence_owner")
}
// Get residence details for the package
residence, err := s.residenceRepo.FindByID(residenceID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Get the user who's sharing
user, err := s.userRepo.FindByID(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Default to 24 hours if not specified
@@ -406,7 +392,7 @@ func (s *ResidenceService) GenerateSharePackage(residenceID, userID uint, expire
// Generate the share code
shareCode, err := s.residenceRepo.CreateShareCode(residenceID, userID, time.Duration(expiresInHours)*time.Hour)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.SharePackageResponse{
@@ -423,23 +409,23 @@ func (s *ResidenceService) JoinWithCode(code string, userID uint) (*responses.Jo
shareCode, err := s.residenceRepo.FindShareCodeByCode(code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrShareCodeInvalid
return nil, apperrors.NotFound("error.share_code_invalid")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check if already a member
hasAccess, err := s.residenceRepo.HasAccess(shareCode.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if hasAccess {
return nil, ErrUserAlreadyMember
return nil, apperrors.Conflict("error.user_already_member")
}
// Add user to residence
if err := s.residenceRepo.AddUser(shareCode.ResidenceID, userID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Mark share code as used (one-time use)
@@ -451,7 +437,7 @@ func (s *ResidenceService) JoinWithCode(code string, userID uint) (*responses.Jo
// Get the residence with full details
residence, err := s.residenceRepo.FindByID(shareCode.ResidenceID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Get updated summary for the user
@@ -469,15 +455,15 @@ func (s *ResidenceService) GetResidenceUsers(residenceID, userID uint) ([]respon
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
users, err := s.residenceRepo.GetResidenceUsers(residenceID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]responses.ResidenceUserResponse, len(users))
@@ -493,39 +479,43 @@ func (s *ResidenceService) RemoveUser(residenceID, userIDToRemove, requestingUse
// Check ownership
isOwner, err := s.residenceRepo.IsOwner(residenceID, requestingUserID)
if err != nil {
return err
return apperrors.Internal(err)
}
if !isOwner {
return ErrNotResidenceOwner
return apperrors.Forbidden("error.not_residence_owner")
}
// Cannot remove the owner
if userIDToRemove == requestingUserID {
return ErrCannotRemoveOwner
return apperrors.BadRequest("error.cannot_remove_owner")
}
// Check if the residence exists
residence, err := s.residenceRepo.FindByIDSimple(residenceID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrResidenceNotFound
return apperrors.NotFound("error.residence_not_found")
}
return err
return apperrors.Internal(err)
}
// Cannot remove the owner
if userIDToRemove == residence.OwnerID {
return ErrCannotRemoveOwner
return apperrors.BadRequest("error.cannot_remove_owner")
}
return s.residenceRepo.RemoveUser(residenceID, userIDToRemove)
if err := s.residenceRepo.RemoveUser(residenceID, userIDToRemove); err != nil {
return apperrors.Internal(err)
}
return nil
}
// GetResidenceTypes returns all residence types
func (s *ResidenceService) GetResidenceTypes() ([]responses.ResidenceTypeResponse, error) {
types, err := s.residenceRepo.GetAllResidenceTypes()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]responses.ResidenceTypeResponse, len(types))
@@ -567,22 +557,25 @@ func (s *ResidenceService) GenerateTasksReport(residenceID, userID uint) (*Tasks
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
// Get residence details
residence, err := s.residenceRepo.FindByIDSimple(residenceID)
if err != nil {
return nil, ErrResidenceNotFound
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.NotFound("error.residence_not_found")
}
return nil, apperrors.Internal(err)
}
// Get all tasks for the residence
tasks, err := s.residenceRepo.GetTasksForReport(residenceID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
now := time.Now().UTC()

View File

@@ -1,6 +1,7 @@
package services
import (
"net/http"
"testing"
"github.com/shopspring/decimal"
@@ -115,7 +116,7 @@ func TestResidenceService_GetResidence_AccessDenied(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
_, err := service.GetResidence(residence.ID, otherUser.ID)
assert.ErrorIs(t, err, ErrResidenceAccessDenied)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
func TestResidenceService_GetResidence_NotFound(t *testing.T) {
@@ -188,7 +189,7 @@ func TestResidenceService_UpdateResidence_NotOwner(t *testing.T) {
req := &requests.UpdateResidenceRequest{Name: &newName}
_, err := service.UpdateResidence(residence.ID, sharedUser.ID, req)
assert.ErrorIs(t, err, ErrNotResidenceOwner)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
}
func TestResidenceService_DeleteResidence(t *testing.T) {
@@ -222,7 +223,7 @@ func TestResidenceService_DeleteResidence_NotOwner(t *testing.T) {
residenceRepo.AddUser(residence.ID, sharedUser.ID)
_, err := service.DeleteResidence(residence.ID, sharedUser.ID)
assert.ErrorIs(t, err, ErrNotResidenceOwner)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
}
func TestResidenceService_GenerateShareCode(t *testing.T) {
@@ -280,7 +281,7 @@ func TestResidenceService_JoinWithCode_AlreadyMember(t *testing.T) {
// Owner tries to join their own residence
_, err := service.JoinWithCode(shareResp.ShareCode.Code, owner.ID)
assert.ErrorIs(t, err, ErrUserAlreadyMember)
testutil.AssertAppError(t, err, http.StatusConflict, "error.user_already_member")
}
func TestResidenceService_GetResidenceUsers(t *testing.T) {
@@ -330,5 +331,5 @@ func TestResidenceService_RemoveUser_CannotRemoveOwner(t *testing.T) {
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
err := service.RemoveUser(residence.ID, owner.ID, owner.ID)
assert.ErrorIs(t, err, ErrCannotRemoveOwner)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.cannot_remove_owner")
}

View File

@@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/config"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
@@ -15,12 +16,19 @@ import (
// Subscription-related errors
var (
// Deprecated: Use apperrors.NotFound("error.subscription_not_found") instead
ErrSubscriptionNotFound = errors.New("subscription not found")
// Deprecated: Use apperrors.Forbidden("error.properties_limit_exceeded") instead
ErrPropertiesLimitExceeded = errors.New("properties limit exceeded for your subscription tier")
// Deprecated: Use apperrors.Forbidden("error.tasks_limit_exceeded") instead
ErrTasksLimitExceeded = errors.New("tasks limit exceeded for your subscription tier")
// Deprecated: Use apperrors.Forbidden("error.contractors_limit_exceeded") instead
ErrContractorsLimitExceeded = errors.New("contractors limit exceeded for your subscription tier")
// Deprecated: Use apperrors.Forbidden("error.documents_limit_exceeded") instead
ErrDocumentsLimitExceeded = errors.New("documents limit exceeded for your subscription tier")
// Deprecated: Use apperrors.NotFound("error.upgrade_trigger_not_found") instead
ErrUpgradeTriggerNotFound = errors.New("upgrade trigger not found")
// Deprecated: Use apperrors.NotFound("error.promotion_not_found") instead
ErrPromotionNotFound = errors.New("promotion not found")
)
@@ -93,7 +101,7 @@ func NewSubscriptionService(
func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return NewSubscriptionResponse(sub), nil
}
@@ -102,18 +110,18 @@ func (s *SubscriptionService) GetSubscription(userID uint) (*SubscriptionRespons
func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionStatusResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
settings, err := s.subscriptionRepo.GetSettings()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Get all tier limits and build a map
allLimits, err := s.subscriptionRepo.GetAllTierLimits()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
limitsMap := make(map[string]*TierLimitsClientResponse)
@@ -169,7 +177,7 @@ func (s *SubscriptionService) GetSubscriptionStatus(userID uint) (*SubscriptionS
func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error) {
residences, err := s.residenceRepo.FindOwnedByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
propertiesCount := int64(len(residences))
@@ -178,19 +186,19 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
for _, r := range residences {
tc, err := s.taskRepo.CountByResidence(r.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
tasksCount += tc
cc, err := s.contractorRepo.CountByResidence(r.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
contractorsCount += cc
dc, err := s.documentRepo.CountByResidence(r.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
documentsCount += dc
}
@@ -207,7 +215,7 @@ func (s *SubscriptionService) getUserUsage(userID uint) (*UsageResponse, error)
func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
settings, err := s.subscriptionRepo.GetSettings()
if err != nil {
return err
return apperrors.Internal(err)
}
// If limitations are disabled globally, allow everything
@@ -217,7 +225,7 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return err
return apperrors.Internal(err)
}
// IsFree users bypass all limitations
@@ -232,7 +240,7 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
limits, err := s.subscriptionRepo.GetTierLimits(sub.Tier)
if err != nil {
return err
return apperrors.Internal(err)
}
usage, err := s.getUserUsage(userID)
@@ -243,19 +251,19 @@ func (s *SubscriptionService) CheckLimit(userID uint, limitType string) error {
switch limitType {
case "properties":
if limits.PropertiesLimit != nil && usage.PropertiesCount >= int64(*limits.PropertiesLimit) {
return ErrPropertiesLimitExceeded
return apperrors.Forbidden("error.properties_limit_exceeded")
}
case "tasks":
if limits.TasksLimit != nil && usage.TasksCount >= int64(*limits.TasksLimit) {
return ErrTasksLimitExceeded
return apperrors.Forbidden("error.tasks_limit_exceeded")
}
case "contractors":
if limits.ContractorsLimit != nil && usage.ContractorsCount >= int64(*limits.ContractorsLimit) {
return ErrContractorsLimitExceeded
return apperrors.Forbidden("error.contractors_limit_exceeded")
}
case "documents":
if limits.DocumentsLimit != nil && usage.DocumentsCount >= int64(*limits.DocumentsLimit) {
return ErrDocumentsLimitExceeded
return apperrors.Forbidden("error.documents_limit_exceeded")
}
}
@@ -267,9 +275,9 @@ func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResp
trigger, err := s.subscriptionRepo.GetUpgradeTrigger(key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUpgradeTriggerNotFound
return nil, apperrors.NotFound("error.upgrade_trigger_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
return NewUpgradeTriggerResponse(trigger), nil
}
@@ -279,7 +287,7 @@ func (s *SubscriptionService) GetUpgradeTrigger(key string) (*UpgradeTriggerResp
func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTriggerDataResponse, error) {
triggers, err := s.subscriptionRepo.GetAllUpgradeTriggers()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make(map[string]*UpgradeTriggerDataResponse)
@@ -293,7 +301,7 @@ func (s *SubscriptionService) GetAllUpgradeTriggers() (map[string]*UpgradeTrigge
func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, error) {
benefits, err := s.subscriptionRepo.GetFeatureBenefits()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]FeatureBenefitResponse, len(benefits))
@@ -307,12 +315,12 @@ func (s *SubscriptionService) GetFeatureBenefits() ([]FeatureBenefitResponse, er
func (s *SubscriptionService) GetActivePromotions(userID uint) ([]PromotionResponse, error) {
sub, err := s.subscriptionRepo.GetOrCreate(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
promotions, err := s.subscriptionRepo.GetActivePromotions(sub.Tier)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]PromotionResponse, len(promotions))
@@ -331,7 +339,7 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
dataToStore = transactionID
}
if err := s.subscriptionRepo.UpdateReceiptData(userID, dataToStore); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Validate with Apple if client is configured
@@ -375,7 +383,7 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
// Upgrade to Pro with the determined expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "ios"); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return s.GetSubscription(userID)
@@ -386,7 +394,7 @@ func (s *SubscriptionService) ProcessApplePurchase(userID uint, receiptData stri
func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken string, productID string) (*SubscriptionResponse, error) {
// Store purchase token first
if err := s.subscriptionRepo.UpdatePurchaseToken(userID, purchaseToken); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Validate the purchase with Google if client is configured
@@ -443,7 +451,7 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
// Upgrade to Pro with the determined expiration
if err := s.subscriptionRepo.UpgradeToPro(userID, expiresAt, "android"); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return s.GetSubscription(userID)
@@ -452,7 +460,7 @@ func (s *SubscriptionService) ProcessGooglePurchase(userID uint, purchaseToken s
// CancelSubscription cancels a subscription (downgrades to free at end of period)
func (s *SubscriptionService) CancelSubscription(userID uint) (*SubscriptionResponse, error) {
if err := s.subscriptionRepo.SetAutoRenew(userID, false); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return s.GetSubscription(userID)
}

View File

@@ -869,8 +869,12 @@ func TestEdgeCase_TaskDueExactlyAtThreshold(t *testing.T) {
}
func TestEdgeCase_TaskDueJustBeforeThreshold(t *testing.T) {
// 29 days and 23 hours from now
dueDate := time.Now().UTC().Add(29*24*time.Hour + 23*time.Hour)
// Task due 29 days from today's start of day should be "due_soon"
// (within the 30-day threshold)
now := time.Now().UTC()
startOfToday := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
dueDate := startOfToday.AddDate(0, 0, 29) // 29 days from start of today
task := &models.Task{
NextDueDate: &dueDate,
}
@@ -878,7 +882,7 @@ func TestEdgeCase_TaskDueJustBeforeThreshold(t *testing.T) {
column := responses.DetermineKanbanColumn(task, 30)
assert.Equal(t, "due_soon_tasks", column,
"Task due just before threshold should be in due_soon")
"Task due 29 days from today should be in due_soon (within 30-day threshold)")
}
func TestEdgeCase_TaskDueInPast_ButHasCompletionAfter(t *testing.T) {
@@ -955,20 +959,23 @@ func TestEdgeCase_MonthlyRecurringTask(t *testing.T) {
CompletedAt: completedAt,
})
// Update NextDueDate
// Update NextDueDate - set to 29 days from today (within 30-day threshold)
task, _ = taskRepo.FindByID(task.ID)
nextDue := completedAt.AddDate(0, 0, 30) // 30 days from now
now := time.Now().UTC()
startOfToday := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
nextDue := startOfToday.AddDate(0, 0, 29) // 29 days from start of today (within threshold)
task.NextDueDate = &nextDue
db.Save(task)
task, _ = taskRepo.FindByID(task.ID)
// 30 days from now is at/within threshold boundary - due to time precision,
// a task at exactly the threshold boundary is considered "due_soon" not "upcoming"
// because the check is NextDueDate.Before(threshold) which includes boundary due to ms precision
// With day-based comparisons:
// - Threshold = start of today + 30 days
// - A task due on day 29 is Before(threshold), so it's "due_soon"
// - A task due on day 30+ is NOT Before(threshold), so it's "upcoming"
column := responses.DetermineKanbanColumn(task, 30)
assert.Equal(t, "due_soon_tasks", column,
"Monthly task at 30-day boundary should be due_soon (at threshold)")
"Monthly task within 30-day threshold should be due_soon")
}
func TestEdgeCase_ZeroDayFrequency_TreatedAsOneTime(t *testing.T) {

View File

@@ -8,16 +8,18 @@ import (
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/requests"
"github.com/treytartt/casera-api/internal/dto/responses"
"github.com/treytartt/casera-api/internal/models"
"github.com/treytartt/casera-api/internal/repositories"
)
// Task-related errors
// Task-related errors (DEPRECATED - kept for reference, use apperrors instead)
// TODO: Migrate handlers to use apperrors instead of these constants
var (
ErrTaskNotFound = errors.New("task not found")
ErrTaskAccessDenied = errors.New("you do not have access to this task")
ErrTaskNotFound = errors.New("task not found")
ErrTaskAccessDenied = errors.New("you do not have access to this task")
ErrTaskAlreadyCancelled = errors.New("task is already cancelled")
ErrTaskAlreadyArchived = errors.New("task is already archived")
ErrCompletionNotFound = errors.New("task completion not found")
@@ -71,18 +73,18 @@ func (s *TaskService) GetTask(taskID, userID uint) (*responses.TaskResponse, err
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
resp := responses.NewTaskResponse(task)
@@ -95,7 +97,7 @@ func (s *TaskService) ListTasks(userID uint, now time.Time) (*responses.KanbanBo
// Get all residence IDs accessible to user (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if len(residenceIDs) == 0 {
@@ -110,7 +112,7 @@ func (s *TaskService) ListTasks(userID uint, now time.Time) (*responses.KanbanBo
// Get kanban data aggregated across all residences using user's timezone-aware time
board, err := s.taskRepo.GetKanbanDataForMultipleResidences(residenceIDs, 30, now)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewKanbanBoardResponseForAll(board)
@@ -126,10 +128,10 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol
// Check access
hasAccess, err := s.residenceRepo.HasAccess(residenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
if daysThreshold <= 0 {
@@ -139,7 +141,7 @@ func (s *TaskService) GetTasksByResidence(residenceID, userID uint, daysThreshol
// Get kanban data using user's timezone-aware time
board, err := s.taskRepo.GetKanbanData(residenceID, daysThreshold, now)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
resp := responses.NewKanbanBoardResponse(board, residenceID)
@@ -155,10 +157,10 @@ func (s *TaskService) CreateTask(req *requests.CreateTaskRequest, userID uint, n
// Check residence access
hasAccess, err := s.residenceRepo.HasAccess(req.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrResidenceAccessDenied
return nil, apperrors.Forbidden("error.residence_access_denied")
}
dueDate := req.DueDate.ToTimePtr()
@@ -180,13 +182,13 @@ func (s *TaskService) CreateTask(req *requests.CreateTaskRequest, userID uint, n
}
if err := s.taskRepo.Create(task); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload with relations
task, err = s.taskRepo.FindByID(task.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -201,18 +203,18 @@ func (s *TaskService) UpdateTask(taskID, userID uint, req *requests.UpdateTaskRe
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
// Apply updates
@@ -260,13 +262,13 @@ func (s *TaskService) UpdateTask(taskID, userID uint, req *requests.UpdateTaskRe
}
if err := s.taskRepo.Update(task); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
task, err = s.taskRepo.FindByID(task.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -280,22 +282,22 @@ func (s *TaskService) DeleteTask(taskID, userID uint) (*responses.DeleteWithSumm
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.Delete(taskID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.DeleteWithSummaryResponse{
@@ -312,28 +314,28 @@ func (s *TaskService) MarkInProgress(taskID, userID uint, now time.Time) (*respo
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.MarkInProgress(taskID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -348,32 +350,32 @@ func (s *TaskService) CancelTask(taskID, userID uint, now time.Time) (*responses
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if task.IsCancelled {
return nil, ErrTaskAlreadyCancelled
return nil, apperrors.BadRequest("error.task_already_cancelled")
}
if err := s.taskRepo.Cancel(taskID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -388,28 +390,28 @@ func (s *TaskService) UncancelTask(taskID, userID uint, now time.Time) (*respons
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.Uncancel(taskID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -424,32 +426,32 @@ func (s *TaskService) ArchiveTask(taskID, userID uint, now time.Time) (*response
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if task.IsArchived {
return nil, ErrTaskAlreadyArchived
return nil, apperrors.BadRequest("error.task_already_archived")
}
if err := s.taskRepo.Archive(taskID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -464,28 +466,28 @@ func (s *TaskService) UnarchiveTask(taskID, userID uint, now time.Time) (*respon
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.Unarchive(taskID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload
task, err = s.taskRepo.FindByID(taskID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.TaskWithSummaryResponse{
@@ -503,18 +505,18 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
task, err := s.taskRepo.FindByID(req.TaskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
completedAt := time.Now().UTC()
@@ -532,7 +534,7 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
}
if err := s.taskRepo.CreateCompletion(completion); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Update next_due_date and in_progress based on frequency
@@ -589,7 +591,7 @@ func (s *TaskService) CreateCompletion(req *requests.CreateTaskCompletionRequest
// Reload completion with user info and images
completion, err = s.taskRepo.FindCompletionByID(completion.ID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
// Reload task with updated completions (so client can update kanban column)
@@ -622,18 +624,18 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrTaskNotFound
return apperrors.NotFound("error.task_not_found")
}
return err
return apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return err
return apperrors.Internal(err)
}
if !hasAccess {
return ErrTaskAccessDenied
return apperrors.Forbidden("error.task_access_denied")
}
completedAt := time.Now().UTC()
@@ -646,7 +648,7 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
}
if err := s.taskRepo.CreateCompletion(completion); err != nil {
return err
return apperrors.Internal(err)
}
// Update next_due_date and in_progress based on frequency
@@ -692,7 +694,7 @@ func (s *TaskService) QuickComplete(taskID uint, userID uint) error {
}
if err := s.taskRepo.Update(task); err != nil {
log.Error().Err(err).Uint("task_id", task.ID).Msg("Failed to update task after quick completion")
return err // Return error so caller knows the update failed
return apperrors.Internal(err) // Return error so caller knows the update failed
}
log.Info().Uint("task_id", task.ID).Msg("QuickComplete: Task updated successfully")
@@ -771,18 +773,18 @@ func (s *TaskService) GetCompletion(completionID, userID uint) (*responses.TaskC
completion, err := s.taskRepo.FindCompletionByID(completionID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCompletionNotFound
return nil, apperrors.NotFound("error.completion_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access via task's residence
hasAccess, err := s.residenceRepo.HasAccess(completion.Task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
resp := responses.NewTaskCompletionResponse(completion)
@@ -794,7 +796,7 @@ func (s *TaskService) ListCompletions(userID uint) ([]responses.TaskCompletionRe
// Get all residence IDs (lightweight - no preloads)
residenceIDs, err := s.residenceRepo.FindResidenceIDsByUser(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if len(residenceIDs) == 0 {
@@ -803,7 +805,7 @@ func (s *TaskService) ListCompletions(userID uint) ([]responses.TaskCompletionRe
completions, err := s.taskRepo.FindCompletionsByUser(userID, residenceIDs)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewTaskCompletionListResponse(completions), nil
@@ -814,22 +816,22 @@ func (s *TaskService) DeleteCompletion(completionID, userID uint) (*responses.De
completion, err := s.taskRepo.FindCompletionByID(completionID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrCompletionNotFound
return nil, apperrors.NotFound("error.completion_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access
hasAccess, err := s.residenceRepo.HasAccess(completion.Task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
if err := s.taskRepo.DeleteCompletion(completionID); err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return &responses.DeleteWithSummaryResponse{
@@ -844,24 +846,24 @@ func (s *TaskService) GetCompletionsByTask(taskID, userID uint) ([]responses.Tas
task, err := s.taskRepo.FindByID(taskID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrTaskNotFound
return nil, apperrors.NotFound("error.task_not_found")
}
return nil, err
return nil, apperrors.Internal(err)
}
// Check access via residence
hasAccess, err := s.residenceRepo.HasAccess(task.ResidenceID, userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if !hasAccess {
return nil, ErrTaskAccessDenied
return nil, apperrors.Forbidden("error.task_access_denied")
}
// Get completions for the task
completions, err := s.taskRepo.FindCompletionsByTask(taskID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
return responses.NewTaskCompletionListResponse(completions), nil
@@ -873,7 +875,7 @@ func (s *TaskService) GetCompletionsByTask(taskID, userID uint) ([]responses.Tas
func (s *TaskService) GetCategories() ([]responses.TaskCategoryResponse, error) {
categories, err := s.taskRepo.GetAllCategories()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]responses.TaskCategoryResponse, len(categories))
@@ -887,7 +889,7 @@ func (s *TaskService) GetCategories() ([]responses.TaskCategoryResponse, error)
func (s *TaskService) GetPriorities() ([]responses.TaskPriorityResponse, error) {
priorities, err := s.taskRepo.GetAllPriorities()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]responses.TaskPriorityResponse, len(priorities))
@@ -901,7 +903,7 @@ func (s *TaskService) GetPriorities() ([]responses.TaskPriorityResponse, error)
func (s *TaskService) GetFrequencies() ([]responses.TaskFrequencyResponse, error) {
frequencies, err := s.taskRepo.GetAllFrequencies()
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
result := make([]responses.TaskFrequencyResponse, len(frequencies))

View File

@@ -1,6 +1,7 @@
package services
import (
"net/http"
"testing"
"time"
@@ -105,7 +106,7 @@ func TestTaskService_CreateTask_AccessDenied(t *testing.T) {
now := time.Now().UTC()
_, err := service.CreateTask(req, otherUser.ID, now)
// When creating a task, residence access is checked first
assert.ErrorIs(t, err, ErrResidenceAccessDenied)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
}
func TestTaskService_GetTask(t *testing.T) {
@@ -138,7 +139,7 @@ func TestTaskService_GetTask_AccessDenied(t *testing.T) {
task := testutil.CreateTestTask(t, db, residence.ID, owner.ID, "Test Task")
_, err := service.GetTask(task.ID, otherUser.ID)
assert.ErrorIs(t, err, ErrTaskAccessDenied)
testutil.AssertAppError(t, err, http.StatusForbidden, "error.task_access_denied")
}
func TestTaskService_ListTasks(t *testing.T) {
@@ -239,7 +240,7 @@ func TestTaskService_CancelTask_AlreadyCancelled(t *testing.T) {
now := time.Now().UTC()
service.CancelTask(task.ID, user.ID, now)
_, err := service.CancelTask(task.ID, user.ID, now)
assert.ErrorIs(t, err, ErrTaskAlreadyCancelled)
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.task_already_cancelled")
}
func TestTaskService_UncancelTask(t *testing.T) {

View File

@@ -3,11 +3,13 @@ package services
import (
"errors"
"github.com/treytartt/casera-api/internal/apperrors"
"github.com/treytartt/casera-api/internal/dto/responses"
"github.com/treytartt/casera-api/internal/repositories"
)
var (
// Deprecated: Use apperrors.NotFound("error.user_not_found") instead
ErrUserNotFound = errors.New("user not found")
)
@@ -27,7 +29,7 @@ func NewUserService(userRepo *repositories.UserRepository) *UserService {
func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.UserSummary, error) {
users, err := s.userRepo.FindUsersInSharedResidences(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
var result []responses.UserSummary
@@ -48,10 +50,10 @@ func (s *UserService) ListUsersInSharedResidences(userID uint) ([]responses.User
func (s *UserService) GetUserIfSharedResidence(targetUserID, requestingUserID uint) (*responses.UserSummary, error) {
user, err := s.userRepo.FindUserIfSharedResidence(targetUserID, requestingUserID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
if user == nil {
return nil, ErrUserNotFound
return nil, apperrors.NotFound("error.user_not_found")
}
return &responses.UserSummary{
@@ -67,7 +69,7 @@ func (s *UserService) GetUserIfSharedResidence(targetUserID, requestingUserID ui
func (s *UserService) ListProfilesInSharedResidences(userID uint) ([]responses.UserProfileSummary, error) {
profiles, err := s.userRepo.FindProfilesInSharedResidences(userID)
if err != nil {
return nil, err
return nil, apperrors.Internal(err)
}
var result []responses.UserProfileSummary