Coverage priorities 1-5: test pure functions, extract interfaces, mock-based handler tests
- Priority 1: Test NewSendEmailTask + NewSendPushTask (5 tests) - Priority 2: Test customHTTPErrorHandler — all 15+ branches (21 tests) - Priority 3: Extract Enqueuer interface + payload builders in worker pkg (5 tests) - Priority 4: Extract ClassifyFile/ComputeRelPath in migrate-encrypt (6 tests) - Priority 5: Define Handler interfaces, refactor to accept them, mock-based tests (14 tests) - Fix .gitignore: /worker instead of worker to stop ignoring internal/worker/ Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -12,7 +12,9 @@
|
||||
"Bash(git add:*)",
|
||||
"Bash(docker ps:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(git push:*)"
|
||||
"Bash(git push:*)",
|
||||
"Bash(docker info:*)",
|
||||
"Bash(curl:*)"
|
||||
]
|
||||
},
|
||||
"enableAllProjectMcpServers": true,
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@
|
||||
# Binaries
|
||||
bin/
|
||||
api
|
||||
worker
|
||||
/worker
|
||||
/admin
|
||||
!admin/
|
||||
*.exe
|
||||
|
||||
61
cmd/backfill-completion-columns/main_test.go
Normal file
61
cmd/backfill-completion-columns/main_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClassifyCompletion_CompletedAfterDue_ReturnsOverdue(t *testing.T) {
|
||||
due := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC)
|
||||
completed := time.Date(2025, 6, 5, 14, 0, 0, 0, time.UTC)
|
||||
got := classifyCompletion(completed, due, 7)
|
||||
if got != "overdue_tasks" {
|
||||
t.Errorf("got %q, want overdue_tasks", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedOnDueDate_ReturnsDueSoon(t *testing.T) {
|
||||
due := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC)
|
||||
completed := time.Date(2025, 6, 1, 10, 0, 0, 0, time.UTC)
|
||||
got := classifyCompletion(completed, due, 7)
|
||||
if got != "due_soon_tasks" {
|
||||
t.Errorf("got %q, want due_soon_tasks", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedWithinThreshold_ReturnsDueSoon(t *testing.T) {
|
||||
due := time.Date(2025, 6, 10, 0, 0, 0, 0, time.UTC)
|
||||
completed := time.Date(2025, 6, 5, 0, 0, 0, 0, time.UTC) // 5 days before due, threshold 7
|
||||
got := classifyCompletion(completed, due, 7)
|
||||
if got != "due_soon_tasks" {
|
||||
t.Errorf("got %q, want due_soon_tasks", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedAtExactThreshold_ReturnsDueSoon(t *testing.T) {
|
||||
due := time.Date(2025, 6, 10, 0, 0, 0, 0, time.UTC)
|
||||
completed := time.Date(2025, 6, 3, 0, 0, 0, 0, time.UTC) // exactly 7 days before due
|
||||
got := classifyCompletion(completed, due, 7)
|
||||
if got != "due_soon_tasks" {
|
||||
t.Errorf("got %q, want due_soon_tasks", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedBeyondThreshold_ReturnsUpcoming(t *testing.T) {
|
||||
due := time.Date(2025, 6, 30, 0, 0, 0, 0, time.UTC)
|
||||
completed := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC) // 29 days before due, threshold 7
|
||||
got := classifyCompletion(completed, due, 7)
|
||||
if got != "upcoming_tasks" {
|
||||
t.Errorf("got %q, want upcoming_tasks", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_TimeNormalization_SameDayDifferentTimes(t *testing.T) {
|
||||
due := time.Date(2025, 6, 1, 23, 59, 59, 0, time.UTC)
|
||||
completed := time.Date(2025, 6, 1, 0, 0, 1, 0, time.UTC) // same day, different times
|
||||
got := classifyCompletion(completed, due, 7)
|
||||
// Same day → daysBefore == 0 → within threshold → due_soon
|
||||
if got != "due_soon_tasks" {
|
||||
t.Errorf("got %q, want due_soon_tasks", got)
|
||||
}
|
||||
}
|
||||
50
cmd/migrate-encrypt/helpers.go
Normal file
50
cmd/migrate-encrypt/helpers.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isEncrypted checks if a file path ends with .enc
|
||||
func isEncrypted(path string) bool {
|
||||
return strings.HasSuffix(path, ".enc")
|
||||
}
|
||||
|
||||
// encryptedPath appends .enc to the file path.
|
||||
func encryptedPath(path string) string {
|
||||
return path + ".enc"
|
||||
}
|
||||
|
||||
// shouldProcessFile returns true if the file should be encrypted.
|
||||
func shouldProcessFile(isDir bool, path string) bool {
|
||||
return !isDir && !isEncrypted(path)
|
||||
}
|
||||
|
||||
// FileAction represents the decision about what to do with a file during encryption migration.
|
||||
type FileAction int
|
||||
|
||||
const (
|
||||
ActionSkipDir FileAction = iota // Directory, skip
|
||||
ActionSkipEncrypted // Already encrypted, skip
|
||||
ActionDryRun // Would encrypt (dry run mode)
|
||||
ActionEncrypt // Should encrypt
|
||||
)
|
||||
|
||||
// ClassifyFile determines what action to take for a file during the walk.
|
||||
func ClassifyFile(isDir bool, path string, dryRun bool) FileAction {
|
||||
if isDir {
|
||||
return ActionSkipDir
|
||||
}
|
||||
if isEncrypted(path) {
|
||||
return ActionSkipEncrypted
|
||||
}
|
||||
if dryRun {
|
||||
return ActionDryRun
|
||||
}
|
||||
return ActionEncrypt
|
||||
}
|
||||
|
||||
// ComputeRelPath computes the relative path from base to path.
|
||||
func ComputeRelPath(base, path string) (string, error) {
|
||||
return filepath.Rel(base, path)
|
||||
}
|
||||
96
cmd/migrate-encrypt/helpers_test.go
Normal file
96
cmd/migrate-encrypt/helpers_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsEncrypted_EncFile_True(t *testing.T) {
|
||||
if !isEncrypted("photo.jpg.enc") {
|
||||
t.Error("expected true for .enc file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEncrypted_PdfFile_False(t *testing.T) {
|
||||
if isEncrypted("doc.pdf") {
|
||||
t.Error("expected false for .pdf file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsEncrypted_DotEncOnly_True(t *testing.T) {
|
||||
if !isEncrypted(".enc") {
|
||||
t.Error("expected true for '.enc'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptedPath_AppendsDotEnc(t *testing.T) {
|
||||
got := encryptedPath("uploads/photo.jpg")
|
||||
want := "uploads/photo.jpg.enc"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldProcessFile_RegularFile_True(t *testing.T) {
|
||||
if !shouldProcessFile(false, "photo.jpg") {
|
||||
t.Error("expected true for regular file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldProcessFile_Directory_False(t *testing.T) {
|
||||
if shouldProcessFile(true, "uploads") {
|
||||
t.Error("expected false for directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldProcessFile_AlreadyEncrypted_False(t *testing.T) {
|
||||
if shouldProcessFile(false, "photo.jpg.enc") {
|
||||
t.Error("expected false for already encrypted file")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ClassifyFile ---
|
||||
|
||||
func TestClassifyFile_Directory_SkipDir(t *testing.T) {
|
||||
if got := ClassifyFile(true, "uploads", false); got != ActionSkipDir {
|
||||
t.Errorf("got %d, want ActionSkipDir", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyFile_EncryptedFile_SkipEncrypted(t *testing.T) {
|
||||
if got := ClassifyFile(false, "photo.jpg.enc", false); got != ActionSkipEncrypted {
|
||||
t.Errorf("got %d, want ActionSkipEncrypted", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyFile_DryRun_DryRun(t *testing.T) {
|
||||
if got := ClassifyFile(false, "photo.jpg", true); got != ActionDryRun {
|
||||
t.Errorf("got %d, want ActionDryRun", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyFile_Normal_Encrypt(t *testing.T) {
|
||||
if got := ClassifyFile(false, "photo.jpg", false); got != ActionEncrypt {
|
||||
t.Errorf("got %d, want ActionEncrypt", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ComputeRelPath ---
|
||||
|
||||
func TestComputeRelPath_Valid(t *testing.T) {
|
||||
got, err := ComputeRelPath("/uploads", "/uploads/photo.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "photo.jpg" {
|
||||
t.Errorf("got %q, want %q", got, "photo.jpg")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeRelPath_NestedPath(t *testing.T) {
|
||||
got, err := ComputeRelPath("/uploads", "/uploads/2024/01/photo.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := "2024/01/photo.jpg"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"flag"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -87,13 +86,11 @@ func main() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
action := ClassifyFile(info.IsDir(), path, *dryRun)
|
||||
switch action {
|
||||
case ActionSkipDir:
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip files already encrypted
|
||||
if strings.HasSuffix(path, ".enc") {
|
||||
case ActionSkipEncrypted:
|
||||
skipped++
|
||||
return nil
|
||||
}
|
||||
@@ -101,14 +98,14 @@ func main() {
|
||||
totalFiles++
|
||||
|
||||
// Compute the relative path from upload dir
|
||||
relPath, err := filepath.Rel(absUploadDir, path)
|
||||
relPath, err := ComputeRelPath(absUploadDir, path)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("path", path).Msg("Failed to compute relative path")
|
||||
errCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
if *dryRun {
|
||||
if action == ActionDryRun {
|
||||
log.Info().Str("file", relPath).Msg("[DRY RUN] Would encrypt")
|
||||
return nil
|
||||
}
|
||||
|
||||
24
cmd/worker/startup.go
Normal file
24
cmd/worker/startup.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import "github.com/treytartt/honeydue-api/internal/worker/jobs"
|
||||
|
||||
// queuePriorities returns the Asynq queue priority map.
|
||||
func queuePriorities() map[string]int {
|
||||
return map[string]int{
|
||||
"critical": 6,
|
||||
"default": 3,
|
||||
"low": 1,
|
||||
}
|
||||
}
|
||||
|
||||
// allJobTypes returns all registered job type strings.
|
||||
func allJobTypes() []string {
|
||||
return []string{
|
||||
jobs.TypeSmartReminder,
|
||||
jobs.TypeDailyDigest,
|
||||
jobs.TypeSendEmail,
|
||||
jobs.TypeSendPush,
|
||||
jobs.TypeOnboardingEmails,
|
||||
jobs.TypeReminderLogCleanup,
|
||||
}
|
||||
}
|
||||
45
cmd/worker/startup_test.go
Normal file
45
cmd/worker/startup_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestQueuePriorities_CriticalHighest(t *testing.T) {
|
||||
p := queuePriorities()
|
||||
if p["critical"] <= p["default"] || p["critical"] <= p["low"] {
|
||||
t.Errorf("critical (%d) should be highest", p["critical"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueuePriorities_ThreeQueues(t *testing.T) {
|
||||
p := queuePriorities()
|
||||
if len(p) != 3 {
|
||||
t.Errorf("len = %d, want 3", len(p))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllJobTypes_Count(t *testing.T) {
|
||||
types := allJobTypes()
|
||||
if len(types) != 6 {
|
||||
t.Errorf("len = %d, want 6", len(types))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllJobTypes_NoDuplicates(t *testing.T) {
|
||||
types := allJobTypes()
|
||||
seen := make(map[string]bool)
|
||||
for _, typ := range types {
|
||||
if seen[typ] {
|
||||
t.Errorf("duplicate job type: %q", typ)
|
||||
}
|
||||
seen[typ] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllJobTypes_AllNonEmpty(t *testing.T) {
|
||||
for _, typ := range allJobTypes() {
|
||||
if typ == "" {
|
||||
t.Error("found empty job type")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2704,6 +2704,105 @@ paths:
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
|
||||
/auth/account/:
|
||||
delete:
|
||||
tags: [Authentication]
|
||||
summary: Delete user account
|
||||
description: Permanently deletes the authenticated user's account and all associated data
|
||||
security:
|
||||
- tokenAuth: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
password:
|
||||
type: string
|
||||
description: Required for email-auth users
|
||||
confirmation:
|
||||
type: string
|
||||
description: Must be "DELETE" for social-auth users
|
||||
responses:
|
||||
'200':
|
||||
description: Account deleted successfully
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
|
||||
/auth/refresh/:
|
||||
post:
|
||||
tags: [Authentication]
|
||||
summary: Refresh auth token
|
||||
description: Returns a new token if current token is in the renewal window (60-90 days old)
|
||||
security:
|
||||
- tokenAuth: []
|
||||
responses:
|
||||
'200':
|
||||
description: Token refreshed
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
|
||||
/health/live:
|
||||
get:
|
||||
tags: [Health]
|
||||
summary: Liveness probe
|
||||
description: Simple liveness check, always returns 200
|
||||
responses:
|
||||
'200':
|
||||
description: Alive
|
||||
|
||||
/tasks/suggestions/:
|
||||
get:
|
||||
tags: [Tasks]
|
||||
summary: Get personalized task template suggestions
|
||||
description: Returns task templates ranked by relevance to the residence's home profile
|
||||
security:
|
||||
- tokenAuth: []
|
||||
parameters:
|
||||
- name: residence_id
|
||||
in: query
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
'200':
|
||||
description: Suggestions with relevance scores
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
suggestions:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
template:
|
||||
$ref: '#/components/schemas/TaskTemplate'
|
||||
relevance_score:
|
||||
type: number
|
||||
match_reasons:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
total_count:
|
||||
type: integer
|
||||
profile_completeness:
|
||||
type: number
|
||||
'401':
|
||||
$ref: '#/components/responses/Unauthorized'
|
||||
'403':
|
||||
$ref: '#/components/responses/Forbidden'
|
||||
|
||||
# =============================================================================
|
||||
# Components
|
||||
# =============================================================================
|
||||
|
||||
302
docs/server_2026_2_24.md
Normal file
302
docs/server_2026_2_24.md
Normal file
@@ -0,0 +1,302 @@
|
||||
# Casera Infrastructure Plan — February 2026
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ Cloudflare │
|
||||
│ (CDN/DNS) │
|
||||
└──────┬──────┘
|
||||
│ HTTPS
|
||||
┌──────┴──────┐
|
||||
│ Hetzner LB │
|
||||
│ ($5.99) │
|
||||
└──────┬──────┘
|
||||
│
|
||||
┌────────────────┼────────────────┐
|
||||
│ │ │
|
||||
┌──────┴──────┐ ┌──────┴──────┐ ┌──────┴──────┐
|
||||
│ CX33 #1 │ │ CX33 #2 │ │ CX33 #3 │
|
||||
│ (manager) │ │ (manager) │ │ (manager) │
|
||||
│ │ │ │ │ │
|
||||
│ api (x2) │ │ api (x2) │ │ api (x1) │
|
||||
│ admin │ │ worker │ │ worker │
|
||||
│ redis │ │ dozzle │ │ │
|
||||
└──────┬──────┘ └──────┬──────┘ └──────┬──────┘
|
||||
│ │ │
|
||||
│ Docker Swarm Overlay (IPsec) │
|
||||
└────────────────┼────────────────┘
|
||||
│
|
||||
┌────────────┼────────────────┐
|
||||
│ │
|
||||
┌──────┴──────┐ ┌───────┴──────┐
|
||||
│ Neon │ │ Backblaze │
|
||||
│ (Postgres) │ │ B2 │
|
||||
│ Launch │ │ (media) │
|
||||
└─────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
## Swarm Nodes — Hetzner CX33
|
||||
|
||||
All 3 nodes are manager+worker (Raft consensus requires 3 managers for fault tolerance — 1 node can go down and the cluster stays operational).
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Plan | CX33 (Shared Regular Performance) |
|
||||
| vCPU | 4 |
|
||||
| RAM | 8 GB |
|
||||
| Disk | 80 GB SSD |
|
||||
| Traffic | 20 TB/mo included |
|
||||
| Price | $6.59/mo per node |
|
||||
| Region | Pick closest to users (US: Ashburn or Hillsboro, EU: Nuremberg/Falkenstein/Helsinki) |
|
||||
|
||||
**Why CX33 over CX23:** 8 GB RAM gives headroom for Redis, multiple API replicas, and the admin panel without pressure. The $2.50/mo difference per node isn't worth optimizing away.
|
||||
|
||||
### Container Distribution
|
||||
|
||||
| Container | Replicas | Notes |
|
||||
|-----------|----------|-------|
|
||||
| api | 3-6 | Spread across all nodes by Swarm |
|
||||
| worker | 2-3 | Asynq workers pull jobs from Redis concurrently |
|
||||
| admin | 1 | Next.js admin panel |
|
||||
| redis | 1 | Pinned to one node with its volume |
|
||||
| dozzle | 1 | Pinned to a manager node (needs Docker socket) |
|
||||
|
||||
### Scaling Path
|
||||
|
||||
- Need more capacity? Add another CX33 with `docker swarm join`. Swarm rebalances automatically.
|
||||
- Need more API throughput? Bump replicas in the compose file. No infra change.
|
||||
- Only infrastructure addition needed at scale: the Hetzner Load Balancer ($5.99/mo).
|
||||
|
||||
## Load Balancer — Hetzner LB
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Price | $5.99/mo |
|
||||
| Purpose | Distribute traffic across Swarm nodes, TLS termination |
|
||||
| When to add | When you need redundant ingress (not required day 1 if using Cloudflare to proxy to a single node) |
|
||||
|
||||
## Database — Neon Postgres (Launch Plan)
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Plan | Launch (usage-based, no monthly minimum) |
|
||||
| Compute | $0.106/CU-hr, up to 16 CU (64 GB RAM) |
|
||||
| Storage | $0.35/GB-month |
|
||||
| Connections | Up to 10,000 via built-in PgBouncer |
|
||||
| Typical cost | ~$5-15/mo for light load, ~$20-40/mo at 100k users |
|
||||
| Free tier | Available for dev/staging (100 CU-hrs/mo, 0.5 GB) |
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
Neon includes built-in PgBouncer on all plans. Enable by adding `-pooler` to the hostname:
|
||||
|
||||
```
|
||||
# Direct connection
|
||||
ep-cool-darkness-123456.us-east-2.aws.neon.tech
|
||||
|
||||
# Pooled connection (use this in production)
|
||||
ep-cool-darkness-123456-pooler.us-east-2.aws.neon.tech
|
||||
```
|
||||
|
||||
Runs in transaction mode — compatible with GORM out of the box.
|
||||
|
||||
### Configuration
|
||||
|
||||
```env
|
||||
DB_HOST=ep-xxxxx-pooler.us-east-2.aws.neon.tech
|
||||
DB_PORT=5432
|
||||
DB_SSLMODE=require
|
||||
POSTGRES_USER=<from neon dashboard>
|
||||
POSTGRES_PASSWORD=<from neon dashboard>
|
||||
POSTGRES_DB=casera
|
||||
```
|
||||
|
||||
## Object Storage — Backblaze B2
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Storage | $6/TB/mo ($0.006/GB) |
|
||||
| Egress | $0.01/GB (first 3x stored amount is free) |
|
||||
| Free tier | 10 GB storage always free |
|
||||
| API calls | Class A free, Class B/C free first 2,500/day |
|
||||
| Spending cap | Built-in data caps with alerts at 75% and 100% |
|
||||
|
||||
### Bucket Setup
|
||||
|
||||
| Bucket | Visibility | Key Permissions | Contents |
|
||||
|--------|------------|-----------------|----------|
|
||||
| `casera-uploads` | Private | Read/Write (API containers) | User-uploaded photos, documents |
|
||||
| `casera-certs` | Private | Read-only (API + worker) | APNs push certificates |
|
||||
|
||||
Serve files through the API using signed URLs — never expose buckets publicly.
|
||||
|
||||
### Why B2 Over Others
|
||||
|
||||
- **Spending cap**: only S3-compatible provider with built-in hard caps and alerts. No surprise bills.
|
||||
- **Cheapest storage**: $6/TB vs Cloudflare R2 at $15/TB vs Tigris at $20/TB.
|
||||
- **Free egress partner CDNs**: Cloudflare, Fastly, bunny.net — zero egress when behind Cloudflare.
|
||||
|
||||
## CDN — Cloudflare (Free Tier)
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Price | $0 |
|
||||
| Purpose | DNS, CDN caching, DDoS protection, TLS termination |
|
||||
| Setup | Point DNS to Cloudflare, proxy traffic to Hetzner LB (or directly to a Swarm node) |
|
||||
|
||||
Add this on day 1. No reason not to.
|
||||
|
||||
## Logging — Dozzle
|
||||
|
||||
| Spec | Value |
|
||||
|------|-------|
|
||||
| Price | $0 (open source) |
|
||||
| Port | 9999 (internal only — do not expose publicly) |
|
||||
| Features | Real-time log viewer, webhook support for alerts |
|
||||
|
||||
Runs as a container in the Swarm. Needs Docker socket access, so it's pinned to a manager node.
|
||||
|
||||
For 100k+ users, consider adding Prometheus + Grafana (self-hosted, free) or Betterstack (~$10/mo) for metrics and alerting beyond log viewing.
|
||||
|
||||
## Security
|
||||
|
||||
### Swarm Node Firewall (Hetzner Cloud Firewall — free)
|
||||
|
||||
| Port | Protocol | Source | Purpose |
|
||||
|------|----------|--------|---------|
|
||||
| Custom (e.g. 2222) | TCP | Your IP only | SSH |
|
||||
| 80, 443 | TCP | Anywhere | Public traffic |
|
||||
| 2377 | TCP | Swarm nodes only | Cluster management |
|
||||
| 7946 | TCP/UDP | Swarm nodes only | Node discovery |
|
||||
| 4789 | UDP | Swarm nodes only | Overlay network (VXLAN) |
|
||||
| Everything else | — | — | Blocked |
|
||||
|
||||
Set up once in Hetzner dashboard, apply to all 3 nodes.
|
||||
|
||||
### SSH Hardening
|
||||
|
||||
```
|
||||
# /etc/ssh/sshd_config
|
||||
Port 2222 # Non-default port
|
||||
PermitRootLogin no # No root SSH
|
||||
PasswordAuthentication no # Key-only auth
|
||||
PubkeyAuthentication yes
|
||||
AllowUsers deploy # Only your deploy user
|
||||
```
|
||||
|
||||
### Swarm ↔ Neon (Postgres)
|
||||
|
||||
| Layer | Method |
|
||||
|-------|--------|
|
||||
| Encryption | TLS enforced by Neon (`DB_SSLMODE=require`) |
|
||||
| Authentication | Strong password stored as Docker secret |
|
||||
| Access control | IP allowlist in Neon dashboard — restrict to 3 Swarm node IPs |
|
||||
|
||||
### Swarm ↔ B2 (Object Storage)
|
||||
|
||||
| Layer | Method |
|
||||
|-------|--------|
|
||||
| Encryption | HTTPS always (enforced by B2 API) |
|
||||
| Authentication | Scoped application keys (not master key) |
|
||||
| Access control | Per-bucket key permissions (read-only where possible) |
|
||||
|
||||
### Swarm Internal
|
||||
|
||||
| Layer | Method |
|
||||
|-------|--------|
|
||||
| Overlay encryption | `driver_opts: encrypted: "true"` on overlay network (IPsec between nodes) |
|
||||
| Secrets | Use `docker secret create` for DB password, SECRET_KEY, B2 keys, APNs keys. Mounted at `/run/secrets/`, encrypted in Swarm raft log. |
|
||||
| Container isolation | Non-root users in all containers (already configured in Dockerfile) |
|
||||
|
||||
### Docker Secrets Migration
|
||||
|
||||
Current setup uses environment variables for secrets. Migrate to Docker secrets for production:
|
||||
|
||||
```bash
|
||||
# Create secrets
|
||||
echo "your-db-password" | docker secret create postgres_password -
|
||||
echo "your-secret-key" | docker secret create secret_key -
|
||||
echo "your-b2-app-key" | docker secret create b2_app_key -
|
||||
|
||||
# Reference in compose file
|
||||
services:
|
||||
api:
|
||||
secrets:
|
||||
- postgres_password
|
||||
- secret_key
|
||||
secrets:
|
||||
postgres_password:
|
||||
external: true
|
||||
secret_key:
|
||||
external: true
|
||||
```
|
||||
|
||||
Application code reads from `/run/secrets/<name>` instead of env vars.
|
||||
|
||||
## Redis (In-Cluster)
|
||||
|
||||
Redis stays inside the Swarm — no need to externalize.
|
||||
|
||||
| Purpose | Details |
|
||||
|---------|---------|
|
||||
| Asynq job queue | Background jobs: push notifications, digests, reminders, onboarding emails |
|
||||
| Static data cache | Cached lookup tables with ETag support |
|
||||
| Resource usage | ~20-50 MB RAM, negligible CPU |
|
||||
|
||||
At 100k users, Redis handles job queuing for nightly digests (100k enqueue + dequeue operations) without issue. A single Redis instance handles millions of operations per second.
|
||||
|
||||
Asynq coordinates multiple worker replicas automatically — each job is dequeued atomically by exactly one worker, no double-processing.
|
||||
|
||||
## Performance Estimates
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Single CX33 API throughput | ~1,000-2,000 req/s (blended, with Neon latency) |
|
||||
| 3-node cluster throughput | ~3,000-6,000 req/s |
|
||||
| Avg requests per user per day | ~50 |
|
||||
| Estimated user capacity (3 nodes) | ~200k-500k registered users |
|
||||
| Bottleneck at scale | Neon compute tier, not Go or Swarm |
|
||||
|
||||
These are napkin estimates. Load test before launch.
|
||||
|
||||
## Monthly Cost Summary
|
||||
|
||||
### Starting Out
|
||||
|
||||
| Component | Provider | Cost |
|
||||
|-----------|----------|------|
|
||||
| 3x Swarm nodes | Hetzner CX33 | $19.77/mo |
|
||||
| Postgres | Neon Launch | ~$5-15/mo |
|
||||
| Object storage | Backblaze B2 | <$1/mo |
|
||||
| CDN | Cloudflare Free | $0 |
|
||||
| Logging | Dozzle (self-hosted) | $0 |
|
||||
| **Total** | | **~$25-35/mo** |
|
||||
|
||||
### At Scale (100k users)
|
||||
|
||||
| Component | Provider | Cost |
|
||||
|-----------|----------|------|
|
||||
| 3x Swarm nodes | Hetzner CX33 | $19.77/mo |
|
||||
| Load balancer | Hetzner LB | $5.99/mo |
|
||||
| Postgres | Neon Launch | ~$20-40/mo |
|
||||
| Object storage | Backblaze B2 | ~$1-3/mo |
|
||||
| CDN | Cloudflare Free | $0 |
|
||||
| Monitoring | Betterstack or self-hosted | ~$0-10/mo |
|
||||
| **Total** | | **~$47-79/mo** |
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Set up 3x Hetzner CX33 instances
|
||||
- [ ] Initialize Docker Swarm (`docker swarm init` on first node, `docker swarm join` on others)
|
||||
- [ ] Configure Hetzner Cloud Firewall
|
||||
- [ ] Harden SSH on all nodes
|
||||
- [ ] Create Neon project (Launch plan), configure IP allowlist
|
||||
- [ ] Create Backblaze B2 buckets with scoped application keys
|
||||
- [ ] Set up Cloudflare DNS proxying
|
||||
- [ ] Update prod compose file: remove `db` service, add overlay encryption, add Docker secrets
|
||||
- [ ] Add B2 SDK integration for file uploads (code change)
|
||||
- [ ] Update config to read from `/run/secrets/` for Docker secrets
|
||||
- [ ] Set B2 spending cap and alerts
|
||||
- [ ] Load test the deployed stack
|
||||
- [ ] Add Hetzner LB when needed
|
||||
176
internal/admin/dto/dto_test.go
Normal file
176
internal/admin/dto/dto_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- GetPage ---
|
||||
|
||||
func TestGetPage_Zero_Returns1(t *testing.T) {
|
||||
p := &PaginationParams{Page: 0}
|
||||
if got := p.GetPage(); got != 1 {
|
||||
t.Errorf("GetPage(0) = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPage_Negative_Returns1(t *testing.T) {
|
||||
p := &PaginationParams{Page: -5}
|
||||
if got := p.GetPage(); got != 1 {
|
||||
t.Errorf("GetPage(-5) = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPage_Valid_ReturnsValue(t *testing.T) {
|
||||
p := &PaginationParams{Page: 3}
|
||||
if got := p.GetPage(); got != 3 {
|
||||
t.Errorf("GetPage(3) = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetPerPage ---
|
||||
|
||||
func TestGetPerPage_Zero_Returns20(t *testing.T) {
|
||||
p := &PaginationParams{PerPage: 0}
|
||||
if got := p.GetPerPage(); got != 20 {
|
||||
t.Errorf("GetPerPage(0) = %d, want 20", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPerPage_Negative_Returns20(t *testing.T) {
|
||||
p := &PaginationParams{PerPage: -1}
|
||||
if got := p.GetPerPage(); got != 20 {
|
||||
t.Errorf("GetPerPage(-1) = %d, want 20", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPerPage_TooLarge_Returns10000(t *testing.T) {
|
||||
p := &PaginationParams{PerPage: 20000}
|
||||
if got := p.GetPerPage(); got != 10000 {
|
||||
t.Errorf("GetPerPage(20000) = %d, want 10000", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPerPage_Valid_ReturnsValue(t *testing.T) {
|
||||
p := &PaginationParams{PerPage: 50}
|
||||
if got := p.GetPerPage(); got != 50 {
|
||||
t.Errorf("GetPerPage(50) = %d, want 50", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetOffset ---
|
||||
|
||||
func TestGetOffset_Page1_Returns0(t *testing.T) {
|
||||
p := &PaginationParams{Page: 1, PerPage: 20}
|
||||
if got := p.GetOffset(); got != 0 {
|
||||
t.Errorf("GetOffset(page=1, perPage=20) = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOffset_Page3_PerPage10_Returns20(t *testing.T) {
|
||||
p := &PaginationParams{Page: 3, PerPage: 10}
|
||||
if got := p.GetOffset(); got != 20 {
|
||||
t.Errorf("GetOffset(page=3, perPage=10) = %d, want 20", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOffset_Defaults_Returns0(t *testing.T) {
|
||||
p := &PaginationParams{}
|
||||
if got := p.GetOffset(); got != 0 {
|
||||
t.Errorf("GetOffset(defaults) = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetSortDir ---
|
||||
|
||||
func TestGetSortDir_Asc(t *testing.T) {
|
||||
p := &PaginationParams{SortDir: "asc"}
|
||||
if got := p.GetSortDir(); got != "ASC" {
|
||||
t.Errorf("GetSortDir('asc') = %q, want 'ASC'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSortDir_Desc(t *testing.T) {
|
||||
p := &PaginationParams{SortDir: "desc"}
|
||||
if got := p.GetSortDir(); got != "DESC" {
|
||||
t.Errorf("GetSortDir('desc') = %q, want 'DESC'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSortDir_Empty_ReturnsDesc(t *testing.T) {
|
||||
p := &PaginationParams{SortDir: ""}
|
||||
if got := p.GetSortDir(); got != "DESC" {
|
||||
t.Errorf("GetSortDir('') = %q, want 'DESC'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSortDir_Invalid_ReturnsDesc(t *testing.T) {
|
||||
p := &PaginationParams{SortDir: "RANDOM"}
|
||||
if got := p.GetSortDir(); got != "DESC" {
|
||||
t.Errorf("GetSortDir('RANDOM') = %q, want 'DESC'", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- GetSafeSortBy ---
|
||||
|
||||
func TestGetSafeSortBy_Allowed(t *testing.T) {
|
||||
p := &PaginationParams{SortBy: "name"}
|
||||
got := p.GetSafeSortBy([]string{"name", "email"}, "id")
|
||||
if got != "name" {
|
||||
t.Errorf("GetSafeSortBy('name') = %q, want 'name'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSafeSortBy_NotAllowed_ReturnsDefault(t *testing.T) {
|
||||
p := &PaginationParams{SortBy: "password"}
|
||||
got := p.GetSafeSortBy([]string{"name", "email"}, "id")
|
||||
if got != "id" {
|
||||
t.Errorf("GetSafeSortBy('password') = %q, want 'id'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSafeSortBy_Empty_ReturnsDefault(t *testing.T) {
|
||||
p := &PaginationParams{SortBy: ""}
|
||||
got := p.GetSafeSortBy([]string{"name", "email"}, "id")
|
||||
if got != "id" {
|
||||
t.Errorf("GetSafeSortBy('') = %q, want 'id'", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- NewPaginatedResponse ---
|
||||
|
||||
func TestNewPaginatedResponse_ExactPages(t *testing.T) {
|
||||
resp := NewPaginatedResponse([]string{"a", "b"}, 40, 1, 20)
|
||||
if resp.TotalPages != 2 {
|
||||
t.Errorf("TotalPages = %d, want 2", resp.TotalPages)
|
||||
}
|
||||
if resp.Total != 40 {
|
||||
t.Errorf("Total = %d, want 40", resp.Total)
|
||||
}
|
||||
if resp.Page != 1 {
|
||||
t.Errorf("Page = %d, want 1", resp.Page)
|
||||
}
|
||||
if resp.PerPage != 20 {
|
||||
t.Errorf("PerPage = %d, want 20", resp.PerPage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPaginatedResponse_PartialLastPage(t *testing.T) {
|
||||
resp := NewPaginatedResponse(nil, 21, 1, 20)
|
||||
if resp.TotalPages != 2 {
|
||||
t.Errorf("TotalPages = %d, want 2", resp.TotalPages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPaginatedResponse_SinglePage(t *testing.T) {
|
||||
resp := NewPaginatedResponse(nil, 5, 1, 20)
|
||||
if resp.TotalPages != 1 {
|
||||
t.Errorf("TotalPages = %d, want 1", resp.TotalPages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPaginatedResponse_ZeroTotal(t *testing.T) {
|
||||
resp := NewPaginatedResponse(nil, 0, 1, 20)
|
||||
if resp.TotalPages != 0 {
|
||||
t.Errorf("TotalPages = %d, want 0", resp.TotalPages)
|
||||
}
|
||||
}
|
||||
109
internal/apperrors/apperrors_test.go
Normal file
109
internal/apperrors/apperrors_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package apperrors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
err := NotFound("error.task_not_found")
|
||||
assert.Equal(t, http.StatusNotFound, err.Code)
|
||||
assert.Equal(t, "error.task_not_found", err.MessageKey)
|
||||
assert.Empty(t, err.Message)
|
||||
assert.Nil(t, err.Err)
|
||||
}
|
||||
|
||||
func TestForbidden(t *testing.T) {
|
||||
err := Forbidden("error.residence_access_denied")
|
||||
assert.Equal(t, http.StatusForbidden, err.Code)
|
||||
assert.Equal(t, "error.residence_access_denied", err.MessageKey)
|
||||
}
|
||||
|
||||
func TestBadRequest(t *testing.T) {
|
||||
err := BadRequest("error.invalid_request_body")
|
||||
assert.Equal(t, http.StatusBadRequest, err.Code)
|
||||
assert.Equal(t, "error.invalid_request_body", err.MessageKey)
|
||||
}
|
||||
|
||||
func TestUnauthorized(t *testing.T) {
|
||||
err := Unauthorized("error.not_authenticated")
|
||||
assert.Equal(t, http.StatusUnauthorized, err.Code)
|
||||
assert.Equal(t, "error.not_authenticated", err.MessageKey)
|
||||
}
|
||||
|
||||
func TestConflict(t *testing.T) {
|
||||
err := Conflict("error.email_taken")
|
||||
assert.Equal(t, http.StatusConflict, err.Code)
|
||||
assert.Equal(t, "error.email_taken", err.MessageKey)
|
||||
}
|
||||
|
||||
func TestTooManyRequests(t *testing.T) {
|
||||
err := TooManyRequests("error.rate_limit_exceeded")
|
||||
assert.Equal(t, http.StatusTooManyRequests, err.Code)
|
||||
assert.Equal(t, "error.rate_limit_exceeded", err.MessageKey)
|
||||
}
|
||||
|
||||
func TestInternal(t *testing.T) {
|
||||
underlying := errors.New("database connection failed")
|
||||
err := Internal(underlying)
|
||||
assert.Equal(t, http.StatusInternalServerError, err.Code)
|
||||
assert.Equal(t, "error.internal", err.MessageKey)
|
||||
assert.Equal(t, underlying, err.Err)
|
||||
}
|
||||
|
||||
func TestAppError_Error_WithWrappedError(t *testing.T) {
|
||||
underlying := errors.New("connection refused")
|
||||
err := Internal(underlying).WithMessage("database error")
|
||||
assert.Equal(t, "database error: connection refused", err.Error())
|
||||
}
|
||||
|
||||
func TestAppError_Error_WithMessageOnly(t *testing.T) {
|
||||
err := NotFound("error.task_not_found").WithMessage("Task not found")
|
||||
assert.Equal(t, "Task not found", err.Error())
|
||||
}
|
||||
|
||||
func TestAppError_Error_MessageKeyFallback(t *testing.T) {
|
||||
err := NotFound("error.task_not_found")
|
||||
// No Message set, no Err set — should fall back to MessageKey
|
||||
assert.Equal(t, "error.task_not_found", err.Error())
|
||||
}
|
||||
|
||||
func TestAppError_Unwrap(t *testing.T) {
|
||||
underlying := errors.New("wrapped error")
|
||||
err := Internal(underlying)
|
||||
assert.Equal(t, underlying, errors.Unwrap(err))
|
||||
}
|
||||
|
||||
func TestAppError_Unwrap_Nil(t *testing.T) {
|
||||
err := NotFound("error.task_not_found")
|
||||
assert.Nil(t, errors.Unwrap(err))
|
||||
}
|
||||
|
||||
func TestAppError_WithMessage(t *testing.T) {
|
||||
err := NotFound("error.task_not_found").WithMessage("custom message")
|
||||
assert.Equal(t, "custom message", err.Message)
|
||||
assert.Equal(t, "error.task_not_found", err.MessageKey)
|
||||
}
|
||||
|
||||
func TestAppError_Wrap(t *testing.T) {
|
||||
underlying := errors.New("some error")
|
||||
err := BadRequest("error.invalid_request_body").Wrap(underlying)
|
||||
assert.Equal(t, underlying, err.Err)
|
||||
assert.Equal(t, http.StatusBadRequest, err.Code)
|
||||
}
|
||||
|
||||
func TestAppError_ImplementsError(t *testing.T) {
|
||||
var err error = NotFound("error.task_not_found")
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, "error.task_not_found", err.Error())
|
||||
}
|
||||
|
||||
func TestAppError_ErrorsAs(t *testing.T) {
|
||||
var appErr *AppError
|
||||
err := NotFound("error.task_not_found")
|
||||
assert.True(t, errors.As(err, &appErr))
|
||||
assert.Equal(t, http.StatusNotFound, appErr.Code)
|
||||
}
|
||||
324
internal/config/config_test.go
Normal file
324
internal/config/config_test.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// resetConfigState resets the package-level singleton so each test starts fresh.
|
||||
func resetConfigState() {
|
||||
cfg = nil
|
||||
cfgOnce = sync.Once{}
|
||||
viper.Reset()
|
||||
}
|
||||
|
||||
func TestLoad_DefaultValues(t *testing.T) {
|
||||
resetConfigState()
|
||||
// Provide required SECRET_KEY so validation passes
|
||||
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
|
||||
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server defaults
|
||||
assert.Equal(t, 8000, c.Server.Port)
|
||||
assert.False(t, c.Server.Debug)
|
||||
assert.False(t, c.Server.DebugFixedCodes)
|
||||
assert.Equal(t, "UTC", c.Server.Timezone)
|
||||
assert.Equal(t, "/app/static", c.Server.StaticDir)
|
||||
assert.Equal(t, "https://api.myhoneydue.com", c.Server.BaseURL)
|
||||
|
||||
// Database defaults
|
||||
assert.Equal(t, "localhost", c.Database.Host)
|
||||
assert.Equal(t, 5432, c.Database.Port)
|
||||
assert.Equal(t, "postgres", c.Database.User)
|
||||
assert.Equal(t, "honeydue", c.Database.Database)
|
||||
assert.Equal(t, "disable", c.Database.SSLMode)
|
||||
assert.Equal(t, 25, c.Database.MaxOpenConns)
|
||||
assert.Equal(t, 10, c.Database.MaxIdleConns)
|
||||
|
||||
// Redis defaults
|
||||
assert.Equal(t, "redis://localhost:6379/0", c.Redis.URL)
|
||||
assert.Equal(t, 0, c.Redis.DB)
|
||||
|
||||
// Worker defaults
|
||||
assert.Equal(t, 14, c.Worker.TaskReminderHour)
|
||||
assert.Equal(t, 15, c.Worker.OverdueReminderHour)
|
||||
assert.Equal(t, 3, c.Worker.DailyNotifHour)
|
||||
|
||||
// Token expiry defaults
|
||||
assert.Equal(t, 90, c.Security.TokenExpiryDays)
|
||||
assert.Equal(t, 60, c.Security.TokenRefreshDays)
|
||||
|
||||
// Feature flags default to true
|
||||
assert.True(t, c.Features.PushEnabled)
|
||||
assert.True(t, c.Features.EmailEnabled)
|
||||
assert.True(t, c.Features.WebhooksEnabled)
|
||||
assert.True(t, c.Features.OnboardingEmailsEnabled)
|
||||
assert.True(t, c.Features.PDFReportsEnabled)
|
||||
assert.True(t, c.Features.WorkerEnabled)
|
||||
}
|
||||
|
||||
func TestLoad_EnvOverrides(t *testing.T) {
|
||||
resetConfigState()
|
||||
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
|
||||
t.Setenv("PORT", "9090")
|
||||
t.Setenv("DEBUG", "true")
|
||||
t.Setenv("DB_HOST", "db.example.com")
|
||||
t.Setenv("DB_PORT", "5433")
|
||||
t.Setenv("TOKEN_EXPIRY_DAYS", "180")
|
||||
t.Setenv("TOKEN_REFRESH_DAYS", "120")
|
||||
t.Setenv("FEATURE_PUSH_ENABLED", "false")
|
||||
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 9090, c.Server.Port)
|
||||
assert.True(t, c.Server.Debug)
|
||||
assert.Equal(t, "db.example.com", c.Database.Host)
|
||||
assert.Equal(t, 5433, c.Database.Port)
|
||||
assert.Equal(t, 180, c.Security.TokenExpiryDays)
|
||||
assert.Equal(t, 120, c.Security.TokenRefreshDays)
|
||||
assert.False(t, c.Features.PushEnabled)
|
||||
}
|
||||
|
||||
func TestLoad_Validation_MissingSecretKey_Production(t *testing.T) {
|
||||
// Test validate() directly to avoid the sync.Once mutex issue
|
||||
// that occurs when Load() resets cfgOnce inside cfgOnce.Do()
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{Debug: false},
|
||||
Security: SecurityConfig{SecretKey: ""},
|
||||
}
|
||||
|
||||
err := validate(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "SECRET_KEY")
|
||||
}
|
||||
|
||||
func TestLoad_Validation_MissingSecretKey_DebugMode(t *testing.T) {
|
||||
resetConfigState()
|
||||
t.Setenv("SECRET_KEY", "")
|
||||
t.Setenv("DEBUG", "true")
|
||||
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
// In debug mode, a default key is assigned
|
||||
assert.Equal(t, "change-me-in-production-secret-key-12345", c.Security.SecretKey)
|
||||
}
|
||||
|
||||
func TestLoad_Validation_WeakSecretKey_Production(t *testing.T) {
|
||||
// Test validate() directly to avoid the sync.Once mutex issue
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{Debug: false},
|
||||
Security: SecurityConfig{SecretKey: "password"},
|
||||
}
|
||||
|
||||
err := validate(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "well-known weak value")
|
||||
}
|
||||
|
||||
func TestLoad_Validation_WeakSecretKey_DebugMode(t *testing.T) {
|
||||
resetConfigState()
|
||||
t.Setenv("SECRET_KEY", "secret")
|
||||
t.Setenv("DEBUG", "true")
|
||||
|
||||
// In debug mode, weak keys produce a warning but no error
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "secret", c.Security.SecretKey)
|
||||
}
|
||||
|
||||
func TestLoad_Validation_EncryptionKey_Valid(t *testing.T) {
|
||||
resetConfigState()
|
||||
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
|
||||
// Valid 64-char hex key (32 bytes)
|
||||
t.Setenv("STORAGE_ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
|
||||
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", c.Storage.EncryptionKey)
|
||||
}
|
||||
|
||||
func TestLoad_Validation_EncryptionKey_WrongLength(t *testing.T) {
|
||||
// Test validate() directly to avoid the sync.Once mutex issue
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{Debug: false},
|
||||
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
|
||||
Storage: StorageConfig{EncryptionKey: "tooshort"},
|
||||
}
|
||||
|
||||
err := validate(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "STORAGE_ENCRYPTION_KEY must be exactly 64 hex characters")
|
||||
}
|
||||
|
||||
func TestLoad_Validation_EncryptionKey_InvalidHex(t *testing.T) {
|
||||
// Test validate() directly to avoid the sync.Once mutex issue
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{Debug: false},
|
||||
Security: SecurityConfig{SecretKey: "a-strong-secret-key-for-tests"},
|
||||
Storage: StorageConfig{EncryptionKey: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"},
|
||||
}
|
||||
|
||||
err := validate(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid hex")
|
||||
}
|
||||
|
||||
func TestLoad_DatabaseURL_Override(t *testing.T) {
|
||||
resetConfigState()
|
||||
t.Setenv("SECRET_KEY", "a-strong-secret-key-for-tests")
|
||||
t.Setenv("DATABASE_URL", "postgres://myuser:mypass@dbhost:5433/mydb?sslmode=require")
|
||||
|
||||
c, err := Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "dbhost", c.Database.Host)
|
||||
assert.Equal(t, 5433, c.Database.Port)
|
||||
assert.Equal(t, "myuser", c.Database.User)
|
||||
assert.Equal(t, "mypass", c.Database.Password)
|
||||
assert.Equal(t, "mydb", c.Database.Database)
|
||||
assert.Equal(t, "require", c.Database.SSLMode)
|
||||
}
|
||||
|
||||
func TestDSN(t *testing.T) {
|
||||
d := DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "testuser",
|
||||
Password: "Password123",
|
||||
Database: "testdb",
|
||||
SSLMode: "disable",
|
||||
}
|
||||
|
||||
dsn := d.DSN()
|
||||
assert.Contains(t, dsn, "host=localhost")
|
||||
assert.Contains(t, dsn, "port=5432")
|
||||
assert.Contains(t, dsn, "user=testuser")
|
||||
assert.Contains(t, dsn, "password=Password123")
|
||||
assert.Contains(t, dsn, "dbname=testdb")
|
||||
assert.Contains(t, dsn, "sslmode=disable")
|
||||
}
|
||||
|
||||
func TestMaskURLCredentials(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "URL with password",
|
||||
input: "postgres://user:secret@host:5432/db",
|
||||
expected: "postgres://user:xxxxx@host:5432/db",
|
||||
},
|
||||
{
|
||||
name: "URL without password",
|
||||
input: "postgres://user@host:5432/db",
|
||||
expected: "postgres://user@host:5432/db",
|
||||
},
|
||||
{
|
||||
name: "URL without user info",
|
||||
input: "postgres://host:5432/db",
|
||||
expected: "postgres://host:5432/db",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := MaskURLCredentials(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCorsOrigins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{"empty string", "", nil},
|
||||
{"single origin", "https://example.com", []string{"https://example.com"}},
|
||||
{"multiple origins", "https://a.com, https://b.com", []string{"https://a.com", "https://b.com"}},
|
||||
{"whitespace trimmed", " https://a.com , https://b.com ", []string{"https://a.com", "https://b.com"}},
|
||||
{"empty parts skipped", "https://a.com,,https://b.com", []string{"https://a.com", "https://b.com"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := parseCorsOrigins(tc.input)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDatabaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantHost string
|
||||
wantPort int
|
||||
wantUser string
|
||||
wantPass string
|
||||
wantDB string
|
||||
wantSSL string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "full URL",
|
||||
url: "postgres://user:Password123@host:5433/mydb?sslmode=require",
|
||||
wantHost: "host",
|
||||
wantPort: 5433,
|
||||
wantUser: "user",
|
||||
wantPass: "Password123",
|
||||
wantDB: "mydb",
|
||||
wantSSL: "require",
|
||||
},
|
||||
{
|
||||
name: "default port",
|
||||
url: "postgres://user:pass@host/mydb",
|
||||
wantHost: "host",
|
||||
wantPort: 5432,
|
||||
wantUser: "user",
|
||||
wantPass: "pass",
|
||||
wantDB: "mydb",
|
||||
wantSSL: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := parseDatabaseURL(tc.url)
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.wantHost, result.Host)
|
||||
assert.Equal(t, tc.wantPort, result.Port)
|
||||
assert.Equal(t, tc.wantUser, result.User)
|
||||
assert.Equal(t, tc.wantPass, result.Password)
|
||||
assert.Equal(t, tc.wantDB, result.Database)
|
||||
assert.Equal(t, tc.wantSSL, result.SSLMode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsWeakSecretKey(t *testing.T) {
|
||||
assert.True(t, isWeakSecretKey("secret"))
|
||||
assert.True(t, isWeakSecretKey("Secret")) // case-insensitive
|
||||
assert.True(t, isWeakSecretKey(" changeme ")) // whitespace trimmed
|
||||
assert.True(t, isWeakSecretKey("password"))
|
||||
assert.True(t, isWeakSecretKey("change-me"))
|
||||
assert.False(t, isWeakSecretKey("a-strong-unique-production-key"))
|
||||
}
|
||||
|
||||
func TestGet_ReturnsNilBeforeLoad(t *testing.T) {
|
||||
resetConfigState()
|
||||
assert.Nil(t, Get())
|
||||
}
|
||||
103
internal/database/database_test.go
Normal file
103
internal/database/database_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// --- Unit tests for Paginate parameter clamping ---
|
||||
|
||||
func TestPaginate_PageZeroDefaultsToOne(t *testing.T) {
|
||||
scope := Paginate(0, 10)
|
||||
|
||||
db := openTestDB(t)
|
||||
createTestRows(t, db, 5)
|
||||
|
||||
var rows []testRow
|
||||
err := db.Scopes(scope).Find(&rows).Error
|
||||
require.NoError(t, err)
|
||||
// page=0 normalised to page=1, pageSize=10 → should get all 5 rows
|
||||
assert.Len(t, rows, 5)
|
||||
}
|
||||
|
||||
func TestPaginate_PageSizeZeroDefaultsTo100(t *testing.T) {
|
||||
scope := Paginate(1, 0)
|
||||
|
||||
db := openTestDB(t)
|
||||
createTestRows(t, db, 5)
|
||||
|
||||
var rows []testRow
|
||||
err := db.Scopes(scope).Find(&rows).Error
|
||||
require.NoError(t, err)
|
||||
// pageSize=0 normalised to 100, only 5 rows exist → 5 returned
|
||||
assert.Len(t, rows, 5)
|
||||
}
|
||||
|
||||
func TestPaginate_PageSizeOverMaxCappedAt1000(t *testing.T) {
|
||||
scope := Paginate(1, 2000)
|
||||
|
||||
db := openTestDB(t)
|
||||
createTestRows(t, db, 5)
|
||||
|
||||
var rows []testRow
|
||||
err := db.Scopes(scope).Find(&rows).Error
|
||||
require.NoError(t, err)
|
||||
// pageSize=2000 capped to 1000, only 5 rows → 5 returned
|
||||
assert.Len(t, rows, 5)
|
||||
}
|
||||
|
||||
func TestPaginate_NormalValues(t *testing.T) {
|
||||
scope := Paginate(1, 3)
|
||||
|
||||
db := openTestDB(t)
|
||||
createTestRows(t, db, 10)
|
||||
|
||||
var rows []testRow
|
||||
err := db.Scopes(scope).Order("id ASC").Find(&rows).Error
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, rows, 3)
|
||||
assert.Equal(t, "row_1", rows[0].Name)
|
||||
assert.Equal(t, "row_3", rows[2].Name)
|
||||
}
|
||||
|
||||
func TestPaginate_SQLiteIntegration_Page2Size10(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestRows(t, db, 25)
|
||||
|
||||
scope := Paginate(2, 10)
|
||||
var rows []testRow
|
||||
err := db.Scopes(scope).Order("id ASC").Find(&rows).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Page 2 with size 10 → rows 11..20
|
||||
assert.Len(t, rows, 10)
|
||||
assert.Equal(t, "row_11", rows[0].Name)
|
||||
assert.Equal(t, "row_20", rows[9].Name)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
type testRow struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
|
||||
func openTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&testRow{}))
|
||||
return db
|
||||
}
|
||||
|
||||
func createTestRows(t *testing.T, db *gorm.DB, n int) {
|
||||
t.Helper()
|
||||
for i := 1; i <= n; i++ {
|
||||
require.NoError(t, db.Create(&testRow{Name: fmt.Sprintf("row_%d", i)}).Error)
|
||||
}
|
||||
}
|
||||
47
internal/database/migration_backfill_test.go
Normal file
47
internal/database/migration_backfill_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClassifyCompletion_CompletedAfterDue(t *testing.T) {
|
||||
dueDate := time.Date(2025, 6, 1, 0, 0, 0, 0, time.UTC)
|
||||
completedAt := time.Date(2025, 6, 5, 14, 30, 0, 0, time.UTC) // 4 days after due
|
||||
|
||||
result := classifyCompletion(completedAt, dueDate, 30)
|
||||
|
||||
assert.Equal(t, "overdue_tasks", result)
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedOnDueDate(t *testing.T) {
|
||||
dueDate := time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
|
||||
completedAt := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC) // same day
|
||||
|
||||
result := classifyCompletion(completedAt, dueDate, 30)
|
||||
|
||||
// Completed on the due date: daysBefore == 0, which is <= threshold → due_soon_tasks
|
||||
assert.Equal(t, "due_soon_tasks", result)
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedWithinThreshold(t *testing.T) {
|
||||
dueDate := time.Date(2025, 7, 1, 0, 0, 0, 0, time.UTC)
|
||||
completedAt := time.Date(2025, 6, 10, 8, 0, 0, 0, time.UTC) // 21 days before due
|
||||
|
||||
result := classifyCompletion(completedAt, dueDate, 30)
|
||||
|
||||
// 21 days before due, within 30-day threshold → due_soon_tasks
|
||||
assert.Equal(t, "due_soon_tasks", result)
|
||||
}
|
||||
|
||||
func TestClassifyCompletion_CompletedBeyondThreshold(t *testing.T) {
|
||||
dueDate := time.Date(2025, 9, 1, 0, 0, 0, 0, time.UTC)
|
||||
completedAt := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) // 92 days before due
|
||||
|
||||
result := classifyCompletion(completedAt, dueDate, 30)
|
||||
|
||||
// 92 days before due, beyond 30-day threshold → upcoming_tasks
|
||||
assert.Equal(t, "upcoming_tasks", result)
|
||||
}
|
||||
31
internal/database/migration_helpers.go
Normal file
31
internal/database/migration_helpers.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package database
|
||||
|
||||
import "sort"
|
||||
|
||||
// sortMigrationNames returns a sorted copy of the names slice.
|
||||
func sortMigrationNames(names []string) []string {
|
||||
sorted := make([]string, len(names))
|
||||
copy(sorted, names)
|
||||
sort.Strings(sorted)
|
||||
return sorted
|
||||
}
|
||||
|
||||
// buildAppliedSet converts a list of applied migrations to a lookup set.
|
||||
func buildAppliedSet(applied []DataMigration) map[string]bool {
|
||||
set := make(map[string]bool, len(applied))
|
||||
for _, m := range applied {
|
||||
set[m.Name] = true
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// filterPending returns names not present in the applied set.
|
||||
func filterPending(names []string, applied map[string]bool) []string {
|
||||
var pending []string
|
||||
for _, name := range names {
|
||||
if !applied[name] {
|
||||
pending = append(pending, name)
|
||||
}
|
||||
}
|
||||
return pending
|
||||
}
|
||||
82
internal/database/migration_helpers_test.go
Normal file
82
internal/database/migration_helpers_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- sortMigrationNames ---
|
||||
|
||||
func TestSortMigrationNames_Alphabetical(t *testing.T) {
|
||||
input := []string{"charlie", "alpha", "bravo"}
|
||||
result := sortMigrationNames(input)
|
||||
|
||||
assert.Equal(t, []string{"alpha", "bravo", "charlie"}, result)
|
||||
// Verify original slice is not mutated
|
||||
assert.Equal(t, []string{"charlie", "alpha", "bravo"}, input)
|
||||
}
|
||||
|
||||
func TestSortMigrationNames_Empty(t *testing.T) {
|
||||
result := sortMigrationNames([]string{})
|
||||
assert.Equal(t, []string{}, result)
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
// --- buildAppliedSet ---
|
||||
|
||||
func TestBuildAppliedSet_Multiple(t *testing.T) {
|
||||
applied := []DataMigration{
|
||||
{ID: 1, Name: "20250101_first", AppliedAt: time.Now()},
|
||||
{ID: 2, Name: "20250201_second", AppliedAt: time.Now()},
|
||||
{ID: 3, Name: "20250301_third", AppliedAt: time.Now()},
|
||||
}
|
||||
|
||||
set := buildAppliedSet(applied)
|
||||
|
||||
assert.Len(t, set, 3)
|
||||
assert.True(t, set["20250101_first"])
|
||||
assert.True(t, set["20250201_second"])
|
||||
assert.True(t, set["20250301_third"])
|
||||
assert.False(t, set["nonexistent"])
|
||||
}
|
||||
|
||||
func TestBuildAppliedSet_Empty(t *testing.T) {
|
||||
set := buildAppliedSet([]DataMigration{})
|
||||
assert.Len(t, set, 0)
|
||||
}
|
||||
|
||||
// --- filterPending ---
|
||||
|
||||
func TestFilterPending_SomePending(t *testing.T) {
|
||||
names := []string{"20250101_first", "20250201_second", "20250301_third"}
|
||||
applied := map[string]bool{
|
||||
"20250101_first": true,
|
||||
}
|
||||
|
||||
pending := filterPending(names, applied)
|
||||
|
||||
assert.Equal(t, []string{"20250201_second", "20250301_third"}, pending)
|
||||
}
|
||||
|
||||
func TestFilterPending_AllApplied(t *testing.T) {
|
||||
names := []string{"20250101_first", "20250201_second"}
|
||||
applied := map[string]bool{
|
||||
"20250101_first": true,
|
||||
"20250201_second": true,
|
||||
}
|
||||
|
||||
pending := filterPending(names, applied)
|
||||
|
||||
assert.Nil(t, pending)
|
||||
}
|
||||
|
||||
func TestFilterPending_NoneApplied(t *testing.T) {
|
||||
names := []string{"20250101_first", "20250201_second", "20250301_third"}
|
||||
applied := map[string]bool{}
|
||||
|
||||
pending := filterPending(names, applied)
|
||||
|
||||
assert.Equal(t, []string{"20250101_first", "20250201_second", "20250301_third"}, pending)
|
||||
}
|
||||
130
internal/dto/requests/requests_test.go
Normal file
130
internal/dto/requests/requests_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package requests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFlexibleDate_UnmarshalJSON_DateOnly(t *testing.T) {
|
||||
var fd FlexibleDate
|
||||
err := fd.UnmarshalJSON([]byte(`"2025-11-27"`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := time.Date(2025, 11, 27, 0, 0, 0, 0, time.UTC)
|
||||
if !fd.Time.Equal(want) {
|
||||
t.Errorf("got %v, want %v", fd.Time, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_UnmarshalJSON_RFC3339(t *testing.T) {
|
||||
var fd FlexibleDate
|
||||
err := fd.UnmarshalJSON([]byte(`"2025-11-27T15:30:00Z"`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := time.Date(2025, 11, 27, 15, 30, 0, 0, time.UTC)
|
||||
if !fd.Time.Equal(want) {
|
||||
t.Errorf("got %v, want %v", fd.Time, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_UnmarshalJSON_Null(t *testing.T) {
|
||||
var fd FlexibleDate
|
||||
err := fd.UnmarshalJSON([]byte(`null`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !fd.Time.IsZero() {
|
||||
t.Errorf("expected zero time, got %v", fd.Time)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_UnmarshalJSON_EmptyString(t *testing.T) {
|
||||
var fd FlexibleDate
|
||||
err := fd.UnmarshalJSON([]byte(`""`))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !fd.Time.IsZero() {
|
||||
t.Errorf("expected zero time, got %v", fd.Time)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_UnmarshalJSON_Invalid(t *testing.T) {
|
||||
var fd FlexibleDate
|
||||
err := fd.UnmarshalJSON([]byte(`"not-a-date"`))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid date, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_MarshalJSON_Valid(t *testing.T) {
|
||||
fd := FlexibleDate{Time: time.Date(2025, 11, 27, 15, 30, 0, 0, time.UTC)}
|
||||
data, err := fd.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
t.Fatalf("result is not a JSON string: %v", err)
|
||||
}
|
||||
want := "2025-11-27T15:30:00Z"
|
||||
if s != want {
|
||||
t.Errorf("got %q, want %q", s, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_MarshalJSON_Zero(t *testing.T) {
|
||||
fd := FlexibleDate{}
|
||||
data, err := fd.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("got %s, want null", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_ToTimePtr_Valid(t *testing.T) {
|
||||
fd := &FlexibleDate{Time: time.Date(2025, 11, 27, 0, 0, 0, 0, time.UTC)}
|
||||
ptr := fd.ToTimePtr()
|
||||
if ptr == nil {
|
||||
t.Fatal("expected non-nil pointer")
|
||||
}
|
||||
if !ptr.Equal(fd.Time) {
|
||||
t.Errorf("got %v, want %v", *ptr, fd.Time)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_ToTimePtr_Zero(t *testing.T) {
|
||||
fd := &FlexibleDate{}
|
||||
ptr := fd.ToTimePtr()
|
||||
if ptr != nil {
|
||||
t.Errorf("expected nil, got %v", *ptr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_ToTimePtr_NilReceiver(t *testing.T) {
|
||||
var fd *FlexibleDate
|
||||
ptr := fd.ToTimePtr()
|
||||
if ptr != nil {
|
||||
t.Errorf("expected nil for nil receiver, got %v", *ptr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlexibleDate_RoundTrip(t *testing.T) {
|
||||
original := FlexibleDate{Time: time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)}
|
||||
data, err := original.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
var restored FlexibleDate
|
||||
if err := restored.UnmarshalJSON(data); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
if !original.Time.Equal(restored.Time) {
|
||||
t.Errorf("round-trip mismatch: original %v, restored %v", original.Time, restored.Time)
|
||||
}
|
||||
}
|
||||
833
internal/dto/responses/responses_test.go
Normal file
833
internal/dto/responses/responses_test.go
Normal file
@@ -0,0 +1,833 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func timePtr(t time.Time) *time.Time { return &t }
|
||||
func uintPtr(v uint) *uint { return &v }
|
||||
func intPtr(v int) *int { return &v }
|
||||
func strPtr(v string) *string { return &v }
|
||||
func float64Ptr(v float64) *float64 { return &v }
|
||||
|
||||
var fixedNow = time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
func makeUser() *models.User {
|
||||
return &models.User{
|
||||
ID: 1,
|
||||
Username: "john",
|
||||
Email: "john@example.com",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
IsActive: true,
|
||||
DateJoined: fixedNow,
|
||||
LastLogin: timePtr(fixedNow),
|
||||
Profile: &models.UserProfile{
|
||||
BaseModel: models.BaseModel{ID: 10},
|
||||
UserID: 1,
|
||||
Verified: true,
|
||||
Bio: "hello",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeUserNoProfile() *models.User {
|
||||
u := makeUser()
|
||||
u.Profile = nil
|
||||
return u
|
||||
}
|
||||
|
||||
// ==================== auth.go ====================
|
||||
|
||||
func TestNewUserResponse_AllFields(t *testing.T) {
|
||||
u := makeUser()
|
||||
resp := NewUserResponse(u)
|
||||
if resp.ID != 1 {
|
||||
t.Errorf("ID = %d, want 1", resp.ID)
|
||||
}
|
||||
if resp.Username != "john" {
|
||||
t.Errorf("Username = %q", resp.Username)
|
||||
}
|
||||
if !resp.Verified {
|
||||
t.Error("Verified should be true when profile is verified")
|
||||
}
|
||||
if resp.LastLogin == nil {
|
||||
t.Error("LastLogin should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUserResponse_NilProfile(t *testing.T) {
|
||||
u := makeUserNoProfile()
|
||||
resp := NewUserResponse(u)
|
||||
if resp.Verified {
|
||||
t.Error("Verified should be false when profile is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUserProfileResponse_Nil(t *testing.T) {
|
||||
resp := NewUserProfileResponse(nil)
|
||||
if resp != nil {
|
||||
t.Error("expected nil for nil profile")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUserProfileResponse_Valid(t *testing.T) {
|
||||
p := &models.UserProfile{
|
||||
BaseModel: models.BaseModel{ID: 5},
|
||||
UserID: 1,
|
||||
Verified: true,
|
||||
Bio: "bio",
|
||||
}
|
||||
resp := NewUserProfileResponse(p)
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if resp.ID != 5 || resp.UserID != 1 || !resp.Verified || resp.Bio != "bio" {
|
||||
t.Errorf("unexpected response: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCurrentUserResponse(t *testing.T) {
|
||||
u := makeUser()
|
||||
resp := NewCurrentUserResponse(u, "apple")
|
||||
if resp.AuthProvider != "apple" {
|
||||
t.Errorf("AuthProvider = %q, want apple", resp.AuthProvider)
|
||||
}
|
||||
if resp.Profile == nil {
|
||||
t.Error("Profile should not be nil")
|
||||
}
|
||||
if resp.ID != 1 {
|
||||
t.Errorf("ID = %d, want 1", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoginResponse(t *testing.T) {
|
||||
u := makeUser()
|
||||
resp := NewLoginResponse("tok123", u)
|
||||
if resp.Token != "tok123" {
|
||||
t.Errorf("Token = %q", resp.Token)
|
||||
}
|
||||
if resp.User.ID != 1 {
|
||||
t.Errorf("User.ID = %d", resp.User.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegisterResponse(t *testing.T) {
|
||||
u := makeUser()
|
||||
resp := NewRegisterResponse("tok456", u)
|
||||
if resp.Token != "tok456" {
|
||||
t.Errorf("Token = %q", resp.Token)
|
||||
}
|
||||
if resp.Message == "" {
|
||||
t.Error("Message should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAppleSignInResponse(t *testing.T) {
|
||||
u := makeUser()
|
||||
resp := NewAppleSignInResponse("atok", u, true)
|
||||
if !resp.IsNewUser {
|
||||
t.Error("IsNewUser should be true")
|
||||
}
|
||||
if resp.Token != "atok" {
|
||||
t.Errorf("Token = %q", resp.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGoogleSignInResponse(t *testing.T) {
|
||||
u := makeUser()
|
||||
resp := NewGoogleSignInResponse("gtok", u, false)
|
||||
if resp.IsNewUser {
|
||||
t.Error("IsNewUser should be false")
|
||||
}
|
||||
if resp.Token != "gtok" {
|
||||
t.Errorf("Token = %q", resp.Token)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== task.go ====================
|
||||
|
||||
func makeTask() *models.Task {
|
||||
due := time.Date(2025, 7, 1, 0, 0, 0, 0, time.UTC)
|
||||
catID := uint(1)
|
||||
priID := uint(2)
|
||||
freqID := uint(3)
|
||||
return &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 100, CreatedAt: fixedNow, UpdatedAt: fixedNow},
|
||||
ResidenceID: 10,
|
||||
CreatedByID: 1,
|
||||
CreatedBy: *makeUser(),
|
||||
Title: "Fix roof",
|
||||
Description: "Repair leak",
|
||||
CategoryID: &catID,
|
||||
Category: &models.TaskCategory{BaseModel: models.BaseModel{ID: catID}, Name: "Exterior", Icon: "roof", Color: "#FF0000", DisplayOrder: 1},
|
||||
PriorityID: &priID,
|
||||
Priority: &models.TaskPriority{BaseModel: models.BaseModel{ID: priID}, Name: "High", Level: 3, Color: "#FF0000", DisplayOrder: 1},
|
||||
FrequencyID: &freqID,
|
||||
Frequency: &models.TaskFrequency{BaseModel: models.BaseModel{ID: freqID}, Name: "Monthly", Days: intPtr(30), DisplayOrder: 1},
|
||||
DueDate: &due,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskResponse_BasicFields(t *testing.T) {
|
||||
task := makeTask()
|
||||
resp := NewTaskResponseWithTime(task, 30, fixedNow)
|
||||
if resp.ID != 100 {
|
||||
t.Errorf("ID = %d", resp.ID)
|
||||
}
|
||||
if resp.Title != "Fix roof" {
|
||||
t.Errorf("Title = %q", resp.Title)
|
||||
}
|
||||
if resp.CreatedBy == nil {
|
||||
t.Error("CreatedBy should not be nil")
|
||||
}
|
||||
if resp.Category == nil {
|
||||
t.Error("Category should not be nil")
|
||||
}
|
||||
if resp.Priority == nil {
|
||||
t.Error("Priority should not be nil")
|
||||
}
|
||||
if resp.Frequency == nil {
|
||||
t.Error("Frequency should not be nil")
|
||||
}
|
||||
if resp.KanbanColumn == "" {
|
||||
t.Error("KanbanColumn should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskResponse_NilAssociations(t *testing.T) {
|
||||
task := &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 200},
|
||||
ResidenceID: 10,
|
||||
CreatedByID: 1,
|
||||
Title: "Simple task",
|
||||
}
|
||||
resp := NewTaskResponseWithTime(task, 30, fixedNow)
|
||||
if resp.CreatedBy != nil {
|
||||
t.Error("CreatedBy should be nil when CreatedBy.ID is 0")
|
||||
}
|
||||
if resp.Category != nil {
|
||||
t.Error("Category should be nil")
|
||||
}
|
||||
if resp.Priority != nil {
|
||||
t.Error("Priority should be nil")
|
||||
}
|
||||
if resp.Frequency != nil {
|
||||
t.Error("Frequency should be nil")
|
||||
}
|
||||
if resp.AssignedTo != nil {
|
||||
t.Error("AssignedTo should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskResponse_WithCompletions(t *testing.T) {
|
||||
task := makeTask()
|
||||
task.Completions = []models.TaskCompletion{
|
||||
{BaseModel: models.BaseModel{ID: 1}, TaskID: 100, CompletedAt: fixedNow, CompletedByID: 1},
|
||||
{BaseModel: models.BaseModel{ID: 2}, TaskID: 100, CompletedAt: fixedNow, CompletedByID: 1},
|
||||
}
|
||||
resp := NewTaskResponseWithTime(task, 30, fixedNow)
|
||||
if resp.CompletionCount != 2 {
|
||||
t.Errorf("CompletionCount = %d, want 2", resp.CompletionCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskResponseWithTime_KanbanColumn(t *testing.T) {
|
||||
task := makeTask()
|
||||
// due date is July 1, now is June 15 → 16 days away → due_soon (within 30 days)
|
||||
resp := NewTaskResponseWithTime(task, 30, fixedNow)
|
||||
if resp.KanbanColumn == "" {
|
||||
t.Error("KanbanColumn should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskListResponse(t *testing.T) {
|
||||
tasks := []models.Task{
|
||||
{BaseModel: models.BaseModel{ID: 1}, Title: "A"},
|
||||
{BaseModel: models.BaseModel{ID: 2}, Title: "B"},
|
||||
}
|
||||
results := NewTaskListResponse(tasks)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("len = %d, want 2", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskListResponse_Empty(t *testing.T) {
|
||||
results := NewTaskListResponse([]models.Task{})
|
||||
if len(results) != 0 {
|
||||
t.Errorf("len = %d, want 0", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskCompletionResponse_WithImages(t *testing.T) {
|
||||
c := &models.TaskCompletion{
|
||||
BaseModel: models.BaseModel{ID: 50},
|
||||
TaskID: 100,
|
||||
CompletedByID: 1,
|
||||
CompletedBy: *makeUser(),
|
||||
CompletedAt: fixedNow,
|
||||
Notes: "done",
|
||||
Images: []models.TaskCompletionImage{
|
||||
{BaseModel: models.BaseModel{ID: 1}, ImageURL: "http://img1.jpg", Caption: "before"},
|
||||
{BaseModel: models.BaseModel{ID: 2}, ImageURL: "http://img2.jpg", Caption: "after"},
|
||||
},
|
||||
}
|
||||
resp := NewTaskCompletionResponse(c)
|
||||
if resp.CompletedBy == nil {
|
||||
t.Error("CompletedBy should not be nil")
|
||||
}
|
||||
if len(resp.Images) != 2 {
|
||||
t.Errorf("Images len = %d, want 2", len(resp.Images))
|
||||
}
|
||||
if resp.Images[0].MediaURL != "/api/media/completion-image/1" {
|
||||
t.Errorf("MediaURL = %q", resp.Images[0].MediaURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskCompletionResponse_EmptyImages(t *testing.T) {
|
||||
c := &models.TaskCompletion{
|
||||
BaseModel: models.BaseModel{ID: 51},
|
||||
TaskID: 100,
|
||||
CompletedByID: 1,
|
||||
CompletedAt: fixedNow,
|
||||
}
|
||||
resp := NewTaskCompletionResponse(c)
|
||||
if resp.Images == nil {
|
||||
t.Error("Images should be empty slice, not nil")
|
||||
}
|
||||
if len(resp.Images) != 0 {
|
||||
t.Errorf("Images len = %d, want 0", len(resp.Images))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewKanbanBoardResponse(t *testing.T) {
|
||||
board := &models.KanbanBoard{
|
||||
Columns: []models.KanbanColumn{
|
||||
{
|
||||
Name: "overdue",
|
||||
DisplayName: "Overdue",
|
||||
Color: "#FF0000",
|
||||
Tasks: []models.Task{{BaseModel: models.BaseModel{ID: 1}, Title: "A"}},
|
||||
Count: 1,
|
||||
},
|
||||
},
|
||||
DaysThreshold: 30,
|
||||
}
|
||||
resp := NewKanbanBoardResponse(board, 10, fixedNow)
|
||||
if len(resp.Columns) != 1 {
|
||||
t.Fatalf("Columns len = %d", len(resp.Columns))
|
||||
}
|
||||
if resp.ResidenceID != "10" {
|
||||
t.Errorf("ResidenceID = %q, want '10'", resp.ResidenceID)
|
||||
}
|
||||
if resp.Columns[0].Count != 1 {
|
||||
t.Errorf("Count = %d", resp.Columns[0].Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewKanbanBoardResponseForAll(t *testing.T) {
|
||||
board := &models.KanbanBoard{
|
||||
Columns: []models.KanbanColumn{},
|
||||
DaysThreshold: 30,
|
||||
}
|
||||
resp := NewKanbanBoardResponseForAll(board, fixedNow)
|
||||
if resp.ResidenceID != "all" {
|
||||
t.Errorf("ResidenceID = %q, want 'all'", resp.ResidenceID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineKanbanColumn_Delegates(t *testing.T) {
|
||||
task := &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 1},
|
||||
Title: "test",
|
||||
}
|
||||
col := DetermineKanbanColumn(task, 30)
|
||||
if col == "" {
|
||||
t.Error("expected non-empty column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskCompletionWithTaskResponse(t *testing.T) {
|
||||
c := &models.TaskCompletion{
|
||||
BaseModel: models.BaseModel{ID: 1},
|
||||
TaskID: 100,
|
||||
CompletedByID: 1,
|
||||
CompletedAt: fixedNow,
|
||||
}
|
||||
task := makeTask()
|
||||
resp := NewTaskCompletionWithTaskResponseWithTime(c, task, 30, fixedNow)
|
||||
if resp.Task == nil {
|
||||
t.Error("Task should not be nil")
|
||||
}
|
||||
if resp.Task.ID != 100 {
|
||||
t.Errorf("Task.ID = %d", resp.Task.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskCompletionWithTaskResponse_NilTask(t *testing.T) {
|
||||
c := &models.TaskCompletion{
|
||||
BaseModel: models.BaseModel{ID: 1},
|
||||
TaskID: 100,
|
||||
CompletedByID: 1,
|
||||
CompletedAt: fixedNow,
|
||||
}
|
||||
resp := NewTaskCompletionWithTaskResponseWithTime(c, nil, 30, fixedNow)
|
||||
if resp.Task != nil {
|
||||
t.Error("Task should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskCompletionListResponse(t *testing.T) {
|
||||
completions := []models.TaskCompletion{
|
||||
{BaseModel: models.BaseModel{ID: 1}, TaskID: 100, CompletedAt: fixedNow, CompletedByID: 1},
|
||||
}
|
||||
results := NewTaskCompletionListResponse(completions)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("len = %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskCategoryResponse_Nil(t *testing.T) {
|
||||
if NewTaskCategoryResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskPriorityResponse_Nil(t *testing.T) {
|
||||
if NewTaskPriorityResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskFrequencyResponse_Nil(t *testing.T) {
|
||||
if NewTaskFrequencyResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskUserResponse_Nil(t *testing.T) {
|
||||
if NewTaskUserResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== contractor.go ====================
|
||||
|
||||
func makeContractor() *models.Contractor {
|
||||
resID := uint(10)
|
||||
return &models.Contractor{
|
||||
BaseModel: models.BaseModel{ID: 5, CreatedAt: fixedNow, UpdatedAt: fixedNow},
|
||||
ResidenceID: &resID,
|
||||
CreatedByID: 1,
|
||||
CreatedBy: *makeUser(),
|
||||
Name: "Bob's Plumbing",
|
||||
Company: "Bob Co",
|
||||
Phone: "555-1234",
|
||||
Email: "bob@plumb.com",
|
||||
Rating: float64Ptr(4.5),
|
||||
IsFavorite: true,
|
||||
IsActive: true,
|
||||
Specialties: []models.ContractorSpecialty{
|
||||
{BaseModel: models.BaseModel{ID: 1}, Name: "Plumbing", Icon: "wrench", DisplayOrder: 1},
|
||||
},
|
||||
Tasks: []models.Task{{BaseModel: models.BaseModel{ID: 1}}, {BaseModel: models.BaseModel{ID: 2}}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContractorResponse_BasicFields(t *testing.T) {
|
||||
c := makeContractor()
|
||||
resp := NewContractorResponse(c)
|
||||
if resp.ID != 5 {
|
||||
t.Errorf("ID = %d", resp.ID)
|
||||
}
|
||||
if resp.Name != "Bob's Plumbing" {
|
||||
t.Errorf("Name = %q", resp.Name)
|
||||
}
|
||||
if resp.AddedBy != 1 {
|
||||
t.Errorf("AddedBy = %d, want 1", resp.AddedBy)
|
||||
}
|
||||
if resp.CreatedBy == nil {
|
||||
t.Error("CreatedBy should not be nil")
|
||||
}
|
||||
if resp.TaskCount != 2 {
|
||||
t.Errorf("TaskCount = %d, want 2", resp.TaskCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContractorResponse_WithSpecialties(t *testing.T) {
|
||||
c := makeContractor()
|
||||
resp := NewContractorResponse(c)
|
||||
if len(resp.Specialties) != 1 {
|
||||
t.Fatalf("Specialties len = %d", len(resp.Specialties))
|
||||
}
|
||||
if resp.Specialties[0].Name != "Plumbing" {
|
||||
t.Errorf("Specialty name = %q", resp.Specialties[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContractorListResponse(t *testing.T) {
|
||||
contractors := []models.Contractor{*makeContractor()}
|
||||
results := NewContractorListResponse(contractors)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("len = %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContractorUserResponse_Nil(t *testing.T) {
|
||||
if NewContractorUserResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContractorSpecialtyResponse(t *testing.T) {
|
||||
s := &models.ContractorSpecialty{
|
||||
BaseModel: models.BaseModel{ID: 1},
|
||||
Name: "Electrical",
|
||||
Description: "Electrical work",
|
||||
Icon: "bolt",
|
||||
DisplayOrder: 2,
|
||||
}
|
||||
resp := NewContractorSpecialtyResponse(s)
|
||||
if resp.Name != "Electrical" || resp.Icon != "bolt" {
|
||||
t.Errorf("unexpected: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== document.go ====================
|
||||
|
||||
func makeDocument() *models.Document {
|
||||
price := decimal.NewFromFloat(99.99)
|
||||
return &models.Document{
|
||||
BaseModel: models.BaseModel{ID: 20, CreatedAt: fixedNow, UpdatedAt: fixedNow},
|
||||
ResidenceID: 10,
|
||||
CreatedByID: 1,
|
||||
CreatedBy: *makeUser(),
|
||||
Title: "Warranty",
|
||||
Description: "Roof warranty",
|
||||
DocumentType: "warranty",
|
||||
FileName: "warranty.pdf",
|
||||
FileSize: func() *int64 { v := int64(1024); return &v }(),
|
||||
MimeType: "application/pdf",
|
||||
PurchasePrice: &price,
|
||||
IsActive: true,
|
||||
Images: []models.DocumentImage{
|
||||
{BaseModel: models.BaseModel{ID: 1}, ImageURL: "http://img.jpg", Caption: "page 1"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDocumentResponse_MediaURL(t *testing.T) {
|
||||
d := makeDocument()
|
||||
resp := NewDocumentResponse(d)
|
||||
want := fmt.Sprintf("/api/media/document/%d", d.ID)
|
||||
if resp.MediaURL != want {
|
||||
t.Errorf("MediaURL = %q, want %q", resp.MediaURL, want)
|
||||
}
|
||||
if resp.Residence != resp.ResidenceID {
|
||||
t.Error("Residence alias should equal ResidenceID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDocumentResponse_WithImages(t *testing.T) {
|
||||
d := makeDocument()
|
||||
resp := NewDocumentResponse(d)
|
||||
if len(resp.Images) != 1 {
|
||||
t.Fatalf("Images len = %d", len(resp.Images))
|
||||
}
|
||||
if resp.Images[0].MediaURL != "/api/media/document-image/1" {
|
||||
t.Errorf("Image MediaURL = %q", resp.Images[0].MediaURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDocumentResponse_EmptyImageURL(t *testing.T) {
|
||||
d := makeDocument()
|
||||
d.Images = []models.DocumentImage{
|
||||
{BaseModel: models.BaseModel{ID: 5}, ImageURL: "", Caption: "missing"},
|
||||
}
|
||||
resp := NewDocumentResponse(d)
|
||||
if resp.Images[0].Error != "image source URL is missing" {
|
||||
t.Errorf("Error = %q", resp.Images[0].Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDocumentListResponse(t *testing.T) {
|
||||
docs := []models.Document{*makeDocument()}
|
||||
results := NewDocumentListResponse(docs)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("len = %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewDocumentUserResponse_Nil(t *testing.T) {
|
||||
if NewDocumentUserResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== residence.go ====================
|
||||
|
||||
func makeResidence() *models.Residence {
|
||||
propTypeID := uint(1)
|
||||
return &models.Residence{
|
||||
BaseModel: models.BaseModel{ID: 10, CreatedAt: fixedNow, UpdatedAt: fixedNow},
|
||||
OwnerID: 1,
|
||||
Owner: *makeUser(),
|
||||
Name: "My House",
|
||||
PropertyTypeID: &propTypeID,
|
||||
PropertyType: &models.ResidenceType{BaseModel: models.BaseModel{ID: 1}, Name: "House"},
|
||||
StreetAddress: "123 Main St",
|
||||
City: "Springfield",
|
||||
StateProvince: "IL",
|
||||
PostalCode: "62701",
|
||||
Country: "USA",
|
||||
Bedrooms: intPtr(3),
|
||||
IsPrimary: true,
|
||||
IsActive: true,
|
||||
HasPool: true,
|
||||
HeatingType: strPtr("central"),
|
||||
Users: []models.User{
|
||||
{ID: 1, Username: "john", Email: "john@example.com"},
|
||||
{ID: 2, Username: "jane", Email: "jane@example.com"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResidenceResponse_AllFields(t *testing.T) {
|
||||
r := makeResidence()
|
||||
resp := NewResidenceResponse(r)
|
||||
if resp.ID != 10 {
|
||||
t.Errorf("ID = %d", resp.ID)
|
||||
}
|
||||
if resp.Name != "My House" {
|
||||
t.Errorf("Name = %q", resp.Name)
|
||||
}
|
||||
if resp.Owner == nil {
|
||||
t.Error("Owner should not be nil")
|
||||
}
|
||||
if resp.PropertyType == nil {
|
||||
t.Error("PropertyType should not be nil")
|
||||
}
|
||||
if !resp.HasPool {
|
||||
t.Error("HasPool should be true")
|
||||
}
|
||||
if resp.HeatingType == nil || *resp.HeatingType != "central" {
|
||||
t.Error("HeatingType should be 'central'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResidenceResponse_WithUsers(t *testing.T) {
|
||||
r := makeResidence()
|
||||
resp := NewResidenceResponse(r)
|
||||
if len(resp.Users) != 2 {
|
||||
t.Errorf("Users len = %d, want 2", len(resp.Users))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResidenceResponse_NoUsers(t *testing.T) {
|
||||
r := makeResidence()
|
||||
r.Users = nil
|
||||
resp := NewResidenceResponse(r)
|
||||
if resp.Users == nil {
|
||||
t.Error("Users should be empty slice, not nil")
|
||||
}
|
||||
if len(resp.Users) != 0 {
|
||||
t.Errorf("Users len = %d, want 0", len(resp.Users))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResidenceListResponse(t *testing.T) {
|
||||
residences := []models.Residence{*makeResidence()}
|
||||
results := NewResidenceListResponse(residences)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("len = %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResidenceUserResponse_Nil(t *testing.T) {
|
||||
if NewResidenceUserResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResidenceTypeResponse_Nil(t *testing.T) {
|
||||
if NewResidenceTypeResponse(nil) != nil {
|
||||
t.Error("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewShareCodeResponse(t *testing.T) {
|
||||
sc := &models.ResidenceShareCode{
|
||||
BaseModel: models.BaseModel{ID: 1, CreatedAt: fixedNow},
|
||||
Code: "ABC123",
|
||||
ResidenceID: 10,
|
||||
CreatedByID: 1,
|
||||
IsActive: true,
|
||||
ExpiresAt: timePtr(fixedNow.Add(24 * time.Hour)),
|
||||
}
|
||||
resp := NewShareCodeResponse(sc)
|
||||
if resp.Code != "ABC123" {
|
||||
t.Errorf("Code = %q", resp.Code)
|
||||
}
|
||||
if resp.ResidenceID != 10 {
|
||||
t.Errorf("ResidenceID = %d", resp.ResidenceID)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== task_template.go ====================
|
||||
|
||||
func TestParseTags_Empty(t *testing.T) {
|
||||
result := parseTags("")
|
||||
if len(result) != 0 {
|
||||
t.Errorf("len = %d, want 0", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTags_Multiple(t *testing.T) {
|
||||
result := parseTags("plumbing,electrical,roofing")
|
||||
if len(result) != 3 {
|
||||
t.Errorf("len = %d, want 3", len(result))
|
||||
}
|
||||
if result[0] != "plumbing" || result[1] != "electrical" || result[2] != "roofing" {
|
||||
t.Errorf("unexpected tags: %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTags_Whitespace(t *testing.T) {
|
||||
result := parseTags(" plumbing , , electrical ")
|
||||
if len(result) != 2 {
|
||||
t.Errorf("len = %d, want 2 (should skip empty after trim)", len(result))
|
||||
}
|
||||
if result[0] != "plumbing" || result[1] != "electrical" {
|
||||
t.Errorf("unexpected tags: %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func makeTemplate(catID *uint, cat *models.TaskCategory) models.TaskTemplate {
|
||||
return models.TaskTemplate{
|
||||
BaseModel: models.BaseModel{ID: 1, CreatedAt: fixedNow, UpdatedAt: fixedNow},
|
||||
Title: "Clean Gutters",
|
||||
Description: "Remove debris",
|
||||
CategoryID: catID,
|
||||
Category: cat,
|
||||
IconIOS: "leaf",
|
||||
IconAndroid: "leaf_android",
|
||||
Tags: "exterior,seasonal",
|
||||
DisplayOrder: 1,
|
||||
IsActive: true,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskTemplateResponse(t *testing.T) {
|
||||
catID := uint(1)
|
||||
cat := &models.TaskCategory{BaseModel: models.BaseModel{ID: 1}, Name: "Exterior"}
|
||||
tmpl := makeTemplate(&catID, cat)
|
||||
resp := NewTaskTemplateResponse(&tmpl)
|
||||
if resp.Title != "Clean Gutters" {
|
||||
t.Errorf("Title = %q", resp.Title)
|
||||
}
|
||||
if len(resp.Tags) != 2 {
|
||||
t.Errorf("Tags len = %d", len(resp.Tags))
|
||||
}
|
||||
if resp.Category == nil {
|
||||
t.Error("Category should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskTemplateResponse_WithRegion(t *testing.T) {
|
||||
tmpl := makeTemplate(nil, nil)
|
||||
tmpl.Regions = []models.ClimateRegion{
|
||||
{BaseModel: models.BaseModel{ID: 5}, Name: "Southeast"},
|
||||
}
|
||||
resp := NewTaskTemplateResponse(&tmpl)
|
||||
if resp.RegionID == nil || *resp.RegionID != 5 {
|
||||
t.Error("RegionID should be 5")
|
||||
}
|
||||
if resp.RegionName != "Southeast" {
|
||||
t.Errorf("RegionName = %q", resp.RegionName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskTemplatesGroupedResponse_Grouping(t *testing.T) {
|
||||
catID := uint(1)
|
||||
cat := &models.TaskCategory{BaseModel: models.BaseModel{ID: 1}, Name: "Exterior"}
|
||||
templates := []models.TaskTemplate{
|
||||
makeTemplate(&catID, cat),
|
||||
makeTemplate(&catID, cat),
|
||||
}
|
||||
resp := NewTaskTemplatesGroupedResponse(templates)
|
||||
if len(resp.Categories) != 1 {
|
||||
t.Fatalf("Categories len = %d, want 1", len(resp.Categories))
|
||||
}
|
||||
if resp.Categories[0].CategoryName != "Exterior" {
|
||||
t.Errorf("CategoryName = %q", resp.Categories[0].CategoryName)
|
||||
}
|
||||
if resp.Categories[0].Count != 2 {
|
||||
t.Errorf("Count = %d, want 2", resp.Categories[0].Count)
|
||||
}
|
||||
if resp.TotalCount != 2 {
|
||||
t.Errorf("TotalCount = %d, want 2", resp.TotalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskTemplatesGroupedResponse_Uncategorized(t *testing.T) {
|
||||
tmpl := makeTemplate(nil, nil)
|
||||
resp := NewTaskTemplatesGroupedResponse([]models.TaskTemplate{tmpl})
|
||||
if len(resp.Categories) != 1 {
|
||||
t.Fatalf("Categories len = %d", len(resp.Categories))
|
||||
}
|
||||
if resp.Categories[0].CategoryName != "Uncategorized" {
|
||||
t.Errorf("CategoryName = %q", resp.Categories[0].CategoryName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTaskTemplateListResponse(t *testing.T) {
|
||||
templates := []models.TaskTemplate{makeTemplate(nil, nil)}
|
||||
results := NewTaskTemplateListResponse(templates)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("len = %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== DetermineKanbanColumnWithTime ====================
|
||||
|
||||
func TestDetermineKanbanColumnWithTime(t *testing.T) {
|
||||
task := makeTask()
|
||||
col := DetermineKanbanColumnWithTime(task, 30, fixedNow)
|
||||
if col == "" {
|
||||
t.Error("expected non-empty column")
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== NewTaskResponse uses NewTaskResponseWithThreshold ====================
|
||||
|
||||
func TestNewTaskResponse_UsesDefault30(t *testing.T) {
|
||||
task := makeTask()
|
||||
resp := NewTaskResponse(task)
|
||||
if resp.ID != 100 {
|
||||
t.Errorf("ID = %d", resp.ID)
|
||||
}
|
||||
// Just verify it doesn't panic and produces a response
|
||||
}
|
||||
|
||||
// ==================== NewTaskCompletionWithTaskResponse UTC variant ====================
|
||||
|
||||
func TestNewTaskCompletionWithTaskResponse_UTC(t *testing.T) {
|
||||
c := &models.TaskCompletion{
|
||||
BaseModel: models.BaseModel{ID: 1},
|
||||
TaskID: 100,
|
||||
CompletedByID: 1,
|
||||
CompletedAt: fixedNow,
|
||||
}
|
||||
task := makeTask()
|
||||
resp := NewTaskCompletionWithTaskResponse(c, task, 30)
|
||||
if resp.Task == nil {
|
||||
t.Error("Task should not be nil")
|
||||
}
|
||||
}
|
||||
105
internal/echohelpers/helpers_test.go
Normal file
105
internal/echohelpers/helpers_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package echohelpers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
key string
|
||||
defaultValue string
|
||||
expected string
|
||||
}{
|
||||
{"returns value when present", "/?status=active", "status", "all", "active"},
|
||||
{"returns default when absent", "/", "status", "all", "all"},
|
||||
{"returns default for empty value", "/?status=", "status", "all", "all"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.query, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
result := DefaultQuery(c, tc.key, tc.defaultValue)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUintParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
paramValue string
|
||||
expected uint
|
||||
expectError bool
|
||||
}{
|
||||
{"valid uint", "42", 42, false},
|
||||
{"zero", "0", 0, false},
|
||||
{"invalid string", "abc", 0, true},
|
||||
{"negative", "-1", 0, true},
|
||||
{"empty", "", 0, true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues(tc.paramValue)
|
||||
|
||||
result, err := ParseUintParam(c, "id")
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIntParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
paramValue string
|
||||
expected int
|
||||
expectError bool
|
||||
}{
|
||||
{"valid int", "42", 42, false},
|
||||
{"zero", "0", 0, false},
|
||||
{"negative", "-5", -5, false},
|
||||
{"invalid string", "abc", 0, true},
|
||||
{"empty", "", 0, true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.SetParamNames("id")
|
||||
c.SetParamValues(tc.paramValue)
|
||||
|
||||
result, err := ParseIntParam(c, "id")
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func setupDeleteAccountHandler(t *testing.T) (*AuthHandler, *echo.Echo, *gorm.DB
|
||||
func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "deletetest", "delete@test.com", "Password123")
|
||||
|
||||
// Create profile for the user
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
@@ -52,7 +52,7 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
|
||||
authGroup.DELETE("/account/", handler.DeleteAccount)
|
||||
|
||||
t.Run("successful deletion with correct password", func(t *testing.T) {
|
||||
password := "password123"
|
||||
password := "Password123"
|
||||
req := map[string]interface{}{
|
||||
"password": password,
|
||||
}
|
||||
@@ -84,7 +84,7 @@ func TestAuthHandler_DeleteAccount_EmailUser(t *testing.T) {
|
||||
func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "wrongpw", "wrongpw@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
@@ -105,7 +105,7 @@ func TestAuthHandler_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
func TestAuthHandler_DeleteAccount_MissingPassword(t *testing.T) {
|
||||
handler, e, db := setupDeleteAccountHandler(t)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "nopw", "nopw@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
@@ -207,7 +207,7 @@ func TestAuthHandler_DeleteAccount_Unauthenticated(t *testing.T) {
|
||||
|
||||
t.Run("unauthenticated request returns 401", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"password": "password123",
|
||||
"password": "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/auth/account/", req, "")
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestAuthHandler_Register(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
FirstName: "New",
|
||||
LastName: "User",
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func TestAuthHandler_Register(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "duplicate",
|
||||
Email: "unique1@test.com",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
@@ -117,7 +117,7 @@ func TestAuthHandler_Register(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "user1",
|
||||
Email: "duplicate@test.com",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/register/", req, "")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
@@ -142,7 +142,7 @@ func TestAuthHandler_Login(t *testing.T) {
|
||||
registerReq := requests.RegisterRequest{
|
||||
Username: "logintest",
|
||||
Email: "login@test.com",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
FirstName: "Test",
|
||||
LastName: "User",
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func TestAuthHandler_Login(t *testing.T) {
|
||||
t.Run("successful login with username", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "logintest",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
@@ -174,7 +174,7 @@ func TestAuthHandler_Login(t *testing.T) {
|
||||
t.Run("successful login with email", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "login@test.com", // Using email as username
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
@@ -199,7 +199,7 @@ func TestAuthHandler_Login(t *testing.T) {
|
||||
t.Run("login with non-existent user", func(t *testing.T) {
|
||||
req := requests.LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/auth/login/", req, "")
|
||||
@@ -223,7 +223,7 @@ func TestAuthHandler_CurrentUser(t *testing.T) {
|
||||
handler, e, userRepo := setupAuthHandler(t)
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "metest", "me@test.com", "Password123")
|
||||
user.FirstName = "Test"
|
||||
user.LastName = "User"
|
||||
userRepo.Update(user)
|
||||
@@ -251,7 +251,7 @@ func TestAuthHandler_UpdateProfile(t *testing.T) {
|
||||
handler, e, userRepo := setupAuthHandler(t)
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "updatetest", "update@test.com", "Password123")
|
||||
userRepo.Update(user)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
@@ -289,7 +289,7 @@ func TestAuthHandler_ForgotPassword(t *testing.T) {
|
||||
registerReq := requests.RegisterRequest{
|
||||
Username: "forgottest",
|
||||
Email: "forgot@test.com",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
}
|
||||
testutil.MakeRequest(e, "POST", "/api/auth/register/", registerReq, "")
|
||||
|
||||
@@ -323,7 +323,7 @@ func TestAuthHandler_Logout(t *testing.T) {
|
||||
handler, e, userRepo := setupAuthHandler(t)
|
||||
|
||||
db := testutil.SetupTestDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "logouttest", "logout@test.com", "Password123")
|
||||
userRepo.Update(user)
|
||||
|
||||
authGroup := e.Group("/api/auth")
|
||||
@@ -350,7 +350,7 @@ func TestAuthHandler_JSONResponses(t *testing.T) {
|
||||
req := requests.RegisterRequest{
|
||||
Username: "jsontest",
|
||||
Email: "json@test.com",
|
||||
Password: "password123",
|
||||
Password: "Password123",
|
||||
FirstName: "JSON",
|
||||
LastName: "Test",
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@@ -180,3 +181,284 @@ func TestContractorHandler_CreateContractor_100Specialties_Returns400(t *testing
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_ListContractors(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Electrician Bob")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", handler.ListContractors)
|
||||
|
||||
t.Run("successful list", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/contractors/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response []map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, response, 2)
|
||||
})
|
||||
|
||||
t.Run("user with no contractors returns empty", func(t *testing.T) {
|
||||
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
|
||||
e2 := testutil.SetupTestRouter()
|
||||
authGroup2 := e2.Group("/api/contractors")
|
||||
authGroup2.Use(testutil.MockAuthMiddleware(otherUser))
|
||||
authGroup2.GET("/", handler.ListContractors)
|
||||
|
||||
w := testutil.MakeRequest(e2, "GET", "/api/contractors/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response []map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, response, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_GetContractor(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/:id/", handler.GetContractor)
|
||||
|
||||
t.Run("successful get", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/contractors/%d/", contractor.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Plumber Joe", response["name"])
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/contractors/99999/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/contractors/invalid/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_UpdateContractor(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.PUT("/:id/", handler.UpdateContractor)
|
||||
|
||||
t.Run("successful update", func(t *testing.T) {
|
||||
newName := "Plumber Joe Updated"
|
||||
req := requests.UpdateContractorRequest{
|
||||
Name: &newName,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "PUT", fmt.Sprintf("/api/contractors/%d/", contractor.ID), req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Plumber Joe Updated", response["name"])
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
newName := "Updated"
|
||||
req := requests.UpdateContractorRequest{Name: &newName}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/contractors/invalid/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
newName := "Updated"
|
||||
req := requests.UpdateContractorRequest{Name: &newName}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/contractors/99999/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_DeleteContractor(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/:id/", handler.DeleteContractor)
|
||||
|
||||
t.Run("successful delete", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", fmt.Sprintf("/api/contractors/%d/", contractor.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "message")
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/contractors/invalid/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/contractors/99999/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_ToggleFavorite(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/:id/toggle-favorite/", handler.ToggleFavorite)
|
||||
|
||||
t.Run("toggle favorite on", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/contractors/%d/toggle-favorite/", contractor.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "is_favorite")
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/invalid/toggle-favorite/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/99999/toggle-favorite/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_ListContractorsByResidence(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/by-residence/:residence_id/", handler.ListContractorsByResidence)
|
||||
|
||||
t.Run("successful list by residence", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/contractors/by-residence/%d/", residence.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response []map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, response, 1)
|
||||
})
|
||||
|
||||
t.Run("invalid residence id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/contractors/by-residence/invalid/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_GetSpecialties(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/specialties/", handler.GetSpecialties)
|
||||
|
||||
t.Run("successful list specialties", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/contractors/specialties/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response []map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, len(response), 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_GetContractorTasks(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Plumber Joe")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/:id/tasks/", handler.GetContractorTasks)
|
||||
|
||||
t.Run("successful get tasks", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/contractors/%d/tasks/", contractor.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/contractors/invalid/tasks/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContractorHandler_CreateContractor_WithOptionalFields(t *testing.T) {
|
||||
handler, e, db := setupContractorHandler(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/contractors")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateContractor)
|
||||
|
||||
t.Run("creation with all optional fields", func(t *testing.T) {
|
||||
rating := 4.5
|
||||
isFavorite := true
|
||||
req := requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Full Contractor",
|
||||
Company: "ABC Plumbing",
|
||||
Phone: "555-1234",
|
||||
Email: "contractor@test.com",
|
||||
Notes: "Great work",
|
||||
Rating: &rating,
|
||||
IsFavorite: &isFavorite,
|
||||
}
|
||||
|
||||
w := testutil.MakeRequest(e, "POST", "/api/contractors/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Full Contractor", response["name"])
|
||||
assert.Equal(t, "ABC Plumbing", response["company"])
|
||||
})
|
||||
}
|
||||
|
||||
@@ -224,3 +224,235 @@ func TestDocumentHandler_DeleteDocument(t *testing.T) {
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_UpdateDocument(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original Title")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.PUT("/:id/", handler.UpdateDocument)
|
||||
|
||||
t.Run("successful update", func(t *testing.T) {
|
||||
newTitle := "Updated Title"
|
||||
req := map[string]interface{}{
|
||||
"title": newTitle,
|
||||
}
|
||||
w := testutil.MakeRequest(e, "PUT", fmt.Sprintf("/api/documents/%d/", doc.ID), req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Title", response["title"])
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{"title": "Updated"}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/documents/invalid/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
req := map[string]interface{}{"title": "Updated"}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/documents/99999/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("access denied for other user", func(t *testing.T) {
|
||||
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
|
||||
e2 := testutil.SetupTestRouter()
|
||||
otherGroup := e2.Group("/api/documents")
|
||||
otherGroup.Use(testutil.MockAuthMiddleware(otherUser))
|
||||
otherGroup.PUT("/:id/", handler.UpdateDocument)
|
||||
|
||||
req := map[string]interface{}{"title": "Hacked"}
|
||||
w := testutil.MakeRequest(e2, "PUT", fmt.Sprintf("/api/documents/%d/", doc.ID), req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_ListDocuments_Filters(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", handler.ListDocuments)
|
||||
|
||||
t.Run("filter by residence", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/documents/?residence=%d", residence.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response []map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, response, 1)
|
||||
})
|
||||
|
||||
t.Run("filter by search", func(t *testing.T) {
|
||||
t.Skip("ILIKE is not supported in SQLite; search filter requires PostgreSQL")
|
||||
})
|
||||
|
||||
t.Run("expiring_soon out of range returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/documents/?expiring_soon=5000", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_ListWarranties(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a warranty document
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Warranty Doc")
|
||||
require.NoError(t, db.Model(doc).Update("document_type", "warranty").Error)
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/warranties/", handler.ListWarranties)
|
||||
|
||||
t.Run("successful list warranties", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/documents/warranties/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_ActivateDeactivateDocument(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Toggle Doc")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/:id/deactivate/", handler.DeactivateDocument)
|
||||
authGroup.POST("/:id/activate/", handler.ActivateDocument)
|
||||
|
||||
t.Run("deactivate document", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/documents/%d/deactivate/", doc.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, false, response["is_active"])
|
||||
})
|
||||
|
||||
t.Run("activate document", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/documents/%d/activate/", doc.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, response["is_active"])
|
||||
})
|
||||
|
||||
t.Run("activate invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/invalid/activate/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("deactivate invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/invalid/deactivate/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("activate not found returns 404", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/99999/activate/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("deactivate not found returns 404", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/99999/deactivate/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_CreateDocument_ValidationErrors(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/", handler.CreateDocument)
|
||||
|
||||
t.Run("missing title returns 400", func(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"residence_id": residence.ID,
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/", body, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("missing residence_id returns 400", func(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"title": "Test Doc",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/", body, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("invalid document_type returns 400", func(t *testing.T) {
|
||||
body := map[string]interface{}{
|
||||
"title": "Test Doc",
|
||||
"residence_id": residence.ID,
|
||||
"document_type": "invalid_type",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/documents/", body, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_GetDocument_InvalidID(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/:id/", handler.GetDocument)
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/documents/invalid/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_DeleteDocument_InvalidID(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/:id/", handler.DeleteDocument)
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/documents/invalid/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDocumentHandler_DeleteDocument_AccessDenied(t *testing.T) {
|
||||
handler, e, db := setupDocumentHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||
|
||||
authGroup := e.Group("/api/documents")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(otherUser))
|
||||
authGroup.DELETE("/:id/", handler.DeleteDocument)
|
||||
|
||||
t.Run("access denied for other user", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", fmt.Sprintf("/api/documents/%d/", doc.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
|
||||
1869
internal/handlers/handler_coverage_test.go
Normal file
1869
internal/handlers/handler_coverage_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -86,3 +86,323 @@ func TestNotificationHandler_ListNotifications_LimitCappedAt200(t *testing.T) {
|
||||
assert.Equal(t, 50, count, "response should use default limit of 50")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_ListNotifications_Pagination(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
createTestNotifications(t, db, user.ID, 20)
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/", handler.ListNotifications)
|
||||
|
||||
t.Run("offset skips notifications", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=5&offset=15", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 5, count, "should return remaining 5 after offset 15")
|
||||
})
|
||||
|
||||
t.Run("response has results array", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=3", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "results")
|
||||
assert.Contains(t, response, "count")
|
||||
results := response["results"].([]interface{})
|
||||
assert.Len(t, results, 3)
|
||||
})
|
||||
|
||||
t.Run("negative limit ignored", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/?limit=-5", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Negative limit should default to 50 (since -5 > 0 is false)
|
||||
count := int(response["count"].(float64))
|
||||
assert.Equal(t, 20, count, "should return all 20 with default limit of 50")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_GetUnreadCount(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create some unread notifications
|
||||
createTestNotifications(t, db, user.ID, 5)
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/unread-count/", handler.GetUnreadCount)
|
||||
|
||||
t.Run("successful unread count", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/unread-count/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "unread_count")
|
||||
unreadCount := int(response["unread_count"].(float64))
|
||||
assert.Equal(t, 5, unreadCount)
|
||||
})
|
||||
|
||||
t.Run("user with no notifications returns zero", func(t *testing.T) {
|
||||
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
|
||||
e2 := testutil.SetupTestRouter()
|
||||
authGroup2 := e2.Group("/api/notifications")
|
||||
authGroup2.Use(testutil.MockAuthMiddleware(otherUser))
|
||||
authGroup2.GET("/unread-count/", handler.GetUnreadCount)
|
||||
|
||||
w := testutil.MakeRequest(e2, "GET", "/api/notifications/unread-count/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, float64(0), response["unread_count"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_MarkAsRead(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create a notification
|
||||
notif := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Test Notification",
|
||||
Body: "Test Body",
|
||||
}
|
||||
require.NoError(t, db.Create(notif).Error)
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/:id/read/", handler.MarkAsRead)
|
||||
|
||||
t.Run("successful mark as read", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/notifications/%d/read/", notif.ID), nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "message")
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/invalid/read/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("not found returns 404", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/99999/read/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_MarkAllAsRead(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
createTestNotifications(t, db, user.ID, 5)
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/mark-all-read/", handler.MarkAllAsRead)
|
||||
authGroup.GET("/unread-count/", handler.GetUnreadCount)
|
||||
|
||||
t.Run("successful mark all as read", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/mark-all-read/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "message")
|
||||
})
|
||||
|
||||
t.Run("unread count is zero after mark all", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/unread-count/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, float64(0), response["unread_count"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_GetPreferences(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/preferences/", handler.GetPreferences)
|
||||
|
||||
t.Run("successful get preferences", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/preferences/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Default preferences should have standard fields
|
||||
assert.Contains(t, response, "task_due_soon")
|
||||
assert.Contains(t, response, "task_overdue")
|
||||
assert.Contains(t, response, "task_completed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_UpdatePreferences(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.PUT("/preferences/", handler.UpdatePreferences)
|
||||
|
||||
t.Run("successful update preferences", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"task_due_soon": false,
|
||||
"task_overdue": true,
|
||||
}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/notifications/preferences/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, false, response["task_due_soon"])
|
||||
assert.Equal(t, true, response["task_overdue"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_RegisterDevice(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/devices/", handler.RegisterDevice)
|
||||
|
||||
t.Run("successful device registration", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"name": "iPhone 15",
|
||||
"device_id": "test-device-id-123",
|
||||
"registration_id": "test-registration-id-abc",
|
||||
"platform": "ios",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
|
||||
t.Run("missing required fields returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"name": "iPhone 15",
|
||||
// Missing device_id, registration_id, platform
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("invalid platform returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"device_id": "test-device-id-456",
|
||||
"registration_id": "test-registration-id-def",
|
||||
"platform": "windows", // invalid
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_ListDevices(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/devices/", handler.ListDevices)
|
||||
|
||||
t.Run("successful list devices", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/notifications/devices/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_UnregisterDevice(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/devices/unregister/", handler.UnregisterDevice)
|
||||
|
||||
t.Run("missing registration_id returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"platform": "ios",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/unregister/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("missing platform returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"registration_id": "test-id",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/unregister/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("invalid platform returns 400", func(t *testing.T) {
|
||||
req := map[string]interface{}{
|
||||
"registration_id": "test-id",
|
||||
"platform": "windows",
|
||||
}
|
||||
w := testutil.MakeRequest(e, "POST", "/api/notifications/devices/unregister/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotificationHandler_DeleteDevice(t *testing.T) {
|
||||
handler, e, db := setupNotificationHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/notifications")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/devices/:id/", handler.DeleteDevice)
|
||||
|
||||
t.Run("missing platform query param returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/notifications/devices/1/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("invalid platform returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/notifications/devices/1/?platform=windows", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/notifications/devices/invalid/?platform=ios", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -567,3 +567,164 @@ func TestResidenceHandler_CreateResidence_NegativeBedrooms_Returns400(t *testing
|
||||
testutil.AssertStatusCode(t, w, http.StatusCreated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_GetMyResidences(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/my-residences/", handler.GetMyResidences)
|
||||
|
||||
t.Run("successful my residences", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/residences/my-residences/", nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
// GetMyResidences returns MyResidencesResponse: {"residences": [...]}
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
residences := response["residences"].([]interface{})
|
||||
assert.Len(t, residences, 2)
|
||||
})
|
||||
|
||||
t.Run("user with no residences returns empty", func(t *testing.T) {
|
||||
noResUser := testutil.CreateTestUser(t, db, "nores", "nores@test.com", "Password123")
|
||||
|
||||
e2 := testutil.SetupTestRouter()
|
||||
authGroup2 := e2.Group("/api/residences")
|
||||
authGroup2.Use(testutil.MockAuthMiddleware(noResUser))
|
||||
authGroup2.GET("/my-residences/", handler.GetMyResidences)
|
||||
|
||||
w := testutil.MakeRequest(e2, "GET", "/api/residences/my-residences/", nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
// GetMyResidences returns MyResidencesResponse: {"residences": [...] or null}
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
if response["residences"] == nil {
|
||||
// null residences means no residences
|
||||
} else {
|
||||
residences := response["residences"].([]interface{})
|
||||
assert.Len(t, residences, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_GetSummary(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/summary/", handler.GetSummary)
|
||||
|
||||
t.Run("successful summary", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/residences/summary/", nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, response, "total_residences")
|
||||
assert.Contains(t, response, "total_tasks")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_UpdateResidence_InvalidID(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.PUT("/:id/", handler.UpdateResidence)
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
newName := "Updated"
|
||||
req := requests.UpdateResidenceRequest{Name: &newName}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/residences/invalid/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("non-existent id returns 403", func(t *testing.T) {
|
||||
newName := "Updated"
|
||||
req := requests.UpdateResidenceRequest{Name: &newName}
|
||||
w := testutil.MakeRequest(e, "PUT", "/api/residences/9999/", req, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_DeleteResidence_InvalidID(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.DELETE("/:id/", handler.DeleteResidence)
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/residences/invalid/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
|
||||
t.Run("non-existent id returns 403", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "DELETE", "/api/residences/9999/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_GetShareCode(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Share Code Test")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.GET("/:id/share-code/", handler.GetShareCode)
|
||||
|
||||
t.Run("no share code returns null", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", fmt.Sprintf("/api/residences/%d/share-code/", residence.ID), nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, response["share_code"])
|
||||
})
|
||||
|
||||
t.Run("invalid id returns 400", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "GET", "/api/residences/invalid/share-code/", nil, "test-token")
|
||||
testutil.AssertStatusCode(t, w, http.StatusBadRequest)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResidenceHandler_GenerateSharePackage(t *testing.T) {
|
||||
handler, e, db := setupResidenceHandler(t)
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Package Test")
|
||||
|
||||
authGroup := e.Group("/api/residences")
|
||||
authGroup.Use(testutil.MockAuthMiddleware(user))
|
||||
authGroup.POST("/:id/generate-share-package/", handler.GenerateSharePackage)
|
||||
|
||||
t.Run("generate share package", func(t *testing.T) {
|
||||
w := testutil.MakeRequest(e, "POST", fmt.Sprintf("/api/residences/%d/generate-share-package/", residence.ID), nil, "test-token")
|
||||
|
||||
testutil.AssertStatusCode(t, w, http.StatusOK)
|
||||
|
||||
var response map[string]interface{}
|
||||
err := json.Unmarshal(w.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, response, "share_code")
|
||||
})
|
||||
}
|
||||
|
||||
211
internal/i18n/i18n_test.go
Normal file
211
internal/i18n/i18n_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package i18n
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
err := Init()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, Bundle)
|
||||
}
|
||||
|
||||
func TestTSimple_EnglishKnownKey(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
localizer := NewLocalizer("en")
|
||||
msg := TSimple(localizer, "error.task_not_found")
|
||||
assert.Equal(t, "Task not found", msg)
|
||||
}
|
||||
|
||||
func TestTSimple_SpanishKnownKey(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
localizer := NewLocalizer("es")
|
||||
msg := TSimple(localizer, "error.invalid_credentials")
|
||||
assert.Equal(t, "Credenciales no validas", msg)
|
||||
}
|
||||
|
||||
func TestT_WithTemplateData(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
localizer := NewLocalizer("en")
|
||||
msg := T(localizer, "message.tasks_report_sent", map[string]interface{}{
|
||||
"Email": "test@example.com",
|
||||
})
|
||||
assert.Contains(t, msg, "test@example.com")
|
||||
}
|
||||
|
||||
func TestTSimple_UnknownKeyReturnsKey(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
localizer := NewLocalizer("en")
|
||||
key := "error.nonexistent_key_that_does_not_exist"
|
||||
msg := TSimple(localizer, key)
|
||||
assert.Equal(t, key, msg)
|
||||
}
|
||||
|
||||
func TestTSimple_FallbackToEnglish(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
// Use a language that may not have all translations — fallback to English
|
||||
localizer := NewLocalizer("xx", "en")
|
||||
msg := TSimple(localizer, "error.task_not_found")
|
||||
assert.Equal(t, "Task not found", msg)
|
||||
}
|
||||
|
||||
func TestT_NilLocalizer_UsesDefault(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
msg := T(nil, "error.task_not_found", nil)
|
||||
assert.Equal(t, "Task not found", msg)
|
||||
}
|
||||
|
||||
func TestNewLocalizer(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
localizer := NewLocalizer("en")
|
||||
assert.NotNil(t, localizer)
|
||||
}
|
||||
|
||||
func TestParseAcceptLanguage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
expected []string
|
||||
}{
|
||||
{"empty returns default", "", []string{"en"}},
|
||||
{"english", "en-US,en;q=0.9", []string{"en", "en"}},
|
||||
{"spanish first", "es,en;q=0.5", []string{"es", "en"}},
|
||||
{"unsupported returns default", "xx-YY", []string{"en"}},
|
||||
{"french", "fr-FR,fr;q=0.9,en;q=0.5", []string{"fr", "fr", "en"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := parseAcceptLanguage(tc.header)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchLocale(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
langs []string
|
||||
expected string
|
||||
}{
|
||||
{"finds supported", []string{"es", "en"}, "es"},
|
||||
{"first match wins", []string{"fr", "de"}, "fr"},
|
||||
{"unsupported returns default", []string{"xx"}, "en"},
|
||||
{"empty returns default", []string{}, "en"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := matchLocale(tc.langs)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleware_SetsLocalizerAndLocale(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Language", "es,en;q=0.5")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := Middleware()(func(c echo.Context) error {
|
||||
// Verify localizer is set
|
||||
localizer := GetLocalizer(c)
|
||||
assert.NotNil(t, localizer)
|
||||
|
||||
// Verify locale is set
|
||||
locale := GetLocale(c)
|
||||
assert.Equal(t, "es", locale)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetLocalizer_NoContextValue_ReturnsDefault(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
localizer := GetLocalizer(c)
|
||||
assert.NotNil(t, localizer)
|
||||
}
|
||||
|
||||
func TestGetLocale_NoContextValue_ReturnsDefault(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
locale := GetLocale(c)
|
||||
assert.Equal(t, "en", locale)
|
||||
}
|
||||
|
||||
func TestLocalizedMessage(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Language", "en")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Set up localizer through middleware
|
||||
handler := Middleware()(func(c echo.Context) error {
|
||||
msg := LocalizedMessage(c, "error.task_not_found")
|
||||
assert.Equal(t, "Task not found", msg)
|
||||
return nil
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestLocalizedMessageWithData(t *testing.T) {
|
||||
require.NoError(t, Init())
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Accept-Language", "en")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := Middleware()(func(c echo.Context) error {
|
||||
msg := LocalizedMessageWithData(c, "message.tasks_report_sent", map[string]interface{}{
|
||||
"Email": "user@example.com",
|
||||
})
|
||||
assert.Contains(t, msg, "user@example.com")
|
||||
return nil
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSupportedLanguages(t *testing.T) {
|
||||
assert.Contains(t, SupportedLanguages, "en")
|
||||
assert.Contains(t, SupportedLanguages, "es")
|
||||
assert.Contains(t, SupportedLanguages, "fr")
|
||||
assert.Equal(t, "en", DefaultLanguage)
|
||||
}
|
||||
@@ -17,10 +17,10 @@ func TestIntegration_ContractorSharingFlow(t *testing.T) {
|
||||
|
||||
// ========== Setup Users ==========
|
||||
// Create user A
|
||||
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "password123")
|
||||
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "Password123")
|
||||
|
||||
// Create user B
|
||||
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "password123")
|
||||
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "Password123")
|
||||
|
||||
// ========== User A creates residence C ==========
|
||||
residenceBody := map[string]interface{}{
|
||||
@@ -180,8 +180,8 @@ func TestIntegration_ContractorAccessWithoutResidenceShare(t *testing.T) {
|
||||
app := setupContractorTest(t)
|
||||
|
||||
// Create two users
|
||||
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "password123")
|
||||
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "password123")
|
||||
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "Password123")
|
||||
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "Password123")
|
||||
|
||||
// User A creates a residence
|
||||
residenceBody := map[string]interface{}{
|
||||
@@ -228,9 +228,9 @@ func TestIntegration_ContractorUpdateAndDeleteAccess(t *testing.T) {
|
||||
app := setupContractorTest(t)
|
||||
|
||||
// Create users
|
||||
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "password123")
|
||||
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "password123")
|
||||
userCToken := app.registerAndLogin(t, "userC", "userC@test.com", "password123")
|
||||
userAToken := app.registerAndLogin(t, "userA", "userA@test.com", "Password123")
|
||||
userBToken := app.registerAndLogin(t, "userB", "userB@test.com", "Password123")
|
||||
userCToken := app.registerAndLogin(t, "userC", "userC@test.com", "Password123")
|
||||
|
||||
// User A creates residence and shares with User B (not User C)
|
||||
residenceBody := map[string]interface{}{"name": "Shared Residence"}
|
||||
|
||||
@@ -379,7 +379,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
|
||||
registerBody := map[string]string{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"password": "password123",
|
||||
"password": "Password123",
|
||||
}
|
||||
w := app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody, "")
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
@@ -388,7 +388,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
|
||||
registerBody2 := map[string]string{
|
||||
"username": "testuser",
|
||||
"email": "different@example.com",
|
||||
"password": "password123",
|
||||
"password": "Password123",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody2, "")
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
@@ -397,7 +397,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
|
||||
registerBody3 := map[string]string{
|
||||
"username": "differentuser",
|
||||
"email": "test@example.com",
|
||||
"password": "password123",
|
||||
"password": "Password123",
|
||||
}
|
||||
w = app.makeAuthenticatedRequest(t, "POST", "/api/auth/register", registerBody3, "")
|
||||
assert.Equal(t, http.StatusConflict, w.Code)
|
||||
@@ -407,7 +407,7 @@ func TestIntegration_DuplicateRegistration(t *testing.T) {
|
||||
|
||||
func TestIntegration_ResidenceFlow(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// 1. Create a residence
|
||||
createBody := map[string]interface{}{
|
||||
@@ -475,8 +475,8 @@ func TestIntegration_ResidenceSharingFlow(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
|
||||
// Create owner and another user
|
||||
ownerToken := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
|
||||
userToken := app.registerAndLogin(t, "shareduser", "shared@test.com", "password123")
|
||||
ownerToken := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
|
||||
userToken := app.registerAndLogin(t, "shareduser", "shared@test.com", "Password123")
|
||||
|
||||
// Create residence as owner
|
||||
createBody := map[string]interface{}{
|
||||
@@ -531,7 +531,7 @@ func TestIntegration_ResidenceSharingFlow(t *testing.T) {
|
||||
|
||||
func TestIntegration_TaskFlow(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create residence first
|
||||
residenceBody := map[string]interface{}{"name": "Task House"}
|
||||
@@ -633,7 +633,7 @@ func TestIntegration_TaskFlow(t *testing.T) {
|
||||
|
||||
func TestIntegration_TasksByResidenceKanban(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "owner", "owner@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Use explicit timezone to test full timezone-aware path
|
||||
testTimezone := "America/Los_Angeles"
|
||||
@@ -682,7 +682,7 @@ func TestIntegration_TasksByResidenceKanban(t *testing.T) {
|
||||
|
||||
func TestIntegration_LookupEndpoints(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "user", "user@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "user", "user@test.com", "Password123")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -721,8 +721,8 @@ func TestIntegration_CrossUserAccessDenied(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
|
||||
// Create two users with their own residences
|
||||
user1Token := app.registerAndLogin(t, "user1", "user1@test.com", "password123")
|
||||
user2Token := app.registerAndLogin(t, "user2", "user2@test.com", "password123")
|
||||
user1Token := app.registerAndLogin(t, "user1", "user1@test.com", "Password123")
|
||||
user2Token := app.registerAndLogin(t, "user2", "user2@test.com", "Password123")
|
||||
|
||||
// User1 creates a residence
|
||||
residenceBody := map[string]interface{}{"name": "User1's House"}
|
||||
@@ -777,7 +777,7 @@ func TestIntegration_CrossUserAccessDenied(t *testing.T) {
|
||||
|
||||
func TestIntegration_ResponseStructure(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "user", "user@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "user", "user@test.com", "Password123")
|
||||
|
||||
// Create residence
|
||||
residenceBody := map[string]interface{}{
|
||||
@@ -1704,7 +1704,7 @@ func setupContractorTest(t *testing.T) *TestApp {
|
||||
// - Verify task moves between kanban columns appropriately
|
||||
func TestIntegration_RecurringTaskLifecycle(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "recurring_user", "recurring@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "recurring_user", "recurring@test.com", "Password123")
|
||||
|
||||
// Create residence
|
||||
residenceBody := map[string]interface{}{"name": "Recurring Task House"}
|
||||
@@ -1904,9 +1904,9 @@ func TestIntegration_MultiUserSharing(t *testing.T) {
|
||||
|
||||
t.Log("Phase 1: Create 3 users")
|
||||
|
||||
tokenA := app.registerAndLogin(t, "user_a", "usera@test.com", "password123")
|
||||
tokenB := app.registerAndLogin(t, "user_b", "userb@test.com", "password123")
|
||||
tokenC := app.registerAndLogin(t, "user_c", "userc@test.com", "password123")
|
||||
tokenA := app.registerAndLogin(t, "user_a", "usera@test.com", "Password123")
|
||||
tokenB := app.registerAndLogin(t, "user_b", "userb@test.com", "Password123")
|
||||
tokenC := app.registerAndLogin(t, "user_c", "userc@test.com", "Password123")
|
||||
|
||||
t.Log("✓ Created users A, B, and C")
|
||||
|
||||
@@ -2098,7 +2098,7 @@ func TestIntegration_MultiUserSharing(t *testing.T) {
|
||||
// - Verify kanban column changes with each transition
|
||||
func TestIntegration_TaskStateTransitions(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "state_user", "state@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "state_user", "state@test.com", "Password123")
|
||||
|
||||
// Create residence
|
||||
residenceBody := map[string]interface{}{"name": "State Transition House"}
|
||||
@@ -2274,7 +2274,7 @@ func TestIntegration_TaskStateTransitions(t *testing.T) {
|
||||
// we're testing the full timezone-aware path, not just UTC defaults.
|
||||
func TestIntegration_DateBoundaryEdgeCases(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "boundary_user", "boundary@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "boundary_user", "boundary@test.com", "Password123")
|
||||
|
||||
// Create residence
|
||||
residenceBody := map[string]interface{}{"name": "Boundary Test House"}
|
||||
@@ -2435,7 +2435,7 @@ func TestIntegration_DateBoundaryEdgeCases(t *testing.T) {
|
||||
// - One where it's already "tomorrow" → task is overdue (due date was "yesterday")
|
||||
func TestIntegration_TimezoneDivergence(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "tz_user", "tz@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "tz_user", "tz@test.com", "Password123")
|
||||
|
||||
// Create residence
|
||||
residenceBody := map[string]interface{}{"name": "Timezone Test House"}
|
||||
@@ -2584,7 +2584,7 @@ func findTaskColumn(kanbanResp map[string]interface{}, taskID uint) string {
|
||||
// - Verify cascading effects
|
||||
func TestIntegration_CascadeOperations(t *testing.T) {
|
||||
app := setupIntegrationTest(t)
|
||||
token := app.registerAndLogin(t, "cascade_user", "cascade@test.com", "password123")
|
||||
token := app.registerAndLogin(t, "cascade_user", "cascade@test.com", "Password123")
|
||||
|
||||
t.Log("Phase 1: Create residence")
|
||||
|
||||
@@ -2721,8 +2721,8 @@ func TestIntegration_MultiUserOperations(t *testing.T) {
|
||||
|
||||
t.Log("Phase 1: Setup users and shared residence")
|
||||
|
||||
tokenA := app.registerAndLogin(t, "multiuser_a", "multiusera@test.com", "password123")
|
||||
tokenB := app.registerAndLogin(t, "multiuser_b", "multiuserb@test.com", "password123")
|
||||
tokenA := app.registerAndLogin(t, "multiuser_a", "multiusera@test.com", "Password123")
|
||||
tokenB := app.registerAndLogin(t, "multiuser_b", "multiuserb@test.com", "Password123")
|
||||
|
||||
// User A creates residence
|
||||
residenceBody := map[string]interface{}{"name": "Multi-User Test House"}
|
||||
|
||||
@@ -265,8 +265,8 @@ func TestE2E_SQLInjection_AdminSort_Blocked(t *testing.T) {
|
||||
adminUserHandler := adminhandlers.NewAdminUserHandler(db)
|
||||
|
||||
// Create a couple of test users to have data to sort
|
||||
testutil.CreateTestUser(t, db, "alice", "alice@test.com", "password123")
|
||||
testutil.CreateTestUser(t, db, "bob", "bob@test.com", "password123")
|
||||
testutil.CreateTestUser(t, db, "alice", "alice@test.com", "Password123")
|
||||
testutil.CreateTestUser(t, db, "bob", "bob@test.com", "Password123")
|
||||
|
||||
// Set up a minimal Echo instance with the admin handler
|
||||
e := echo.New()
|
||||
@@ -322,7 +322,7 @@ func TestE2E_SQLInjection_AdminSort_Blocked(t *testing.T) {
|
||||
// garbage receipt data does NOT upgrade the user to Pro tier.
|
||||
func TestE2E_IAP_InvalidReceipt_NoPro(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, userID := app.registerAndLoginSec(t, "iapuser", "iap@test.com", "password123")
|
||||
token, userID := app.registerAndLoginSec(t, "iapuser", "iap@test.com", "Password123")
|
||||
|
||||
// Create initial subscription (free tier)
|
||||
sub := &models.UserSubscription{UserID: userID, Tier: models.TierFree}
|
||||
@@ -352,7 +352,7 @@ func TestE2E_IAP_InvalidReceipt_NoPro(t *testing.T) {
|
||||
// updates both the completion record and the task's NextDueDate together (P1-5/P1-6).
|
||||
func TestE2E_CompletionTransaction_Atomic(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, _ := app.registerAndLoginSec(t, "atomicuser", "atomic@test.com", "password123")
|
||||
token, _ := app.registerAndLoginSec(t, "atomicuser", "atomic@test.com", "Password123")
|
||||
|
||||
// Create a residence
|
||||
residenceBody := map[string]interface{}{"name": "Atomic Test House"}
|
||||
@@ -423,7 +423,7 @@ func TestE2E_CompletionTransaction_Atomic(t *testing.T) {
|
||||
// on a recurring task recalculates NextDueDate back to the correct value (P1-7).
|
||||
func TestE2E_DeleteCompletion_RecalculatesNextDueDate(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, _ := app.registerAndLoginSec(t, "recuruser", "recur@test.com", "password123")
|
||||
token, _ := app.registerAndLoginSec(t, "recuruser", "recur@test.com", "Password123")
|
||||
|
||||
// Create a residence
|
||||
residenceBody := map[string]interface{}{"name": "Recurring Test House"}
|
||||
@@ -510,7 +510,7 @@ func TestE2E_DeleteCompletion_RecalculatesNextDueDate(t *testing.T) {
|
||||
// configured property limit.
|
||||
func TestE2E_TierLimits_Enforced(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, userID := app.registerAndLoginSec(t, "tieruser", "tier@test.com", "password123")
|
||||
token, userID := app.registerAndLoginSec(t, "tieruser", "tier@test.com", "Password123")
|
||||
|
||||
// Enable global limitations
|
||||
app.DB.Where("1=1").Delete(&models.SubscriptionSettings{})
|
||||
@@ -602,7 +602,7 @@ func TestE2E_AuthAssertion_NoPanics(t *testing.T) {
|
||||
// caps the limit parameter to 200 even if the client requests more.
|
||||
func TestE2E_NotificationLimit_Capped(t *testing.T) {
|
||||
app := setupSecurityTest(t)
|
||||
token, userID := app.registerAndLoginSec(t, "notifuser", "notif@test.com", "password123")
|
||||
token, userID := app.registerAndLoginSec(t, "notifuser", "notif@test.com", "Password123")
|
||||
|
||||
// Create 210 notifications directly in the database
|
||||
for i := 0; i < 210; i++ {
|
||||
|
||||
@@ -164,7 +164,7 @@ func TestIntegration_IsFreeBypassesLimitations(t *testing.T) {
|
||||
app := setupSubscriptionTest(t)
|
||||
|
||||
// Register and login a user
|
||||
token, userID := app.registerAndLogin(t, "freeuser", "free@test.com", "password123")
|
||||
token, userID := app.registerAndLogin(t, "freeuser", "free@test.com", "Password123")
|
||||
|
||||
// Enable global limitations - first delete any existing, then create with enabled
|
||||
app.DB.Where("1=1").Delete(&models.SubscriptionSettings{})
|
||||
@@ -215,7 +215,7 @@ func TestIntegration_IsFreeBypassesCheckLimit(t *testing.T) {
|
||||
app := setupSubscriptionTest(t)
|
||||
|
||||
// Register and login a user
|
||||
_, userID := app.registerAndLogin(t, "limituser", "limit@test.com", "password123")
|
||||
_, userID := app.registerAndLogin(t, "limituser", "limit@test.com", "Password123")
|
||||
|
||||
// Enable global limitations
|
||||
settings := &models.SubscriptionSettings{EnableLimitations: true}
|
||||
@@ -282,7 +282,7 @@ func TestIntegration_IsFreeIndependentOfTier(t *testing.T) {
|
||||
app := setupSubscriptionTest(t)
|
||||
|
||||
// Register and login a user
|
||||
token, userID := app.registerAndLogin(t, "tieruser", "tier@test.com", "password123")
|
||||
token, userID := app.registerAndLogin(t, "tieruser", "tier@test.com", "Password123")
|
||||
|
||||
// Enable global limitations
|
||||
settings := &models.SubscriptionSettings{EnableLimitations: true}
|
||||
@@ -340,7 +340,7 @@ func TestIntegration_IsFreeWhenGlobalLimitationsDisabled(t *testing.T) {
|
||||
app := setupSubscriptionTest(t)
|
||||
|
||||
// Register and login a user
|
||||
token, userID := app.registerAndLogin(t, "globaluser", "global@test.com", "password123")
|
||||
token, userID := app.registerAndLogin(t, "globaluser", "global@test.com", "Password123")
|
||||
|
||||
// Disable global limitations
|
||||
settings := &models.SubscriptionSettings{EnableLimitations: false}
|
||||
|
||||
163
internal/middleware/admin_auth_test.go
Normal file
163
internal/middleware/admin_auth_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestAdminAuth_NoHeader_Returns401(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Authorization required")
|
||||
}
|
||||
|
||||
func TestAdminAuth_InvalidToken_Returns401(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid-jwt-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Invalid token")
|
||||
}
|
||||
|
||||
func TestAdminAuth_TokenSchemeOnly_Returns401(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
|
||||
mw := AdminAuthMiddleware(cfg, nil)
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
// "Token" scheme is not supported for admin auth, only "Bearer"
|
||||
req.Header.Set("Authorization", "Token some-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_NoAdmin_Returns401(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No admin in context
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_WrongType_Returns401(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Wrong type in context
|
||||
c.Set(AdminUserKey, "not-an-admin")
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_NonSuperAdmin_Returns403(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Regular admin (not super admin)
|
||||
admin := &models.AdminUser{
|
||||
Email: "admin@test.com",
|
||||
IsActive: true,
|
||||
Role: models.AdminRoleAdmin,
|
||||
}
|
||||
c.Set(AdminUserKey, admin)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Super admin privileges required")
|
||||
}
|
||||
|
||||
func TestRequireSuperAdmin_SuperAdmin_Passes(t *testing.T) {
|
||||
mw := RequireSuperAdmin()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
admin := &models.AdminUser{
|
||||
Email: "superadmin@test.com",
|
||||
IsActive: true,
|
||||
Role: models.AdminRoleSuperAdmin,
|
||||
}
|
||||
c.Set(AdminUserKey, admin)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
@@ -41,7 +41,7 @@ func createTestUserAndToken(t *testing.T, db *gorm.DB, username string, ageDays
|
||||
Email: username + "@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("password123"))
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
|
||||
token := &models.AuthToken{
|
||||
|
||||
337
internal/middleware/auth_test.go
Normal file
337
internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestTokenAuth_BearerScheme_Accepted(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "bearer_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
user := GetAuthUser(c)
|
||||
require.NotNil(t, user)
|
||||
assert.Equal(t, "bearer_user", user.Username)
|
||||
}
|
||||
|
||||
func TestTokenAuth_InvalidScheme_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "scheme_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Basic "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_MalformedHeader_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "JustATokenWithNoScheme")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_EmptyToken_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token ")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.not_authenticated")
|
||||
}
|
||||
|
||||
func TestTokenAuth_InactiveUser_Rejected(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
user, token := createTestUserAndToken(t, db, "inactive_user", 10)
|
||||
|
||||
// Deactivate the user
|
||||
require.NoError(t, db.Model(user).Update("is_active", false).Error)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.invalid_token")
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_NoToken_PassesThrough(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
// No Authorization header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ValidToken_SetsUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "opt_user", 10)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "opt_user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_ExpiredToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
_, token := createTestUserAndToken(t, db, "expired_opt_user", 91)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestOptionalTokenAuth_InvalidToken_IgnoresUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token nonexistent-token")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.OptionalTokenAuth()(func(c echo.Context) error {
|
||||
user := GetAuthUser(c)
|
||||
if user == nil {
|
||||
return c.String(http.StatusOK, "no-user")
|
||||
}
|
||||
return c.String(http.StatusOK, user.Username)
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-user", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_CustomExpiryDays(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
assert.NotNil(t, m)
|
||||
assert.Equal(t, 30, m.tokenExpiryDays)
|
||||
|
||||
// Token at 29 days should be valid
|
||||
_, token := createTestUserAndToken(t, db, "short_expiry_user", 29)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_ExpiredWithCustomExpiry(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
TokenExpiryDays: 30,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, cfg)
|
||||
|
||||
// Token at 31 days should be expired with 30-day config
|
||||
_, token := createTestUserAndToken(t, db, "custom_expired_user", 31)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set("Authorization", "Token "+token.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := m.TokenAuth()(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error.token_expired")
|
||||
}
|
||||
|
||||
func TestNewAuthMiddlewareWithConfig_NilConfig_UsesDefault(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
|
||||
m := NewAuthMiddlewareWithConfig(db, nil, nil)
|
||||
assert.Equal(t, DefaultTokenExpiryDays, m.tokenExpiryDays)
|
||||
}
|
||||
|
||||
func TestGetAuthToken_ReturnsToken(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, "test-token-value")
|
||||
assert.Equal(t, "test-token-value", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_NilContext_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No token set
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestGetAuthToken_WrongType_ReturnsEmpty(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(AuthTokenKey, 12345) // Wrong type
|
||||
assert.Equal(t, "", GetAuthToken(c))
|
||||
}
|
||||
|
||||
func TestIsTokenExpired_ZeroTime_NotExpired(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil)
|
||||
|
||||
// Legacy tokens without created time should not be expired
|
||||
assert.False(t, m.isTokenExpired(models.AuthToken{}.Created))
|
||||
}
|
||||
|
||||
func TestInvalidateToken_NilCache_NoError(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
m := NewAuthMiddleware(db, nil) // nil cache
|
||||
|
||||
err := m.InvalidateToken(nil, "some-token")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
93
internal/middleware/host_check_test.go
Normal file
93
internal/middleware/host_check_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostCheck_AllowedHost_Passes(t *testing.T) {
|
||||
mw := HostCheck([]string{"api.example.com", "localhost:8000"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "api.example.com"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHostCheck_DisallowedHost_Returns403(t *testing.T) {
|
||||
mw := HostCheck([]string{"api.example.com"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "evil.example.com"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
|
||||
var response map[string]interface{}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Forbidden", response["error"])
|
||||
}
|
||||
|
||||
func TestHostCheck_EmptyAllowedHosts_AllPass(t *testing.T) {
|
||||
mw := HostCheck([]string{})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "any-host.example.com"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHostCheck_LocalhostWithPort_Passes(t *testing.T) {
|
||||
mw := HostCheck([]string{"localhost:8000"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "localhost:8000"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestHostCheck_LocalhostWithoutPort_Denied(t *testing.T) {
|
||||
// Only "localhost:8000" allowed, not plain "localhost"
|
||||
mw := HostCheck([]string{"localhost:8000"})
|
||||
e := echo.New()
|
||||
handler := mw(okHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Host = "localhost"
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
}
|
||||
103
internal/middleware/logger_test.go
Normal file
103
internal/middleware/logger_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestStructuredLogger_Passes_Request(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "ok", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestStructuredLogger_WithUser(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
user := &models.User{Username: "loguser"}
|
||||
user.ID = 42
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Set(AuthUserKey, user)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestStructuredLogger_WithRequestID(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
c.Set(ContextKeyRequestID, "test-request-id-123")
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestStructuredLogger_ErrorStatus(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusInternalServerError, "error")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestStructuredLogger_ClientError(t *testing.T) {
|
||||
mw := StructuredLogger()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusBadRequest, "bad request")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
222
internal/middleware/timezone_test.go
Normal file
222
internal/middleware/timezone_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTimezoneMiddleware_IANATimezone(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "America/New_York")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "America/New_York", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_UTCOffset(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
// Just verify it's not UTC (if offset is non-zero)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "-05:00")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "-05:00", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_NoHeader_DefaultsToUTC(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
// No X-Timezone header
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "UTC", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_InvalidTimezone_DefaultsToUTC(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
handler := mw(func(c echo.Context) error {
|
||||
loc := GetUserTimezone(c)
|
||||
return c.String(http.StatusOK, loc.String())
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "Invalid/Timezone")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "UTC", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_SetsUserNow(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
|
||||
var capturedNow time.Time
|
||||
handler := mw(func(c echo.Context) error {
|
||||
capturedNow = GetUserNow(c)
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "America/Chicago")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The "now" should be the start of the day in the user's timezone
|
||||
assert.Equal(t, 0, capturedNow.Hour())
|
||||
assert.Equal(t, 0, capturedNow.Minute())
|
||||
assert.Equal(t, 0, capturedNow.Second())
|
||||
}
|
||||
|
||||
func TestTimezoneMiddleware_SetsTimezoneName(t *testing.T) {
|
||||
mw := TimezoneMiddleware()
|
||||
e := echo.New()
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
name := GetTimezoneName(c)
|
||||
return c.String(http.StatusOK, name)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test/", nil)
|
||||
req.Header.Set(TimezoneHeader, "Europe/London")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
err := handler(c)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Europe/London", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestGetUserTimezone_NotSet_ReturnsUTC(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// No timezone set in context
|
||||
loc := GetUserTimezone(c)
|
||||
assert.Equal(t, time.UTC, loc)
|
||||
}
|
||||
|
||||
func TestGetUserTimezone_WrongType_ReturnsUTC(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(TimezoneKey, "not-a-location")
|
||||
loc := GetUserTimezone(c)
|
||||
assert.Equal(t, time.UTC, loc)
|
||||
}
|
||||
|
||||
func TestGetUserNow_NotSet_ReturnsUTCNow(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
before := time.Now().UTC()
|
||||
now := GetUserNow(c)
|
||||
after := time.Now().UTC()
|
||||
|
||||
assert.True(t, !now.Before(before.Add(-time.Second)), "now should be roughly after before")
|
||||
assert.True(t, !now.After(after.Add(time.Second)), "now should be roughly before after")
|
||||
}
|
||||
|
||||
func TestGetUserNow_WrongType_ReturnsUTCNow(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(UserNowKey, "not-a-time")
|
||||
now := GetUserNow(c)
|
||||
assert.NotNil(t, now)
|
||||
}
|
||||
|
||||
func TestIsTimezoneChanged_NoChange_ReturnsFalse(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(TimezoneChangedKey, false)
|
||||
assert.False(t, IsTimezoneChanged(c))
|
||||
}
|
||||
|
||||
func TestIsTimezoneChanged_Changed_ReturnsTrue(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
c.Set(TimezoneChangedKey, true)
|
||||
assert.True(t, IsTimezoneChanged(c))
|
||||
}
|
||||
|
||||
func TestIsTimezoneChanged_NotSet_ReturnsFalse(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Not set at all
|
||||
assert.False(t, IsTimezoneChanged(c))
|
||||
}
|
||||
|
||||
func TestParseTimezone_UTCOffsetWithoutColon(t *testing.T) {
|
||||
loc := parseTimezone("-0800")
|
||||
assert.NotEqual(t, time.UTC, loc)
|
||||
assert.Equal(t, "-0800", loc.String())
|
||||
}
|
||||
|
||||
func TestParseTimezone_PositiveOffset(t *testing.T) {
|
||||
loc := parseTimezone("+05:30")
|
||||
assert.NotEqual(t, time.UTC, loc)
|
||||
assert.Equal(t, "+05:30", loc.String())
|
||||
}
|
||||
|
||||
func TestParseTimezone_UTC(t *testing.T) {
|
||||
loc := parseTimezone("UTC")
|
||||
assert.Equal(t, time.UTC, loc)
|
||||
}
|
||||
186
internal/middleware/user_cache_test.go
Normal file
186
internal/middleware/user_cache_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
func TestUserCache_SetAndGet(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "testuser", Email: "test@test.com"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
assert.Equal(t, "testuser", cached.Username)
|
||||
assert.Equal(t, "test@test.com", cached.Email)
|
||||
}
|
||||
|
||||
func TestUserCache_GetNonExistent_ReturnsNil(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
cached := cache.Get(999)
|
||||
assert.Nil(t, cached)
|
||||
}
|
||||
|
||||
func TestUserCache_Expired_ReturnsNil(t *testing.T) {
|
||||
// Very short TTL
|
||||
cache := NewUserCache(1 * time.Millisecond)
|
||||
|
||||
user := &models.User{Username: "expiring_user"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Wait for expiry
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
cached := cache.Get(1)
|
||||
assert.Nil(t, cached, "expired entry should return nil")
|
||||
}
|
||||
|
||||
func TestUserCache_Invalidate(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "to_invalidate"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Verify it's cached
|
||||
require.NotNil(t, cache.Get(1))
|
||||
|
||||
// Invalidate
|
||||
cache.Invalidate(1)
|
||||
|
||||
// Should be gone
|
||||
assert.Nil(t, cache.Get(1))
|
||||
}
|
||||
|
||||
func TestUserCache_ReturnsCopy_NotOriginal(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "original"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Modify the returned copy
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
cached.Username = "modified"
|
||||
|
||||
// Original cache entry should be unaffected
|
||||
cached2 := cache.Get(1)
|
||||
require.NotNil(t, cached2)
|
||||
assert.Equal(t, "original", cached2.Username, "cache should return a copy, not the original")
|
||||
}
|
||||
|
||||
func TestUserCache_SetCopiesInput(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "original"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Modify the input after setting
|
||||
user.Username = "modified_after_set"
|
||||
|
||||
// Cache should still have the original value
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
assert.Equal(t, "original", cached.Username, "cache should store a copy of the input")
|
||||
}
|
||||
|
||||
func TestUserCache_MultipleUsers(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user1 := &models.User{Username: "user1"}
|
||||
user1.ID = 1
|
||||
user2 := &models.User{Username: "user2"}
|
||||
user2.ID = 2
|
||||
|
||||
cache.Set(user1)
|
||||
cache.Set(user2)
|
||||
|
||||
cached1 := cache.Get(1)
|
||||
cached2 := cache.Get(2)
|
||||
|
||||
require.NotNil(t, cached1)
|
||||
require.NotNil(t, cached2)
|
||||
assert.Equal(t, "user1", cached1.Username)
|
||||
assert.Equal(t, "user2", cached2.Username)
|
||||
}
|
||||
|
||||
func TestUserCache_OverwriteEntry(t *testing.T) {
|
||||
cache := NewUserCache(1 * time.Minute)
|
||||
|
||||
user := &models.User{Username: "original"}
|
||||
user.ID = 1
|
||||
|
||||
cache.Set(user)
|
||||
|
||||
// Overwrite with new data
|
||||
updated := &models.User{Username: "updated"}
|
||||
updated.ID = 1
|
||||
|
||||
cache.Set(updated)
|
||||
|
||||
cached := cache.Get(1)
|
||||
require.NotNil(t, cached)
|
||||
assert.Equal(t, "updated", cached.Username)
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_NewEntry(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
// First call should return false (not cached yet)
|
||||
unchanged := tc.GetAndCompare(1, "America/New_York")
|
||||
assert.False(t, unchanged, "first call should indicate a change")
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_SameValue(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
// First call sets the value
|
||||
tc.GetAndCompare(1, "America/New_York")
|
||||
|
||||
// Second call with same value should return true (unchanged)
|
||||
unchanged := tc.GetAndCompare(1, "America/New_York")
|
||||
assert.True(t, unchanged, "same value should indicate no change")
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_DifferentValue(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
// Set initial value
|
||||
tc.GetAndCompare(1, "America/New_York")
|
||||
|
||||
// Update to different value
|
||||
unchanged := tc.GetAndCompare(1, "America/Chicago")
|
||||
assert.False(t, unchanged, "different value should indicate a change")
|
||||
|
||||
// Now the new value is cached
|
||||
unchanged = tc.GetAndCompare(1, "America/Chicago")
|
||||
assert.True(t, unchanged, "same value should indicate no change")
|
||||
}
|
||||
|
||||
func TestTimezoneCache_GetAndCompare_DifferentUsers(t *testing.T) {
|
||||
tc := NewTimezoneCache()
|
||||
|
||||
tc.GetAndCompare(1, "America/New_York")
|
||||
tc.GetAndCompare(2, "Europe/London")
|
||||
|
||||
assert.True(t, tc.GetAndCompare(1, "America/New_York"))
|
||||
assert.True(t, tc.GetAndCompare(2, "Europe/London"))
|
||||
assert.False(t, tc.GetAndCompare(1, "Europe/London"))
|
||||
}
|
||||
626
internal/models/models_coverage_test.go
Normal file
626
internal/models/models_coverage_test.go
Normal file
@@ -0,0 +1,626 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// setupModelsTestDB creates a minimal in-memory SQLite for model-level tests
|
||||
// that require database interaction (e.g., BeforeCreate hooks).
|
||||
func setupModelsTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
err = db.AutoMigrate(&User{}, &AuthToken{}, &UserProfile{})
|
||||
require.NoError(t, err)
|
||||
return db
|
||||
}
|
||||
|
||||
// === Residence model tests ===
|
||||
|
||||
func TestResidence_GetAllUsers(t *testing.T) {
|
||||
owner := User{Username: "owner"}
|
||||
owner.ID = 1
|
||||
member1 := User{Username: "member1"}
|
||||
member1.ID = 2
|
||||
member2 := User{Username: "member2"}
|
||||
member2.ID = 3
|
||||
|
||||
residence := &Residence{
|
||||
OwnerID: owner.ID,
|
||||
Owner: owner,
|
||||
Users: []User{member1, member2},
|
||||
}
|
||||
|
||||
allUsers := residence.GetAllUsers()
|
||||
assert.Len(t, allUsers, 3)
|
||||
assert.Equal(t, "owner", allUsers[0].Username)
|
||||
}
|
||||
|
||||
func TestResidence_HasAccess(t *testing.T) {
|
||||
owner := User{Username: "owner"}
|
||||
owner.ID = 1
|
||||
member := User{Username: "member"}
|
||||
member.ID = 2
|
||||
|
||||
residence := &Residence{
|
||||
OwnerID: owner.ID,
|
||||
Owner: owner,
|
||||
Users: []User{member},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID uint
|
||||
expected bool
|
||||
}{
|
||||
{"owner has access", 1, true},
|
||||
{"member has access", 2, true},
|
||||
{"stranger no access", 99, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, residence.HasAccess(tt.userID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResidence_IsPrimaryOwner(t *testing.T) {
|
||||
residence := &Residence{OwnerID: 1}
|
||||
|
||||
assert.True(t, residence.IsPrimaryOwner(1))
|
||||
assert.False(t, residence.IsPrimaryOwner(2))
|
||||
}
|
||||
|
||||
// === Document model tests ===
|
||||
|
||||
func TestDocument_TableName_DocumentImage(t *testing.T) {
|
||||
di := DocumentImage{}
|
||||
assert.Equal(t, "task_documentimage", di.TableName())
|
||||
}
|
||||
|
||||
func TestDocument_IsWarrantyExpiringSoon(t *testing.T) {
|
||||
future30 := time.Now().UTC().AddDate(0, 0, 15)
|
||||
future90 := time.Now().UTC().AddDate(0, 0, 60)
|
||||
past := time.Now().UTC().AddDate(0, 0, -5)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
doc Document
|
||||
days int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "warranty expiring within threshold",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &future30},
|
||||
days: 30,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "warranty not expiring within threshold",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &future90},
|
||||
days: 30,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "warranty already expired",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &past},
|
||||
days: 30,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-warranty document",
|
||||
doc: Document{DocumentType: DocumentTypeGeneral, ExpiryDate: &future30},
|
||||
days: 30,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "warranty with nil expiry",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: nil},
|
||||
days: 30,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.doc.IsWarrantyExpiringSoon(tt.days))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocument_IsWarrantyExpired(t *testing.T) {
|
||||
past := time.Now().UTC().AddDate(0, 0, -5)
|
||||
future := time.Now().UTC().AddDate(0, 0, 30)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
doc Document
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "expired warranty",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &past},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "active warranty",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: &future},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-warranty",
|
||||
doc: Document{DocumentType: DocumentTypeGeneral, ExpiryDate: &past},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "warranty nil expiry",
|
||||
doc: Document{DocumentType: DocumentTypeWarranty, ExpiryDate: nil},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.doc.IsWarrantyExpired())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocumentType_Constants(t *testing.T) {
|
||||
// Verify document type constants have expected values
|
||||
assert.Equal(t, DocumentType("general"), DocumentTypeGeneral)
|
||||
assert.Equal(t, DocumentType("warranty"), DocumentTypeWarranty)
|
||||
assert.Equal(t, DocumentType("receipt"), DocumentTypeReceipt)
|
||||
assert.Equal(t, DocumentType("contract"), DocumentTypeContract)
|
||||
assert.Equal(t, DocumentType("insurance"), DocumentTypeInsurance)
|
||||
assert.Equal(t, DocumentType("manual"), DocumentTypeManual)
|
||||
}
|
||||
|
||||
// === Notification model tests ===
|
||||
|
||||
func TestNotification_TableName(t *testing.T) {
|
||||
n := Notification{}
|
||||
assert.Equal(t, "notifications_notification", n.TableName())
|
||||
}
|
||||
|
||||
func TestNotificationPreference_TableName(t *testing.T) {
|
||||
np := NotificationPreference{}
|
||||
assert.Equal(t, "notifications_notificationpreference", np.TableName())
|
||||
}
|
||||
|
||||
func TestAPNSDevice_TableName(t *testing.T) {
|
||||
d := APNSDevice{}
|
||||
assert.Equal(t, "push_notifications_apnsdevice", d.TableName())
|
||||
}
|
||||
|
||||
func TestGCMDevice_TableName(t *testing.T) {
|
||||
d := GCMDevice{}
|
||||
assert.Equal(t, "push_notifications_gcmdevice", d.TableName())
|
||||
}
|
||||
|
||||
func TestNotification_MarkAsRead(t *testing.T) {
|
||||
n := &Notification{Read: false}
|
||||
|
||||
n.MarkAsRead()
|
||||
assert.True(t, n.Read)
|
||||
assert.NotNil(t, n.ReadAt)
|
||||
}
|
||||
|
||||
func TestNotification_MarkAsSent(t *testing.T) {
|
||||
n := &Notification{Sent: false}
|
||||
|
||||
n.MarkAsSent()
|
||||
assert.True(t, n.Sent)
|
||||
assert.NotNil(t, n.SentAt)
|
||||
}
|
||||
|
||||
func TestNotificationType_Constants(t *testing.T) {
|
||||
assert.Equal(t, NotificationType("task_due_soon"), NotificationTaskDueSoon)
|
||||
assert.Equal(t, NotificationType("task_overdue"), NotificationTaskOverdue)
|
||||
assert.Equal(t, NotificationType("task_completed"), NotificationTaskCompleted)
|
||||
assert.Equal(t, NotificationType("task_assigned"), NotificationTaskAssigned)
|
||||
assert.Equal(t, NotificationType("residence_shared"), NotificationResidenceShared)
|
||||
assert.Equal(t, NotificationType("warranty_expiring"), NotificationWarrantyExpiring)
|
||||
}
|
||||
|
||||
// === AuthToken model tests ===
|
||||
|
||||
func TestAuthToken_BeforeCreate_GeneratesKey(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "tokenuser",
|
||||
Email: "token@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
token := &AuthToken{UserID: user.ID}
|
||||
err = db.Create(token).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Len(t, token.Key, 40) // 20 bytes = 40 hex chars
|
||||
assert.False(t, token.Created.IsZero())
|
||||
}
|
||||
|
||||
func TestAuthToken_BeforeCreate_PreservesExistingKey(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "tokenuser",
|
||||
Email: "token@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
existingKey := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2"
|
||||
token := &AuthToken{
|
||||
Key: existingKey,
|
||||
UserID: user.ID,
|
||||
}
|
||||
err = db.Create(token).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, existingKey, token.Key)
|
||||
}
|
||||
|
||||
func TestGetOrCreateToken_CreatesNew(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "newtoken",
|
||||
Email: "newtoken@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Equal(t, user.ID, token.UserID)
|
||||
}
|
||||
|
||||
func TestGetOrCreateToken_ReturnsExisting(t *testing.T) {
|
||||
db := setupModelsTestDB(t)
|
||||
|
||||
user := &User{
|
||||
Username: "existingtoken",
|
||||
Email: "existingtoken@test.com",
|
||||
Password: "dummy",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(user).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
token1, err := GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
token2, err := GetOrCreateToken(db, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, token1.Key, token2.Key)
|
||||
}
|
||||
|
||||
// === User model additional tests ===
|
||||
|
||||
func TestUser_SetPassword_And_CheckPassword_Integration(t *testing.T) {
|
||||
user := &User{}
|
||||
err := user.SetPassword("Password123")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, user.CheckPassword("Password123"))
|
||||
assert.False(t, user.CheckPassword("WrongPassword"))
|
||||
assert.False(t, user.CheckPassword(""))
|
||||
assert.False(t, user.CheckPassword("password123")) // case sensitive
|
||||
}
|
||||
|
||||
// === Task model additional tests ===
|
||||
|
||||
func TestTask_IsOverdue_CancelledNotOverdue(t *testing.T) {
|
||||
yesterday := time.Now().UTC().AddDate(0, 0, -2)
|
||||
task := &Task{
|
||||
DueDate: &yesterday,
|
||||
IsCancelled: true,
|
||||
}
|
||||
assert.False(t, task.IsOverdue())
|
||||
}
|
||||
|
||||
func TestTask_IsOverdue_ArchivedNotOverdue(t *testing.T) {
|
||||
yesterday := time.Now().UTC().AddDate(0, 0, -2)
|
||||
task := &Task{
|
||||
DueDate: &yesterday,
|
||||
IsArchived: true,
|
||||
}
|
||||
assert.False(t, task.IsOverdue())
|
||||
}
|
||||
|
||||
func TestTask_IsOverdue_NoDueDateNotOverdue(t *testing.T) {
|
||||
task := &Task{
|
||||
DueDate: nil,
|
||||
}
|
||||
assert.False(t, task.IsOverdue())
|
||||
}
|
||||
|
||||
func TestTask_IsOverdue_CompletedNotOverdue(t *testing.T) {
|
||||
yesterday := time.Now().UTC().AddDate(0, 0, -2)
|
||||
task := &Task{
|
||||
DueDate: &yesterday,
|
||||
NextDueDate: nil,
|
||||
Completions: []TaskCompletion{{CompletedAt: time.Now().UTC()}},
|
||||
}
|
||||
assert.False(t, task.IsOverdue())
|
||||
}
|
||||
|
||||
func TestTask_IsOverdue_CompletionCountNotOverdue(t *testing.T) {
|
||||
yesterday := time.Now().UTC().AddDate(0, 0, -2)
|
||||
task := &Task{
|
||||
DueDate: &yesterday,
|
||||
NextDueDate: nil,
|
||||
CompletionCount: 1,
|
||||
}
|
||||
assert.False(t, task.IsOverdue())
|
||||
}
|
||||
|
||||
func TestTask_IsOverdue_UsesNextDueDate(t *testing.T) {
|
||||
// DueDate is overdue, but NextDueDate is in the future
|
||||
pastDue := time.Now().UTC().AddDate(0, 0, -10)
|
||||
futureDue := time.Now().UTC().AddDate(0, 0, 10)
|
||||
task := &Task{
|
||||
DueDate: &pastDue,
|
||||
NextDueDate: &futureDue,
|
||||
}
|
||||
assert.False(t, task.IsOverdue())
|
||||
}
|
||||
|
||||
func TestTask_IsDueSoon_CancelledNotDueSoon(t *testing.T) {
|
||||
futureDue := time.Now().UTC().AddDate(0, 0, 5)
|
||||
task := &Task{
|
||||
DueDate: &futureDue,
|
||||
IsCancelled: true,
|
||||
}
|
||||
assert.False(t, task.IsDueSoon(30))
|
||||
}
|
||||
|
||||
func TestTask_IsDueSoon_NoDueDateNotDueSoon(t *testing.T) {
|
||||
task := &Task{
|
||||
DueDate: nil,
|
||||
}
|
||||
assert.False(t, task.IsDueSoon(30))
|
||||
}
|
||||
|
||||
func TestTask_IsDueSoon_WithinThreshold(t *testing.T) {
|
||||
futureDue := time.Now().UTC().AddDate(0, 0, 5)
|
||||
task := &Task{
|
||||
DueDate: &futureDue,
|
||||
}
|
||||
assert.True(t, task.IsDueSoon(30))
|
||||
assert.True(t, task.IsDueSoon(10))
|
||||
assert.False(t, task.IsDueSoon(3))
|
||||
}
|
||||
|
||||
func TestTask_IsDueSoon_CompletedNotDueSoon(t *testing.T) {
|
||||
futureDue := time.Now().UTC().AddDate(0, 0, 5)
|
||||
task := &Task{
|
||||
DueDate: &futureDue,
|
||||
NextDueDate: nil,
|
||||
Completions: []TaskCompletion{{CompletedAt: time.Now().UTC()}},
|
||||
}
|
||||
assert.False(t, task.IsDueSoon(30))
|
||||
}
|
||||
|
||||
func TestTaskCompletionImage_TableName(t *testing.T) {
|
||||
tci := TaskCompletionImage{}
|
||||
assert.Equal(t, "task_taskcompletionimage", tci.TableName())
|
||||
}
|
||||
|
||||
// === Subscription model additional tests ===
|
||||
|
||||
func TestSubscription_TableNames(t *testing.T) {
|
||||
assert.Equal(t, "subscription_subscriptionsettings", SubscriptionSettings{}.TableName())
|
||||
assert.Equal(t, "subscription_usersubscription", UserSubscription{}.TableName())
|
||||
assert.Equal(t, "subscription_upgradetrigger", UpgradeTrigger{}.TableName())
|
||||
assert.Equal(t, "subscription_featurebenefit", FeatureBenefit{}.TableName())
|
||||
assert.Equal(t, "subscription_promotion", Promotion{}.TableName())
|
||||
assert.Equal(t, "subscription_tierlimits", TierLimits{}.TableName())
|
||||
}
|
||||
|
||||
func TestSubscription_IsActive(t *testing.T) {
|
||||
future := time.Now().UTC().Add(24 * time.Hour)
|
||||
past := time.Now().UTC().Add(-24 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sub *UserSubscription
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "pro with future expiry is active",
|
||||
sub: &UserSubscription{Tier: TierPro, ExpiresAt: &future},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pro with nil expiry is active",
|
||||
sub: &UserSubscription{Tier: TierPro, ExpiresAt: nil},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pro with past expiry is not active",
|
||||
sub: &UserSubscription{Tier: TierPro, ExpiresAt: &past},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "free with active trial is active",
|
||||
sub: &UserSubscription{Tier: TierFree, TrialEnd: &future},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "free without trial is not active",
|
||||
sub: &UserSubscription{Tier: TierFree},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.sub.IsActive())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscription_SubscriptionSource(t *testing.T) {
|
||||
sub := &UserSubscription{Platform: "ios"}
|
||||
assert.Equal(t, "ios", sub.SubscriptionSource())
|
||||
}
|
||||
|
||||
func TestPromotion_IsCurrentlyActive(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
promo Promotion
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "active promotion within dates",
|
||||
promo: Promotion{
|
||||
IsActive: true,
|
||||
StartDate: now.Add(-1 * time.Hour),
|
||||
EndDate: now.Add(1 * time.Hour),
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "inactive promotion",
|
||||
promo: Promotion{
|
||||
IsActive: false,
|
||||
StartDate: now.Add(-1 * time.Hour),
|
||||
EndDate: now.Add(1 * time.Hour),
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "promotion not yet started",
|
||||
promo: Promotion{
|
||||
IsActive: true,
|
||||
StartDate: now.Add(1 * time.Hour),
|
||||
EndDate: now.Add(2 * time.Hour),
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "promotion already ended",
|
||||
promo: Promotion{
|
||||
IsActive: true,
|
||||
StartDate: now.Add(-2 * time.Hour),
|
||||
EndDate: now.Add(-1 * time.Hour),
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.promo.IsCurrentlyActive())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDefaultFreeLimits(t *testing.T) {
|
||||
limits := GetDefaultFreeLimits()
|
||||
assert.Equal(t, TierFree, limits.Tier)
|
||||
require.NotNil(t, limits.PropertiesLimit)
|
||||
require.NotNil(t, limits.TasksLimit)
|
||||
require.NotNil(t, limits.ContractorsLimit)
|
||||
require.NotNil(t, limits.DocumentsLimit)
|
||||
assert.Equal(t, 1, *limits.PropertiesLimit)
|
||||
assert.Equal(t, 10, *limits.TasksLimit)
|
||||
assert.Equal(t, 0, *limits.ContractorsLimit)
|
||||
assert.Equal(t, 0, *limits.DocumentsLimit)
|
||||
}
|
||||
|
||||
func TestGetDefaultProLimits(t *testing.T) {
|
||||
limits := GetDefaultProLimits()
|
||||
assert.Equal(t, TierPro, limits.Tier)
|
||||
assert.Nil(t, limits.PropertiesLimit)
|
||||
assert.Nil(t, limits.TasksLimit)
|
||||
assert.Nil(t, limits.ContractorsLimit)
|
||||
assert.Nil(t, limits.DocumentsLimit)
|
||||
}
|
||||
|
||||
// === ConfirmationCode additional tests ===
|
||||
|
||||
func TestConfirmationCode_TableName(t *testing.T) {
|
||||
cc := ConfirmationCode{}
|
||||
assert.Equal(t, "user_confirmationcode", cc.TableName())
|
||||
}
|
||||
|
||||
// === PasswordResetCode additional tests ===
|
||||
|
||||
func TestPasswordResetCode_TableName(t *testing.T) {
|
||||
prc := PasswordResetCode{}
|
||||
assert.Equal(t, "user_passwordresetcode", prc.TableName())
|
||||
}
|
||||
|
||||
// === Social Auth TableName tests ===
|
||||
|
||||
func TestAppleSocialAuth_TableName(t *testing.T) {
|
||||
a := AppleSocialAuth{}
|
||||
assert.Equal(t, "user_applesocialauth", a.TableName())
|
||||
}
|
||||
|
||||
func TestGoogleSocialAuth_TableName(t *testing.T) {
|
||||
g := GoogleSocialAuth{}
|
||||
assert.Equal(t, "user_googlesocialauth", g.TableName())
|
||||
}
|
||||
|
||||
// === BaseModel tests ===
|
||||
|
||||
func TestBaseModel_BeforeCreate(t *testing.T) {
|
||||
b := &BaseModel{}
|
||||
err := b.BeforeCreate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, b.CreatedAt.IsZero())
|
||||
assert.False(t, b.UpdatedAt.IsZero())
|
||||
}
|
||||
|
||||
func TestBaseModel_BeforeCreate_PreservesExisting(t *testing.T) {
|
||||
existingTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
b := &BaseModel{
|
||||
CreatedAt: existingTime,
|
||||
UpdatedAt: existingTime,
|
||||
}
|
||||
err := b.BeforeCreate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, existingTime, b.CreatedAt)
|
||||
assert.Equal(t, existingTime, b.UpdatedAt)
|
||||
}
|
||||
|
||||
func TestBaseModel_BeforeUpdate(t *testing.T) {
|
||||
oldTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
b := &BaseModel{
|
||||
UpdatedAt: oldTime,
|
||||
}
|
||||
err := b.BeforeUpdate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, b.UpdatedAt.After(oldTime))
|
||||
}
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
func TestUser_SetPassword(t *testing.T) {
|
||||
user := &User{}
|
||||
|
||||
err := user.SetPassword("testpassword123")
|
||||
err := user.SetPassword("testPassword123")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, user.Password)
|
||||
assert.NotEqual(t, "testpassword123", user.Password) // Should be hashed
|
||||
assert.NotEqual(t, "testPassword123", user.Password) // Should be hashed
|
||||
}
|
||||
|
||||
func TestUser_CheckPassword(t *testing.T) {
|
||||
|
||||
233
internal/monitoring/monitoring_test.go
Normal file
233
internal/monitoring/monitoring_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package monitoring
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLogFilters_GetLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
limit int
|
||||
expected int
|
||||
}{
|
||||
{"default for zero", 0, 100},
|
||||
{"default for negative", -5, 100},
|
||||
{"capped at 1000", 2000, 1000},
|
||||
{"exactly 1000", 1000, 1000},
|
||||
{"normal value", 50, 50},
|
||||
{"minimum valid", 1, 1},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
f := LogFilters{Limit: tc.limit}
|
||||
assert.Equal(t, tc.expected, f.GetLimit())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPStatsCollector_Record_And_GetStats(t *testing.T) {
|
||||
c := NewHTTPStatsCollector()
|
||||
|
||||
c.Record("GET /api/tasks/", 100*time.Millisecond, 200)
|
||||
c.Record("GET /api/tasks/", 200*time.Millisecond, 200)
|
||||
c.Record("POST /api/tasks/", 50*time.Millisecond, 201)
|
||||
c.Record("GET /api/tasks/", 300*time.Millisecond, 500)
|
||||
|
||||
stats := c.GetStats()
|
||||
|
||||
assert.Equal(t, int64(4), stats.RequestsTotal)
|
||||
assert.True(t, stats.RequestsPerMinute > 0)
|
||||
|
||||
// Check endpoint stats
|
||||
taskStats, ok := stats.ByEndpoint["GET /api/tasks/"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, int64(3), taskStats.Count)
|
||||
assert.True(t, taskStats.AvgLatencyMs > 0)
|
||||
|
||||
postStats, ok := stats.ByEndpoint["POST /api/tasks/"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, int64(1), postStats.Count)
|
||||
|
||||
// Check status codes
|
||||
assert.Equal(t, int64(2), stats.ByStatusCode[200])
|
||||
assert.Equal(t, int64(1), stats.ByStatusCode[201])
|
||||
assert.Equal(t, int64(1), stats.ByStatusCode[500])
|
||||
|
||||
// Error rate should include 500 status
|
||||
assert.True(t, stats.ErrorRate > 0)
|
||||
}
|
||||
|
||||
func TestHTTPStatsCollector_Reset(t *testing.T) {
|
||||
c := NewHTTPStatsCollector()
|
||||
c.Record("GET /api/tasks/", 100*time.Millisecond, 200)
|
||||
|
||||
c.Reset()
|
||||
|
||||
stats := c.GetStats()
|
||||
assert.Equal(t, int64(0), stats.RequestsTotal)
|
||||
assert.Empty(t, stats.ByEndpoint)
|
||||
assert.Empty(t, stats.ByStatusCode)
|
||||
}
|
||||
|
||||
func TestHTTPStatsCollector_ErrorRate(t *testing.T) {
|
||||
c := NewHTTPStatsCollector()
|
||||
c.Record("GET /api/tasks/", 10*time.Millisecond, 200)
|
||||
c.Record("GET /api/tasks/", 10*time.Millisecond, 400)
|
||||
c.Record("GET /api/tasks/", 10*time.Millisecond, 500)
|
||||
|
||||
stats := c.GetStats()
|
||||
ep := stats.ByEndpoint["GET /api/tasks/"]
|
||||
|
||||
// 2 out of 3 are errors (400 and 500)
|
||||
assert.InDelta(t, 2.0/3.0, ep.ErrorRate, 0.001)
|
||||
}
|
||||
|
||||
func TestHTTPStatsCollector_P95(t *testing.T) {
|
||||
c := NewHTTPStatsCollector()
|
||||
|
||||
// Record 100 requests with increasing latencies
|
||||
for i := 1; i <= 100; i++ {
|
||||
c.Record("GET /api/test/", time.Duration(i)*time.Millisecond, 200)
|
||||
}
|
||||
|
||||
stats := c.GetStats()
|
||||
ep := stats.ByEndpoint["GET /api/test/"]
|
||||
// P95 should be around 95ms
|
||||
assert.True(t, ep.P95LatencyMs >= 90, "P95 should be >= 90ms, got %f", ep.P95LatencyMs)
|
||||
}
|
||||
|
||||
func TestHTTPStatsCollector_EmptyStats(t *testing.T) {
|
||||
c := NewHTTPStatsCollector()
|
||||
stats := c.GetStats()
|
||||
|
||||
assert.Equal(t, int64(0), stats.RequestsTotal)
|
||||
assert.Equal(t, float64(0), stats.AvgLatencyMs)
|
||||
assert.Equal(t, float64(0), stats.ErrorRate)
|
||||
assert.Equal(t, float64(0), stats.RequestsPerMinute)
|
||||
}
|
||||
|
||||
func TestHTTPStatsCollector_EndpointOverflow(t *testing.T) {
|
||||
c := NewHTTPStatsCollector()
|
||||
|
||||
// Fill up to maxEndpoints unique endpoints
|
||||
for i := 0; i < maxEndpoints+10; i++ {
|
||||
endpoint := "GET /api/test/" + string(rune('A'+i%26)) + string(rune('0'+i/26))
|
||||
c.Record(endpoint, 10*time.Millisecond, 200)
|
||||
}
|
||||
|
||||
stats := c.GetStats()
|
||||
// Should have at most maxEndpoints + 1 (the OTHER bucket)
|
||||
assert.LessOrEqual(t, len(stats.ByEndpoint), maxEndpoints+1)
|
||||
}
|
||||
|
||||
func TestWSMessageConstants(t *testing.T) {
|
||||
assert.Equal(t, "log", WSMessageTypeLog)
|
||||
assert.Equal(t, "stats", WSMessageTypeStats)
|
||||
}
|
||||
|
||||
func TestRedisLogWriter_Write_Disabled(t *testing.T) {
|
||||
// Create a writer with a nil buffer -- won't actually push to Redis
|
||||
// but we can test the enabled/disabled logic
|
||||
w := &RedisLogWriter{
|
||||
process: "api",
|
||||
ch: make(chan LogEntry, writerChannelSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
w.enabled.Store(false)
|
||||
|
||||
// Start drain loop (reads from channel)
|
||||
go func() {
|
||||
defer close(w.done)
|
||||
for range w.ch {
|
||||
}
|
||||
}()
|
||||
|
||||
n, err := w.Write([]byte(`{"level":"info","message":"test"}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, n, 0)
|
||||
|
||||
// Channel should be empty since writer is disabled
|
||||
assert.Equal(t, 0, len(w.ch))
|
||||
|
||||
close(w.ch)
|
||||
<-w.done
|
||||
}
|
||||
|
||||
func TestRedisLogWriter_Write_Enabled(t *testing.T) {
|
||||
w := &RedisLogWriter{
|
||||
process: "api",
|
||||
ch: make(chan LogEntry, writerChannelSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
w.enabled.Store(true)
|
||||
|
||||
go func() {
|
||||
defer close(w.done)
|
||||
for range w.ch {
|
||||
}
|
||||
}()
|
||||
|
||||
n, err := w.Write([]byte(`{"level":"info","message":"hello","caller":"main.go:10"}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, n, 0)
|
||||
|
||||
// Give the goroutine a moment, then close
|
||||
close(w.ch)
|
||||
<-w.done
|
||||
}
|
||||
|
||||
func TestRedisLogWriter_Write_InvalidJSON(t *testing.T) {
|
||||
w := &RedisLogWriter{
|
||||
process: "api",
|
||||
ch: make(chan LogEntry, writerChannelSize),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
w.enabled.Store(true)
|
||||
|
||||
go func() {
|
||||
defer close(w.done)
|
||||
for range w.ch {
|
||||
}
|
||||
}()
|
||||
|
||||
// Non-JSON input should be silently skipped
|
||||
n, err := w.Write([]byte("not json at all"))
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, n, 0)
|
||||
assert.Equal(t, 0, len(w.ch))
|
||||
|
||||
close(w.ch)
|
||||
<-w.done
|
||||
}
|
||||
|
||||
func TestRedisLogWriter_SetEnabled_IsEnabled(t *testing.T) {
|
||||
w := &RedisLogWriter{
|
||||
process: "api",
|
||||
ch: make(chan LogEntry, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
w.enabled.Store(true)
|
||||
|
||||
assert.True(t, w.IsEnabled())
|
||||
w.SetEnabled(false)
|
||||
assert.False(t, w.IsEnabled())
|
||||
w.SetEnabled(true)
|
||||
assert.True(t, w.IsEnabled())
|
||||
}
|
||||
|
||||
func TestCollector_Stop_MultipleCallsSafe(t *testing.T) {
|
||||
c := &Collector{
|
||||
process: "api",
|
||||
startTime: time.Now(),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Should not panic on multiple calls
|
||||
c.Stop()
|
||||
c.Stop()
|
||||
c.Stop()
|
||||
}
|
||||
359
internal/push/push_coverage_test.go
Normal file
359
internal/push/push_coverage_test.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// === truncateToken tests ===
|
||||
|
||||
func TestTruncateToken_LongToken(t *testing.T) {
|
||||
token := "abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
result := truncateToken(token)
|
||||
assert.Equal(t, "abcdefgh...", result)
|
||||
}
|
||||
|
||||
func TestTruncateToken_ShortToken(t *testing.T) {
|
||||
token := "abc"
|
||||
result := truncateToken(token)
|
||||
assert.Equal(t, "abc", result)
|
||||
}
|
||||
|
||||
func TestTruncateToken_ExactlyEightChars(t *testing.T) {
|
||||
token := "12345678"
|
||||
result := truncateToken(token)
|
||||
assert.Equal(t, "12345678", result)
|
||||
}
|
||||
|
||||
func TestTruncateToken_NineChars(t *testing.T) {
|
||||
token := "123456789"
|
||||
result := truncateToken(token)
|
||||
assert.Equal(t, "12345678...", result)
|
||||
}
|
||||
|
||||
func TestTruncateToken_Empty(t *testing.T) {
|
||||
result := truncateToken("")
|
||||
assert.Equal(t, "", result)
|
||||
}
|
||||
|
||||
// === Client tests ===
|
||||
|
||||
func TestClient_SendToIOS_Disabled(t *testing.T) {
|
||||
client := &Client{
|
||||
enabled: false,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToIOS(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err) // Returns nil when disabled
|
||||
}
|
||||
|
||||
func TestClient_SendToIOS_NilAPNs(t *testing.T) {
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
apns: nil, // Not initialized
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToIOS(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err) // Returns nil when not initialized
|
||||
}
|
||||
|
||||
func TestClient_SendToIOS_CircuitBreakerOpen(t *testing.T) {
|
||||
breaker := NewCircuitBreaker("apns", WithFailureThreshold(1))
|
||||
breaker.RecordFailure() // Open the circuit
|
||||
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
apns: &APNsClient{}, // Non-nil so we pass the nil check
|
||||
apnsBreaker: breaker,
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToIOS(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.ErrorIs(t, err, ErrCircuitOpen)
|
||||
}
|
||||
|
||||
func TestClient_SendToAndroid_Disabled(t *testing.T) {
|
||||
client := &Client{
|
||||
enabled: false,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendToAndroid_NilFCM(t *testing.T) {
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
fcm: nil,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendToAndroid_CircuitBreakerOpen(t *testing.T) {
|
||||
breaker := NewCircuitBreaker("fcm", WithFailureThreshold(1))
|
||||
breaker.RecordFailure()
|
||||
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
fcm: &FCMClient{}, // Non-nil
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: breaker,
|
||||
}
|
||||
|
||||
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.ErrorIs(t, err, ErrCircuitOpen)
|
||||
}
|
||||
|
||||
func TestClient_SendToAndroid_Success(t *testing.T) {
|
||||
server := serveFCMV1Success(t)
|
||||
defer server.Close()
|
||||
|
||||
fcmClient := newTestFCMClient(server.URL)
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
fcm: fcmClient,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendToAndroid_Failure_RecordsInBreaker(t *testing.T) {
|
||||
server := serveFCMV1Error(t, http.StatusInternalServerError, "INTERNAL", "internal error")
|
||||
defer server.Close()
|
||||
|
||||
fcmClient := newTestFCMClient(server.URL)
|
||||
breaker := NewCircuitBreaker("fcm", WithFailureThreshold(3))
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
fcm: fcmClient,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: breaker,
|
||||
}
|
||||
|
||||
err := client.SendToAndroid(context.Background(), []string{"token1"}, "Title", "Body", nil)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, breaker.Counts())
|
||||
}
|
||||
|
||||
func TestClient_SendToAll_Disabled(t *testing.T) {
|
||||
client := &Client{
|
||||
enabled: false,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToAll(context.Background(), []string{"ios-token"}, []string{"android-token"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendToAll_EmptyTokens(t *testing.T) {
|
||||
server := serveFCMV1Success(t)
|
||||
defer server.Close()
|
||||
|
||||
fcmClient := newTestFCMClient(server.URL)
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
fcm: fcmClient,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
// No tokens at all — should just return nil
|
||||
err := client.SendToAll(context.Background(), []string{}, []string{}, "Title", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendToAll_AndroidOnly(t *testing.T) {
|
||||
server := serveFCMV1Success(t)
|
||||
defer server.Close()
|
||||
|
||||
fcmClient := newTestFCMClient(server.URL)
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
fcm: fcmClient,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendToAll(context.Background(), []string{}, []string{"android-token"}, "Title", "Body", nil)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_IsIOSEnabled(t *testing.T) {
|
||||
clientWithAPNS := &Client{apns: &APNsClient{}}
|
||||
clientWithoutAPNS := &Client{apns: nil}
|
||||
|
||||
assert.True(t, clientWithAPNS.IsIOSEnabled())
|
||||
assert.False(t, clientWithoutAPNS.IsIOSEnabled())
|
||||
}
|
||||
|
||||
func TestClient_IsAndroidEnabled(t *testing.T) {
|
||||
clientWithFCM := &Client{fcm: &FCMClient{}}
|
||||
clientWithoutFCM := &Client{fcm: nil}
|
||||
|
||||
assert.True(t, clientWithFCM.IsAndroidEnabled())
|
||||
assert.False(t, clientWithoutFCM.IsAndroidEnabled())
|
||||
}
|
||||
|
||||
func TestClient_HealthCheck(t *testing.T) {
|
||||
client := &Client{
|
||||
apns: nil,
|
||||
fcm: nil,
|
||||
}
|
||||
err := client.HealthCheck(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
client.fcm = &FCMClient{}
|
||||
err = client.HealthCheck(context.Background())
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendActionableNotification_Disabled(t *testing.T) {
|
||||
client := &Client{
|
||||
enabled: false,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendActionableNotification(context.Background(), []string{"ios"}, []string{"android"}, "Title", "Body", nil, "TASK_DUE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendActionableNotification_NilAPNs(t *testing.T) {
|
||||
server := serveFCMV1Success(t)
|
||||
defer server.Close()
|
||||
|
||||
fcmClient := newTestFCMClient(server.URL)
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
apns: nil,
|
||||
fcm: fcmClient,
|
||||
apnsBreaker: NewCircuitBreaker("apns"),
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
// Should skip iOS and send Android
|
||||
err := client.SendActionableNotification(context.Background(), []string{"ios"}, []string{"android"}, "Title", "Body", nil, "TASK_DUE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_SendActionableNotification_APNsBreakerOpen(t *testing.T) {
|
||||
server := serveFCMV1Success(t)
|
||||
defer server.Close()
|
||||
|
||||
fcmClient := newTestFCMClient(server.URL)
|
||||
apnsBreaker := NewCircuitBreaker("apns", WithFailureThreshold(1))
|
||||
apnsBreaker.RecordFailure()
|
||||
|
||||
client := &Client{
|
||||
enabled: true,
|
||||
apns: &APNsClient{},
|
||||
fcm: fcmClient,
|
||||
apnsBreaker: apnsBreaker,
|
||||
fcmBreaker: NewCircuitBreaker("fcm"),
|
||||
}
|
||||
|
||||
err := client.SendActionableNotification(context.Background(), []string{"ios"}, []string{"android"}, "Title", "Body", nil, "TASK_DUE")
|
||||
// Should return ErrCircuitOpen because that was the lastErr set
|
||||
assert.ErrorIs(t, err, ErrCircuitOpen)
|
||||
}
|
||||
|
||||
// === FCM additional tests ===
|
||||
|
||||
func TestFCMV1Send_WithDataPayload(t *testing.T) {
|
||||
var receivedData map[string]string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req fcmV1Request
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
receivedData = req.Message.Data
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(fcmV1Response{Name: "projects/test/messages/0:12345"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
data := map[string]string{
|
||||
"task_id": "42",
|
||||
"action": "complete",
|
||||
"deep_link": "/tasks/42",
|
||||
}
|
||||
|
||||
err := client.Send(context.Background(), []string{"token"}, "Title", "Body", data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "42", receivedData["task_id"])
|
||||
assert.Equal(t, "complete", receivedData["action"])
|
||||
assert.Equal(t, "/tasks/42", receivedData["deep_link"])
|
||||
}
|
||||
|
||||
func TestFCMV1Send_ContextCancelled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// The server just hangs — context cancellation should cause the request to fail
|
||||
select {}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := newTestFCMClient(server.URL)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
err := client.Send(ctx, []string{"token"}, "Title", "Body", nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFCMSendError_ErrorFormatting(t *testing.T) {
|
||||
// Short token (no truncation)
|
||||
shortErr := &FCMSendError{
|
||||
Token: "abc",
|
||||
StatusCode: 500,
|
||||
ErrorCode: FCMErrInternal,
|
||||
Message: "server error",
|
||||
}
|
||||
assert.Contains(t, shortErr.Error(), "abc")
|
||||
assert.Contains(t, shortErr.Error(), "500")
|
||||
assert.Contains(t, shortErr.Error(), "INTERNAL")
|
||||
assert.Contains(t, shortErr.Error(), "server error")
|
||||
}
|
||||
|
||||
func TestParseFCMV1Error_MalformedJSON(t *testing.T) {
|
||||
result := parseFCMV1Error("token123", 500, []byte("not json"))
|
||||
assert.Equal(t, 500, result.StatusCode)
|
||||
assert.Contains(t, result.Message, "unparseable error response")
|
||||
}
|
||||
|
||||
func TestParseFCMV1Error_ValidJSON(t *testing.T) {
|
||||
body := `{"error":{"code":404,"message":"not found","status":"NOT_FOUND"}}`
|
||||
result := parseFCMV1Error("token123", 404, []byte(body))
|
||||
assert.Equal(t, 404, result.StatusCode)
|
||||
assert.Equal(t, FCMErrorCode("NOT_FOUND"), result.ErrorCode)
|
||||
assert.Equal(t, "not found", result.Message)
|
||||
}
|
||||
|
||||
// === Platform constants ===
|
||||
|
||||
func TestPlatformConstants(t *testing.T) {
|
||||
assert.Equal(t, "ios", PlatformIOS)
|
||||
assert.Equal(t, "android", PlatformAndroid)
|
||||
}
|
||||
205
internal/repositories/admin_repo_test.go
Normal file
205
internal/repositories/admin_repo_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestAdminRepository_Create(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{
|
||||
Email: "admin@test.com",
|
||||
FirstName: "Test",
|
||||
LastName: "Admin",
|
||||
Role: models.AdminRoleAdmin,
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
|
||||
err := repo.Create(admin)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, admin.ID)
|
||||
}
|
||||
|
||||
func TestAdminRepository_Create_Duplicate(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin1 := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin1.SetPassword("Password123"))
|
||||
err := repo.Create(admin1)
|
||||
require.NoError(t, err)
|
||||
|
||||
admin2 := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin2.SetPassword("Password123"))
|
||||
err = repo.Create(admin2)
|
||||
assert.ErrorIs(t, err, ErrAdminExists)
|
||||
}
|
||||
|
||||
func TestAdminRepository_FindByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
found, err := repo.FindByID(admin.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "admin@test.com", found.Email)
|
||||
}
|
||||
|
||||
func TestAdminRepository_FindByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
_, err := repo.FindByID(9999)
|
||||
assert.ErrorIs(t, err, ErrAdminNotFound)
|
||||
}
|
||||
|
||||
func TestAdminRepository_FindByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
found, err := repo.FindByEmail("admin@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, admin.ID, found.ID)
|
||||
}
|
||||
|
||||
func TestAdminRepository_FindByEmail_CaseInsensitive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "Admin@Test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
found, err := repo.FindByEmail("admin@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, admin.ID, found.ID)
|
||||
}
|
||||
|
||||
func TestAdminRepository_FindByEmail_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
_, err := repo.FindByEmail("nonexistent@test.com")
|
||||
assert.ErrorIs(t, err, ErrAdminNotFound)
|
||||
}
|
||||
|
||||
func TestAdminRepository_Update(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "admin@test.com", FirstName: "Old", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
admin.FirstName = "New"
|
||||
err := repo.Update(admin)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(admin.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "New", found.FirstName)
|
||||
}
|
||||
|
||||
func TestAdminRepository_Delete(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
err := repo.Delete(admin.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindByID(admin.ID)
|
||||
assert.ErrorIs(t, err, ErrAdminNotFound)
|
||||
}
|
||||
|
||||
func TestAdminRepository_UpdateLastLogin(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
assert.Nil(t, admin.LastLogin)
|
||||
|
||||
err := repo.UpdateLastLogin(admin.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(admin.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, found.LastLogin)
|
||||
}
|
||||
|
||||
func TestAdminRepository_List(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
admin := &models.AdminUser{
|
||||
Email: "admin" + string(rune('0'+i)) + "@test.com",
|
||||
Role: models.AdminRoleAdmin,
|
||||
}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
}
|
||||
|
||||
// Page 1
|
||||
admins, total, err := repo.List(1, 3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), total)
|
||||
assert.Len(t, admins, 3)
|
||||
|
||||
// Page 2
|
||||
admins, total, err = repo.List(2, 3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), total)
|
||||
assert.Len(t, admins, 2)
|
||||
}
|
||||
|
||||
func TestAdminRepository_ExistsByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "admin@test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
exists, err := repo.ExistsByEmail("admin@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = repo.ExistsByEmail("nonexistent@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestAdminRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewAdminRepository(db)
|
||||
|
||||
admin := &models.AdminUser{Email: "Admin@Test.com", Role: models.AdminRoleAdmin}
|
||||
require.NoError(t, admin.SetPassword("Password123"))
|
||||
require.NoError(t, repo.Create(admin))
|
||||
|
||||
exists, err := repo.ExistsByEmail("admin@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
356
internal/repositories/contractor_repo_coverage_test.go
Normal file
356
internal/repositories/contractor_repo_coverage_test.go
Normal file
@@ -0,0 +1,356 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestContractorRepository_Create(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
contractor := &models.Contractor{
|
||||
ResidenceID: &residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Name: "Mike's Plumbing",
|
||||
Company: "Mike's Plumbing Co.",
|
||||
Phone: "+1-555-1234",
|
||||
Email: "mike@plumbing.com",
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
err := repo.Create(contractor)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, contractor.ID)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
found, err := repo.FindByID(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, contractor.ID, found.ID)
|
||||
assert.Equal(t, "Test Contractor", found.Name)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
_, err := repo.FindByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByID_InactiveNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Inactive Contractor")
|
||||
|
||||
// Soft-delete
|
||||
db.Model(&models.Contractor{}).Where("id = ?", contractor.ID).Update("is_active", false)
|
||||
|
||||
_, err := repo.FindByID(contractor.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
otherResidence := testutil.CreateTestResidence(t, db, user.ID, "Other House")
|
||||
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
|
||||
testutil.CreateTestContractor(t, db, otherResidence.ID, user.ID, "Other Contractor")
|
||||
|
||||
contractors, err := repo.FindByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, contractors, 2)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByResidence_ExcludesInactive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
active := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Active Contractor")
|
||||
inactive := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Inactive Contractor")
|
||||
db.Model(&models.Contractor{}).Where("id = ?", inactive.ID).Update("is_active", false)
|
||||
|
||||
contractors, err := repo.FindByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, contractors, 1)
|
||||
assert.Equal(t, active.ID, contractors[0].ID)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByUser_WithResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Residence Contractor")
|
||||
|
||||
contractors, err := repo.FindByUser(user.ID, []uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, contractors, 1)
|
||||
assert.Equal(t, "Residence Contractor", contractors[0].Name)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByUser_NoResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Personal contractor with no residence
|
||||
personal := &models.Contractor{
|
||||
ResidenceID: nil,
|
||||
CreatedByID: user.ID,
|
||||
Name: "Personal Contractor",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(personal).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
contractors, err := repo.FindByUser(user.ID, []uint{})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, contractors, 1)
|
||||
assert.Equal(t, "Personal Contractor", contractors[0].Name)
|
||||
}
|
||||
|
||||
func TestContractorRepository_Update(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Original Name")
|
||||
|
||||
contractor.Name = "Updated Name"
|
||||
contractor.Phone = "+1-555-9999"
|
||||
err := repo.Update(contractor)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", found.Name)
|
||||
assert.Equal(t, "+1-555-9999", found.Phone)
|
||||
}
|
||||
|
||||
func TestContractorRepository_Delete(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
|
||||
|
||||
err := repo.Delete(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not be found (soft delete)
|
||||
_, err = repo.FindByID(contractor.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContractorRepository_CountByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
|
||||
|
||||
count, err := repo.CountByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
func TestContractorRepository_CountByResidenceIDs(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
testutil.CreateTestContractor(t, db, r1.ID, user.ID, "Contractor A")
|
||||
testutil.CreateTestContractor(t, db, r2.ID, user.ID, "Contractor B")
|
||||
testutil.CreateTestContractor(t, db, r2.ID, user.ID, "Contractor C")
|
||||
|
||||
count, err := repo.CountByResidenceIDs([]uint{r1.ID, r2.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestContractorRepository_CountByResidenceIDs_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
count, err := repo.CountByResidenceIDs([]uint{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestContractorRepository_SetSpecialties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Skilled Contractor")
|
||||
|
||||
// Get seeded specialties
|
||||
var specialties []models.ContractorSpecialty
|
||||
err := db.Find(&specialties).Error
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(specialties), 2)
|
||||
|
||||
// Set specialties
|
||||
err = repo.SetSpecialties(contractor.ID, []uint{specialties[0].ID, specialties[1].ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify via FindByID
|
||||
found, err := repo.FindByID(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, found.Specialties, 2)
|
||||
}
|
||||
|
||||
func TestContractorRepository_SetSpecialties_ClearsExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Skilled Contractor")
|
||||
|
||||
var specialties []models.ContractorSpecialty
|
||||
err := db.Find(&specialties).Error
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, len(specialties), 2)
|
||||
|
||||
// Set initial specialties
|
||||
err = repo.SetSpecialties(contractor.ID, []uint{specialties[0].ID, specialties[1].ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Clear all specialties
|
||||
err = repo.SetSpecialties(contractor.ID, []uint{})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, found.Specialties, 0)
|
||||
}
|
||||
|
||||
func TestContractorRepository_GetAllSpecialties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
specialties, err := repo.GetAllSpecialties()
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(specialties), 4) // Seeded 4 specialties
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindSpecialtyByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
var seeded models.ContractorSpecialty
|
||||
err := db.First(&seeded).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindSpecialtyByID(seeded.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, seeded.Name, found.Name)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindSpecialtyByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
_, err := repo.FindSpecialtyByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContractorRepository_GetTasksForContractor(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Task Contractor")
|
||||
|
||||
// Create tasks linked to contractor
|
||||
task1 := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Task 1",
|
||||
ContractorID: &contractor.ID,
|
||||
Version: 1,
|
||||
}
|
||||
err := db.Create(task1).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
task2 := &models.Task{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Task 2",
|
||||
ContractorID: &contractor.ID,
|
||||
Version: 1,
|
||||
}
|
||||
err = db.Create(task2).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
tasks, err := repo.GetTasksForContractor(contractor.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tasks, 2)
|
||||
}
|
||||
|
||||
func TestContractorRepository_FindByResidence_FavoritesFirst(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewContractorRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create non-favorite first (alphabetically first)
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Alpha Contractor")
|
||||
|
||||
// Create favorite
|
||||
fav := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Zeta Contractor")
|
||||
db.Model(&models.Contractor{}).Where("id = ?", fav.ID).Update("is_favorite", true)
|
||||
|
||||
contractors, err := repo.FindByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, contractors, 2)
|
||||
// Favorite should be first
|
||||
assert.Equal(t, fav.ID, contractors[0].ID)
|
||||
}
|
||||
384
internal/repositories/document_repo_coverage_test.go
Normal file
384
internal/repositories/document_repo_coverage_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestDocumentRepository_Create(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
doc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "HVAC Warranty",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/hvac.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
err := repo.Create(doc)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, doc.ID)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||
|
||||
found, err := repo.FindByID(doc.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, doc.ID, found.ID)
|
||||
assert.Equal(t, "Test Doc", found.Title)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
_, err := repo.FindByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByID_InactiveNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
|
||||
|
||||
db.Model(&models.Document{}).Where("id = ?", doc.ID).Update("is_active", false)
|
||||
|
||||
_, err := repo.FindByID(doc.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
otherResidence := testutil.CreateTestResidence(t, db, user.ID, "Other House")
|
||||
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc A")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc B")
|
||||
testutil.CreateTestDocument(t, db, otherResidence.ID, user.ID, "Other Doc")
|
||||
|
||||
docs, err := repo.FindByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 2)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByResidence_ExcludesInactive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
active := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
|
||||
inactive := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
|
||||
db.Model(&models.Document{}).Where("id = ?", inactive.ID).Update("is_active", false)
|
||||
|
||||
docs, err := repo.FindByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, active.ID, docs[0].ID)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
testutil.CreateTestDocument(t, db, r1.ID, user.ID, "Doc 1")
|
||||
testutil.CreateTestDocument(t, db, r2.ID, user.ID, "Doc 2")
|
||||
|
||||
docs, err := repo.FindByUser([]uint{r1.ID, r2.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 2)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_Update(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original Title")
|
||||
|
||||
doc.Title = "Updated Title"
|
||||
doc.Description = "Updated description"
|
||||
err := repo.Update(doc)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(doc.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Title", found.Title)
|
||||
assert.Equal(t, "Updated description", found.Description)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_Delete(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
|
||||
|
||||
err := repo.Delete(doc.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindByID(doc.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_ActivateDeactivate(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Toggle Doc")
|
||||
|
||||
// Deactivate
|
||||
err := repo.Deactivate(doc.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindByID(doc.ID)
|
||||
assert.Error(t, err) // Not found when inactive
|
||||
|
||||
// Activate
|
||||
err = repo.Activate(doc.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(doc.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, doc.ID, found.ID)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_CountByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 3")
|
||||
|
||||
count, err := repo.CountByResidence(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_CountByResidenceIDs(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
testutil.CreateTestDocument(t, db, r1.ID, user.ID, "Doc A")
|
||||
testutil.CreateTestDocument(t, db, r2.ID, user.ID, "Doc B")
|
||||
testutil.CreateTestDocument(t, db, r2.ID, user.ID, "Doc C")
|
||||
|
||||
count, err := repo.CountByResidenceIDs([]uint{r1.ID, r2.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_CountByResidenceIDs_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
count, err := repo.CountByResidenceIDs([]uint{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_DocumentImages(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc with Images")
|
||||
|
||||
// Create image
|
||||
img := &models.DocumentImage{
|
||||
DocumentID: doc.ID,
|
||||
ImageURL: "https://example.com/img1.jpg",
|
||||
Caption: "Front page",
|
||||
}
|
||||
err := repo.CreateDocumentImage(img)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, img.ID)
|
||||
|
||||
// Find image by ID
|
||||
found, err := repo.FindImageByID(img.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/img1.jpg", found.ImageURL)
|
||||
assert.Equal(t, "Front page", found.Caption)
|
||||
|
||||
// Delete single image
|
||||
err = repo.DeleteDocumentImage(img.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindImageByID(img.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_DeleteDocumentImages(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc with Images")
|
||||
|
||||
// Create multiple images
|
||||
for i := 0; i < 3; i++ {
|
||||
img := &models.DocumentImage{
|
||||
DocumentID: doc.ID,
|
||||
ImageURL: "https://example.com/img.jpg",
|
||||
}
|
||||
err := repo.CreateDocumentImage(img)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Delete all images for document
|
||||
err := repo.DeleteDocumentImages(doc.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all deleted
|
||||
var count int64
|
||||
db.Model(&models.DocumentImage{}).Where("document_id = ?", doc.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByIDIncludingInactive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Deactivated Doc")
|
||||
|
||||
// Deactivate
|
||||
db.Model(&models.Document{}).Where("id = ?", doc.ID).Update("is_active", false)
|
||||
|
||||
// FindByID should not find it
|
||||
_, err := repo.FindByID(doc.ID)
|
||||
assert.Error(t, err)
|
||||
|
||||
// FindByIDIncludingInactive should find it
|
||||
var found models.Document
|
||||
err = repo.FindByIDIncludingInactive(doc.ID, &found)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, doc.ID, found.ID)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUserFiltered_ByType(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create general document
|
||||
generalDoc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "General Doc",
|
||||
DocumentType: models.DocumentTypeGeneral,
|
||||
FileURL: "https://example.com/gen.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(generalDoc).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create warranty document
|
||||
warrantyDoc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Warranty Doc",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/warranty.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
err = db.Create(warrantyDoc).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Filter by warranty type
|
||||
filter := &DocumentFilter{
|
||||
DocumentType: string(models.DocumentTypeWarranty),
|
||||
}
|
||||
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "Warranty Doc", docs[0].Title)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUserFiltered_NilFilter(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
|
||||
|
||||
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 2)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindWarranties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// General doc (should not be returned)
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
||||
|
||||
// Warranty doc
|
||||
warranty := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Warranty",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/warranty.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
err := db.Create(warranty).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
docs, err := repo.FindWarranties([]uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "Warranty", docs[0].Title)
|
||||
}
|
||||
207
internal/repositories/document_repo_extended_test.go
Normal file
207
internal/repositories/document_repo_extended_test.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestDocumentRepository_FindExpiringWarranties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiringSoon := now.AddDate(0, 0, 15) // 15 days from now
|
||||
expiringLater := now.AddDate(0, 0, 60) // 60 days from now
|
||||
alreadyExpired := now.AddDate(0, 0, -5) // 5 days ago
|
||||
|
||||
// Warranty expiring in 15 days (within 30 day threshold)
|
||||
w1 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Expiring Warranty",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/w1.pdf",
|
||||
ExpiryDate: &expiringSoon,
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(w1).Error)
|
||||
|
||||
// Warranty expiring in 60 days (outside 30 day threshold)
|
||||
w2 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Later Warranty",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/w2.pdf",
|
||||
ExpiryDate: &expiringLater,
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(w2).Error)
|
||||
|
||||
// Already expired warranty
|
||||
w3 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Expired Warranty",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/w3.pdf",
|
||||
ExpiryDate: &alreadyExpired,
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(w3).Error)
|
||||
|
||||
// General document (not warranty)
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
||||
|
||||
docs, err := repo.FindExpiringWarranties([]uint{residence.ID}, 30)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "Expiring Warranty", docs[0].Title)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUserFiltered_Search(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite does not support ILIKE")
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
d1 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "HVAC Warranty Certificate",
|
||||
Description: "For the main HVAC unit",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/hvac.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(d1).Error)
|
||||
|
||||
d2 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Plumbing Receipt",
|
||||
Description: "Bathroom plumbing fix",
|
||||
DocumentType: models.DocumentTypeReceipt,
|
||||
FileURL: "https://example.com/plumbing.pdf",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(d2).Error)
|
||||
|
||||
// Search by title
|
||||
filter := &DocumentFilter{Search: "HVAC"}
|
||||
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "HVAC Warranty Certificate", docs[0].Title)
|
||||
|
||||
// Search by description
|
||||
filter = &DocumentFilter{Search: "bathroom"}
|
||||
docs, err = repo.FindByUserFiltered([]uint{residence.ID}, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "Plumbing Receipt", docs[0].Title)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUserFiltered_IsActiveOverride(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
active := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
|
||||
inactive := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
|
||||
db.Model(&models.Document{}).Where("id = ?", inactive.ID).Update("is_active", false)
|
||||
|
||||
// With IsActive = false, should get only inactive
|
||||
isActiveFalse := false
|
||||
filter := &DocumentFilter{IsActive: &isActiveFalse}
|
||||
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, inactive.ID, docs[0].ID)
|
||||
|
||||
// With IsActive = true, should get only active
|
||||
isActiveTrue := true
|
||||
filter = &DocumentFilter{IsActive: &isActiveTrue}
|
||||
docs, err = repo.FindByUserFiltered([]uint{residence.ID}, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, active.ID, docs[0].ID)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUserFiltered_ExpiringSoon(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiringSoon := now.AddDate(0, 0, 10)
|
||||
notExpiring := now.AddDate(0, 0, 90)
|
||||
|
||||
d1 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Expiring Doc",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/exp.pdf",
|
||||
ExpiryDate: &expiringSoon,
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(d1).Error)
|
||||
|
||||
d2 := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Not Expiring Doc",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/noexp.pdf",
|
||||
ExpiryDate: ¬Expiring,
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(d2).Error)
|
||||
|
||||
days := 30
|
||||
filter := &DocumentFilter{ExpiringSoon: &days}
|
||||
docs, err := repo.FindByUserFiltered([]uint{residence.ID}, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
assert.Equal(t, "Expiring Doc", docs[0].Title)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindImageByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
_, err := repo.FindImageByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDocumentRepository_FindByUser_ExcludesInactive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewDocumentRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Active Doc")
|
||||
inactive := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Inactive Doc")
|
||||
db.Model(&models.Document{}).Where("id = ?", inactive.ID).Update("is_active", false)
|
||||
|
||||
docs, err := repo.FindByUser([]uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, docs, 1)
|
||||
}
|
||||
510
internal/repositories/notification_repo_coverage_test.go
Normal file
510
internal/repositories/notification_repo_coverage_test.go
Normal file
@@ -0,0 +1,510 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestNotificationRepository_Create(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
notification := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Task Due Soon",
|
||||
Body: "Your task is due tomorrow",
|
||||
}
|
||||
|
||||
err := repo.Create(notification)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, notification.ID)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_FindByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
notification := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskOverdue,
|
||||
Title: "Task Overdue",
|
||||
Body: "Your task is overdue",
|
||||
}
|
||||
err := db.Create(notification).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(notification.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Task Overdue", found.Title)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_FindByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
_, err := repo.FindByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_FindByUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
otherUser := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
|
||||
// Create notifications for user
|
||||
for i := 0; i < 5; i++ {
|
||||
n := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Notification",
|
||||
Body: "Body",
|
||||
}
|
||||
err := db.Create(n).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create notification for other user
|
||||
otherN := &models.Notification{
|
||||
UserID: otherUser.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Other",
|
||||
Body: "Body",
|
||||
}
|
||||
err := db.Create(otherN).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find all for user
|
||||
notifications, err := repo.FindByUser(user.ID, 0, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, notifications, 5)
|
||||
|
||||
// Find with limit
|
||||
notifications, err = repo.FindByUser(user.ID, 3, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, notifications, 3)
|
||||
|
||||
// Find with offset
|
||||
notifications, err = repo.FindByUser(user.ID, 3, 3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, notifications, 2) // Only 2 remaining
|
||||
}
|
||||
|
||||
func TestNotificationRepository_MarkAsRead(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
notification := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskCompleted,
|
||||
Title: "Task Completed",
|
||||
Body: "Your task was completed",
|
||||
Read: false,
|
||||
}
|
||||
err := db.Create(notification).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.MarkAsRead(notification.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(notification.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.Read)
|
||||
assert.NotNil(t, found.ReadAt)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_MarkAllAsRead(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create multiple unread notifications
|
||||
for i := 0; i < 3; i++ {
|
||||
n := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Unread",
|
||||
Body: "Body",
|
||||
Read: false,
|
||||
}
|
||||
err := db.Create(n).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := repo.MarkAllAsRead(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all are read
|
||||
count, err := repo.CountUnread(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_CountUnread(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create 3 unread and 2 read notifications
|
||||
for i := 0; i < 3; i++ {
|
||||
n := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Unread",
|
||||
Body: "Body",
|
||||
Read: false,
|
||||
}
|
||||
err := db.Create(n).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
n := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskCompleted,
|
||||
Title: "Read",
|
||||
Body: "Body",
|
||||
Read: true,
|
||||
}
|
||||
err := db.Create(n).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
count, err := repo.CountUnread(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_MarkAsSent(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
notification := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "To Send",
|
||||
Body: "Body",
|
||||
Sent: false,
|
||||
}
|
||||
err := db.Create(notification).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.MarkAsSent(notification.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(notification.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.Sent)
|
||||
assert.NotNil(t, found.SentAt)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_SetError(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
notification := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Error Notification",
|
||||
Body: "Body",
|
||||
}
|
||||
err := db.Create(notification).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.SetError(notification.ID, "failed to send: connection timeout")
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(notification.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "failed to send: connection timeout", found.ErrorMessage)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_GetPendingNotifications(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create pending (unsent) notifications
|
||||
for i := 0; i < 5; i++ {
|
||||
n := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Pending",
|
||||
Body: "Body",
|
||||
Sent: false,
|
||||
}
|
||||
err := db.Create(n).Error
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create sent notification
|
||||
sent := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskCompleted,
|
||||
Title: "Already Sent",
|
||||
Body: "Body",
|
||||
Sent: true,
|
||||
}
|
||||
err := db.Create(sent).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
pending, err := repo.GetPendingNotifications(10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pending, 5) // Only unsent
|
||||
|
||||
// Test with limit
|
||||
pending, err = repo.GetPendingNotifications(3)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pending, 3)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_APNSDevice_CRUD(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create APNS device
|
||||
device := &models.APNSDevice{
|
||||
Name: "iPhone 15",
|
||||
Active: true,
|
||||
UserID: &user.ID,
|
||||
DeviceID: "device-uuid-123",
|
||||
RegistrationID: "apns-token-abc123",
|
||||
}
|
||||
err := repo.CreateAPNSDevice(device)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, device.ID)
|
||||
|
||||
// Find by ID
|
||||
found, err := repo.FindAPNSDeviceByID(device.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "iPhone 15", found.Name)
|
||||
assert.Equal(t, "apns-token-abc123", found.RegistrationID)
|
||||
|
||||
// Find by token
|
||||
foundByToken, err := repo.FindAPNSDeviceByToken("apns-token-abc123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, device.ID, foundByToken.ID)
|
||||
|
||||
// Find by user
|
||||
devices, err := repo.FindAPNSDevicesByUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, devices, 1)
|
||||
|
||||
// Update
|
||||
device.Name = "iPhone 16"
|
||||
err = repo.UpdateAPNSDevice(device)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err = repo.FindAPNSDeviceByID(device.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "iPhone 16", found.Name)
|
||||
|
||||
// Deactivate
|
||||
err = repo.DeactivateAPNSDevice(device.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not appear in active devices
|
||||
devices, err = repo.FindAPNSDevicesByUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, devices, 0)
|
||||
|
||||
// Delete
|
||||
err = repo.DeleteAPNSDevice(device.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindAPNSDeviceByID(device.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_GCMDevice_CRUD(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create GCM device
|
||||
device := &models.GCMDevice{
|
||||
Name: "Pixel 8",
|
||||
Active: true,
|
||||
UserID: &user.ID,
|
||||
DeviceID: "android-device-uuid",
|
||||
RegistrationID: "fcm-token-xyz789",
|
||||
CloudMessageType: "FCM",
|
||||
}
|
||||
err := repo.CreateGCMDevice(device)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, device.ID)
|
||||
|
||||
// Find by ID
|
||||
found, err := repo.FindGCMDeviceByID(device.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Pixel 8", found.Name)
|
||||
assert.Equal(t, "fcm-token-xyz789", found.RegistrationID)
|
||||
|
||||
// Find by token
|
||||
foundByToken, err := repo.FindGCMDeviceByToken("fcm-token-xyz789")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, device.ID, foundByToken.ID)
|
||||
|
||||
// Find by user
|
||||
devices, err := repo.FindGCMDevicesByUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, devices, 1)
|
||||
|
||||
// Update
|
||||
device.Name = "Pixel 9"
|
||||
err = repo.UpdateGCMDevice(device)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err = repo.FindGCMDeviceByID(device.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Pixel 9", found.Name)
|
||||
|
||||
// Deactivate
|
||||
err = repo.DeactivateGCMDevice(device.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
devices, err = repo.FindGCMDevicesByUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, devices, 0)
|
||||
|
||||
// Delete
|
||||
err = repo.DeleteGCMDevice(device.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindGCMDeviceByID(device.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_GetActiveTokensForUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create APNS devices
|
||||
apns1 := &models.APNSDevice{
|
||||
Active: true,
|
||||
UserID: &user.ID,
|
||||
RegistrationID: "ios-token-1",
|
||||
}
|
||||
err := db.Create(apns1).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
apns2 := &models.APNSDevice{
|
||||
Active: true,
|
||||
UserID: &user.ID,
|
||||
RegistrationID: "ios-token-2",
|
||||
}
|
||||
err = db.Create(apns2).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create inactive APNS device (should not be returned)
|
||||
// Insert inactive device via raw SQL to bypass GORM's default:true override
|
||||
err = db.Exec("INSERT INTO push_notifications_apnsdevice (active, user_id, registration_id, date_created) VALUES (false, ?, 'ios-token-inactive', CURRENT_TIMESTAMP)", user.ID).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create GCM device
|
||||
gcm1 := &models.GCMDevice{
|
||||
Active: true,
|
||||
UserID: &user.ID,
|
||||
RegistrationID: "android-token-1",
|
||||
CloudMessageType: "FCM",
|
||||
}
|
||||
err = db.Create(gcm1).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
iosTokens, androidTokens, err := repo.GetActiveTokensForUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, iosTokens, 2)
|
||||
assert.Contains(t, iosTokens, "ios-token-1")
|
||||
assert.Contains(t, iosTokens, "ios-token-2")
|
||||
assert.Len(t, androidTokens, 1)
|
||||
assert.Contains(t, androidTokens, "android-token-1")
|
||||
}
|
||||
|
||||
func TestNotificationRepository_GetActiveTokensForUser_NoDevices(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
iosTokens, androidTokens, err := repo.GetActiveTokensForUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, iosTokens)
|
||||
assert.Empty(t, androidTokens)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_FindPreferencesByUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create preferences with defaults (GORM applies default:true for bools)
|
||||
prefs := &models.NotificationPreference{
|
||||
UserID: user.ID,
|
||||
}
|
||||
err := db.Create(prefs).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update one field to false via raw SQL to bypass GORM default
|
||||
err = db.Exec("UPDATE notifications_notificationpreference SET task_due_soon = false WHERE user_id = ?", user.ID).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindPreferencesByUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
assert.False(t, found.TaskDueSoon)
|
||||
assert.True(t, found.TaskOverdue)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_FindPreferencesByUser_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
_, err := repo.FindPreferencesByUser(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNotificationRepository_UpdatePreferences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewNotificationRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
prefs, err := repo.GetOrCreatePreferences(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
prefs.TaskDueSoon = false
|
||||
prefs.DailyDigest = false
|
||||
err = repo.UpdatePreferences(prefs)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindPreferencesByUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.TaskDueSoon)
|
||||
assert.False(t, found.DailyDigest)
|
||||
}
|
||||
217
internal/repositories/reminder_repo_test.go
Normal file
217
internal/repositories/reminder_repo_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestReminderRepository_LogReminder(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
logEntry, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, logEntry.ID)
|
||||
assert.Equal(t, task.ID, logEntry.TaskID)
|
||||
assert.Equal(t, user.ID, logEntry.UserID)
|
||||
assert.Equal(t, models.ReminderStage7Days, logEntry.ReminderStage)
|
||||
}
|
||||
|
||||
func TestReminderRepository_HasSentReminder(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
dueDate := time.Date(2026, 4, 15, 12, 30, 0, 0, time.UTC) // Has time component
|
||||
|
||||
// Not sent yet
|
||||
sent, err := repo.HasSentReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sent)
|
||||
|
||||
// Log it
|
||||
_, err = repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now should be sent
|
||||
sent, err = repo.HasSentReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, sent)
|
||||
|
||||
// Different stage should not be sent
|
||||
sent, err = repo.HasSentReminder(task.ID, user.ID, dueDate, models.ReminderStage3Days)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, sent)
|
||||
}
|
||||
|
||||
func TestReminderRepository_HasSentReminderBatch(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task1 := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 1")
|
||||
task2 := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 2")
|
||||
|
||||
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Log reminder for task1
|
||||
_, err := repo.LogReminder(task1.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys := []ReminderKey{
|
||||
{TaskID: task1.ID, UserID: user.ID, DueDate: dueDate, Stage: models.ReminderStage7Days},
|
||||
{TaskID: task2.ID, UserID: user.ID, DueDate: dueDate, Stage: models.ReminderStage7Days},
|
||||
}
|
||||
|
||||
result, err := repo.HasSentReminderBatch(keys)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result[0], "task1 reminder should be sent")
|
||||
assert.False(t, result[1], "task2 reminder should not be sent")
|
||||
}
|
||||
|
||||
func TestReminderRepository_HasSentReminderBatch_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
result, err := repo.HasSentReminderBatch([]ReminderKey{})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestReminderRepository_GetSentRemindersForTask(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
_, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage3Days, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
logs, err := repo.GetSentRemindersForTask(task.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, logs, 2)
|
||||
}
|
||||
|
||||
func TestReminderRepository_GetSentRemindersForDueDate(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
dueDate1 := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
dueDate2 := time.Date(2026, 5, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
_, err := repo.LogReminder(task.ID, user.ID, dueDate1, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.LogReminder(task.ID, user.ID, dueDate2, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
logs, err := repo.GetSentRemindersForDueDate(task.ID, user.ID, dueDate1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, logs, 1)
|
||||
}
|
||||
|
||||
func TestReminderRepository_CleanupOldLogs(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
// Create old log entry (100 days ago)
|
||||
oldLog := &models.TaskReminderLog{
|
||||
TaskID: task.ID,
|
||||
UserID: user.ID,
|
||||
DueDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
ReminderStage: models.ReminderStage7Days,
|
||||
SentAt: time.Now().UTC().AddDate(0, 0, -100),
|
||||
}
|
||||
require.NoError(t, db.Create(oldLog).Error)
|
||||
|
||||
// Create recent log entry
|
||||
recentLog := &models.TaskReminderLog{
|
||||
TaskID: task.ID,
|
||||
UserID: user.ID,
|
||||
DueDate: time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC),
|
||||
ReminderStage: models.ReminderStage3Days,
|
||||
SentAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, db.Create(recentLog).Error)
|
||||
|
||||
deleted, err := repo.CleanupOldLogs(90)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), deleted)
|
||||
|
||||
// Recent log should still exist
|
||||
var count int64
|
||||
db.Model(&models.TaskReminderLog{}).Count(&count)
|
||||
assert.Equal(t, int64(1), count)
|
||||
}
|
||||
|
||||
func TestReminderRepository_GetRecentReminderStats(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
// Create recent reminders
|
||||
_, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage3Days, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats, err := repo.GetRecentReminderStats(24) // last 24 hours
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), stats[string(models.ReminderStage7Days)])
|
||||
assert.Equal(t, int64(1), stats[string(models.ReminderStage3Days)])
|
||||
}
|
||||
|
||||
func TestReminderRepository_LogReminder_WithNotificationID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewReminderRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Fix Roof")
|
||||
|
||||
// Create a notification
|
||||
notif := &models.Notification{
|
||||
UserID: user.ID,
|
||||
NotificationType: models.NotificationTaskDueSoon,
|
||||
Title: "Test",
|
||||
Body: "Test body",
|
||||
}
|
||||
require.NoError(t, db.Create(notif).Error)
|
||||
|
||||
dueDate := time.Date(2026, 4, 15, 0, 0, 0, 0, time.UTC)
|
||||
logEntry, err := repo.LogReminder(task.ID, user.ID, dueDate, models.ReminderStage7Days, ¬if.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, logEntry.NotificationID)
|
||||
assert.Equal(t, notif.ID, *logEntry.NotificationID)
|
||||
}
|
||||
216
internal/repositories/residence_repo_coverage_test.go
Normal file
216
internal/repositories/residence_repo_coverage_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestResidenceRepository_FindByIDSimple(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
found, err := repo.FindByIDSimple(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, residence.ID, found.ID)
|
||||
assert.Equal(t, "Test House", found.Name)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_FindByIDSimple_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
_, err := repo.FindByIDSimple(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_FindByIDSimple_InactiveNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
db.Model(residence).Update("is_active", false)
|
||||
|
||||
_, err := repo.FindByIDSimple(residence.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_FindResidenceIDsByUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
sharedUser := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
|
||||
r1 := testutil.CreateTestResidence(t, db, owner.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, owner.ID, "House 2")
|
||||
|
||||
repo.AddUser(r1.ID, sharedUser.ID)
|
||||
|
||||
// Owner should see both
|
||||
ownerIDs, err := repo.FindResidenceIDsByUser(owner.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ownerIDs, 2)
|
||||
|
||||
// Shared user should see one
|
||||
sharedIDs, err := repo.FindResidenceIDsByUser(sharedUser.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, sharedIDs, 1)
|
||||
assert.Equal(t, r1.ID, sharedIDs[0])
|
||||
|
||||
_ = r2 // suppress unused
|
||||
}
|
||||
|
||||
func TestResidenceRepository_FindResidenceIDsByOwner(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
otherOwner := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
|
||||
testutil.CreateTestResidence(t, db, owner.ID, "House 1")
|
||||
testutil.CreateTestResidence(t, db, owner.ID, "House 2")
|
||||
testutil.CreateTestResidence(t, db, otherOwner.ID, "Other House")
|
||||
|
||||
ids, err := repo.FindResidenceIDsByOwner(owner.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ids, 2)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_DeactivateShareCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
// Create a share code
|
||||
created, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000) // 24h in ns
|
||||
require.NoError(t, err)
|
||||
assert.True(t, created.IsActive)
|
||||
|
||||
// Deactivate it
|
||||
err = repo.DeactivateShareCode(created.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// FindShareCodeByCode should not find it
|
||||
_, err = repo.FindShareCodeByCode(created.Code)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_GetActiveShareCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
// No active code initially
|
||||
code, err := repo.GetActiveShareCode(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, code)
|
||||
|
||||
// Create a share code
|
||||
created, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should find the active code
|
||||
code, err = repo.GetActiveShareCode(residence.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, code)
|
||||
assert.Equal(t, created.Code, code.Code)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_CreateShareCode_DeactivatesPrevious(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
// Create first code
|
||||
first, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create second code (should deactivate first)
|
||||
second, err := repo.CreateShareCode(residence.ID, owner.ID, 24*60*60*1000*1000*1000)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, first.Code, second.Code)
|
||||
|
||||
// First code should no longer be findable
|
||||
_, err = repo.FindShareCodeByCode(first.Code)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Second code should be active
|
||||
found, err := repo.FindShareCodeByCode(second.Code)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, second.Code, found.Code)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_FindResidenceTypeByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
// Get first residence type
|
||||
types, err := repo.GetAllResidenceTypes()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, types)
|
||||
|
||||
found, err := repo.FindResidenceTypeByID(types[0].ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, types[0].Name, found.Name)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_FindResidenceTypeByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
_, err := repo.FindResidenceTypeByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_GetTasksForReport(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 1")
|
||||
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 2")
|
||||
|
||||
tasks, err := repo.GetTasksForReport(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tasks, 2)
|
||||
}
|
||||
|
||||
func TestResidenceRepository_AddUser_Idempotent(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
sharedUser := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
// Add same user twice (should not error due to ON CONFLICT DO NOTHING)
|
||||
err := repo.AddUser(residence.ID, sharedUser.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.AddUser(residence.ID, sharedUser.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should still only count once
|
||||
users, err := repo.GetResidenceUsers(residence.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 2) // owner + sharedUser
|
||||
}
|
||||
418
internal/repositories/subscription_repo_coverage_test.go
Normal file
418
internal/repositories/subscription_repo_coverage_test.go
Normal file
@@ -0,0 +1,418 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestSubscriptionRepository_FindByUserID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierPro, found.Tier)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByUserID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
_, err := repo.FindByUserID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_Update(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
sub, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
sub.Tier = models.TierPro
|
||||
err = repo.Update(sub)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierPro, found.Tier)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_UpgradeToPro(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
expiresAt := time.Now().UTC().AddDate(1, 0, 0)
|
||||
err = repo.UpgradeToPro(user.ID, expiresAt, "apple")
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierPro, found.Tier)
|
||||
assert.Equal(t, "apple", found.Platform)
|
||||
assert.True(t, found.AutoRenew)
|
||||
require.NotNil(t, found.SubscribedAt)
|
||||
require.NotNil(t, found.ExpiresAt)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_DowngradeToFree(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First upgrade
|
||||
err = repo.UpgradeToPro(user.ID, time.Now().UTC().AddDate(1, 0, 0), "apple")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then downgrade
|
||||
err = repo.DowngradeToFree(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.TierFree, found.Tier)
|
||||
assert.False(t, found.AutoRenew)
|
||||
require.NotNil(t, found.CancelledAt)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_SetAutoRenew(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.SetAutoRenew(user.ID, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, found.AutoRenew)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_UpdateReceiptData(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.UpdateReceiptData(user.ID, "receipt_data_abc123")
|
||||
require.NoError(t, err)
|
||||
|
||||
var sub models.UserSubscription
|
||||
db.Where("user_id = ?", user.ID).First(&sub)
|
||||
require.NotNil(t, sub.AppleReceiptData)
|
||||
assert.Equal(t, "receipt_data_abc123", *sub.AppleReceiptData)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_UpdatePurchaseToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.UpdatePurchaseToken(user.ID, "google_token_xyz")
|
||||
require.NoError(t, err)
|
||||
|
||||
var sub models.UserSubscription
|
||||
db.Where("user_id = ?", user.ID).First(&sub)
|
||||
require.NotNil(t, sub.GooglePurchaseToken)
|
||||
assert.Equal(t, "google_token_xyz", *sub.GooglePurchaseToken)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByAppleReceiptContains(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
// Use raw SQL to ensure apple_receipt_data is stored correctly (GORM may omit pointer fields)
|
||||
require.NoError(t, db.Exec(
|
||||
"INSERT INTO subscription_usersubscription (user_id, tier, apple_receipt_data, created_at, updated_at) VALUES (?, ?, ?, datetime('now'), datetime('now'))",
|
||||
user.ID, models.TierPro, "transactionid=txnabc123data",
|
||||
).Error)
|
||||
|
||||
// Avoid underscores in search term: escapeLikeWildcards escapes _ with \
|
||||
// which PostgreSQL handles but SQLite does not (no default ESCAPE clause)
|
||||
found, err := repo.FindByAppleReceiptContains("txnabc123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByAppleReceiptContains_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
_, err := repo.FindByAppleReceiptContains("nonexistent_txn")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByGoogleToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
purchaseToken := "google_purchase_abc"
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
GooglePurchaseToken: &purchaseToken,
|
||||
}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
found, err := repo.FindByGoogleToken("google_purchase_abc")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByGoogleToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
_, err := repo.FindByGoogleToken("nonexistent_token")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_SetCancelledAt(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancelTime := time.Now().UTC()
|
||||
err = repo.SetCancelledAt(user.ID, cancelTime)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found.CancelledAt)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_ClearCancelledAt(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
_, err := repo.GetOrCreate(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set then clear
|
||||
err = repo.SetCancelledAt(user.ID, time.Now().UTC())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.ClearCancelledAt(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, found.CancelledAt)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetTierLimits_Defaults(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
// No tier limits in DB, should return defaults
|
||||
freeLimits, err := repo.GetTierLimits(models.TierFree)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, freeLimits)
|
||||
|
||||
proLimits, err := repo.GetTierLimits(models.TierPro)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, proLimits)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetAllTierLimits(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
// No limits seeded = empty
|
||||
limits, err := repo.GetAllTierLimits()
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, limits)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetSettings_Defaults(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
// No settings in DB, should return default (limitations disabled)
|
||||
settings, err := repo.GetSettings()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, settings)
|
||||
assert.False(t, settings.EnableLimitations)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetUpgradeTrigger(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
trigger := &models.UpgradeTrigger{
|
||||
TriggerKey: "max_residences",
|
||||
Title: "Upgrade for more homes",
|
||||
Message: "Free tier supports 1 home",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(trigger).Error)
|
||||
|
||||
found, err := repo.GetUpgradeTrigger("max_residences")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "max_residences", found.TriggerKey)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetUpgradeTrigger_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
_, err := repo.GetUpgradeTrigger("nonexistent_key")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetAllUpgradeTriggers(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
trigger1 := &models.UpgradeTrigger{TriggerKey: "key1", Title: "T1", Message: "M1", IsActive: true}
|
||||
trigger2 := &models.UpgradeTrigger{TriggerKey: "key2", Title: "T2", Message: "M2", IsActive: true}
|
||||
require.NoError(t, db.Create(trigger1).Error)
|
||||
require.NoError(t, db.Create(trigger2).Error)
|
||||
// Use raw SQL for inactive record to avoid GORM default:true overriding IsActive=false
|
||||
require.NoError(t, db.Exec(
|
||||
"INSERT INTO subscription_upgradetrigger (trigger_key, title, message, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, datetime('now'), datetime('now'))",
|
||||
"key3", "T3", "M3", false,
|
||||
).Error)
|
||||
|
||||
triggers, err := repo.GetAllUpgradeTriggers()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, triggers, 2) // Only active
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetFeatureBenefits(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
b1 := &models.FeatureBenefit{FeatureName: "Benefit 1", FreeTierText: "1 home", ProTierText: "Unlimited", DisplayOrder: 1, IsActive: true}
|
||||
require.NoError(t, db.Create(b1).Error)
|
||||
// Use raw SQL for inactive record to avoid GORM default:true overriding IsActive=false
|
||||
require.NoError(t, db.Exec(
|
||||
"INSERT INTO subscription_featurebenefit (feature_name, free_tier_text, pro_tier_text, display_order, is_active, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))",
|
||||
"Benefit 2", "3 tasks", "Unlimited", 2, false,
|
||||
).Error)
|
||||
|
||||
benefits, err := repo.GetFeatureBenefits()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, benefits, 1) // Only active
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetActivePromotions(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
now := time.Now().UTC()
|
||||
active := &models.Promotion{
|
||||
PromotionID: "promo_active",
|
||||
TargetTier: models.TierPro,
|
||||
Title: "Active Promo",
|
||||
Message: "Get Pro now!",
|
||||
StartDate: now.AddDate(0, 0, -1),
|
||||
EndDate: now.AddDate(0, 0, 1),
|
||||
IsActive: true,
|
||||
}
|
||||
expired := &models.Promotion{
|
||||
PromotionID: "promo_expired",
|
||||
TargetTier: models.TierPro,
|
||||
Title: "Expired Promo",
|
||||
Message: "Too late!",
|
||||
StartDate: now.AddDate(0, 0, -10),
|
||||
EndDate: now.AddDate(0, 0, -5),
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(active).Error)
|
||||
require.NoError(t, db.Create(expired).Error)
|
||||
|
||||
promos, err := repo.GetActivePromotions(models.TierPro)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, promos, 1)
|
||||
assert.Equal(t, "promo_active", promos[0].PromotionID)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetPromotionByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
promo := &models.Promotion{
|
||||
PromotionID: "promo_find",
|
||||
TargetTier: models.TierPro,
|
||||
Title: "Find Me",
|
||||
Message: "Found!",
|
||||
StartDate: time.Now().UTC(),
|
||||
EndDate: time.Now().UTC().AddDate(0, 0, 30),
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, db.Create(promo).Error)
|
||||
|
||||
found, err := repo.GetPromotionByID("promo_find")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Find Me", found.Title)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_GetPromotionByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
_, err := repo.GetPromotionByID("nonexistent")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByStripeSubscriptionID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
subID := "sub_test_123"
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
StripeSubscriptionID: &subID,
|
||||
}
|
||||
require.NoError(t, db.Create(sub).Error)
|
||||
|
||||
found, err := repo.FindByStripeSubscriptionID("sub_test_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestSubscriptionRepository_FindByStripeSubscriptionID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewSubscriptionRepository(db)
|
||||
|
||||
_, err := repo.FindByStripeSubscriptionID("sub_nonexistent")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
516
internal/repositories/task_repo_coverage_test.go
Normal file
516
internal/repositories/task_repo_coverage_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
// === UpdateTx / Version Conflict Tests ===
|
||||
|
||||
func TestTaskRepository_UpdateTx_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Versioned Task")
|
||||
|
||||
// Simulate stale version by setting wrong version
|
||||
task.Version = 999
|
||||
task.Title = "Should Fail"
|
||||
|
||||
tx := db.Begin()
|
||||
err := repo.UpdateTx(tx, task)
|
||||
tx.Rollback()
|
||||
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
func TestTaskRepository_UpdateTx_Success(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Versioned Task")
|
||||
|
||||
originalVersion := task.Version
|
||||
task.Title = "Updated via Tx"
|
||||
|
||||
tx := db.Begin()
|
||||
err := repo.UpdateTx(tx, task)
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
|
||||
assert.Equal(t, originalVersion+1, task.Version)
|
||||
|
||||
found, err := repo.FindByID(task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated via Tx", found.Title)
|
||||
}
|
||||
|
||||
func TestTaskRepository_Update_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Versioned Task")
|
||||
|
||||
task.Version = 999 // stale version
|
||||
task.Title = "Should Fail"
|
||||
err := repo.Update(task)
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
// === Version Conflict on State Operations ===
|
||||
|
||||
func TestTaskRepository_Cancel_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
err := repo.Cancel(task.ID, 999) // wrong version
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
func TestTaskRepository_Uncancel_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
err := repo.Uncancel(task.ID, 999)
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
func TestTaskRepository_Archive_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
err := repo.Archive(task.ID, 999)
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
func TestTaskRepository_Unarchive_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
err := repo.Unarchive(task.ID, 999)
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
func TestTaskRepository_MarkInProgress(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
err := repo.MarkInProgress(task.ID, task.Version)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(task.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, found.InProgress)
|
||||
}
|
||||
|
||||
func TestTaskRepository_MarkInProgress_VersionConflict(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
err := repo.MarkInProgress(task.ID, 999)
|
||||
assert.ErrorIs(t, err, ErrVersionConflict)
|
||||
}
|
||||
|
||||
// === CreateCompletionTx ===
|
||||
|
||||
func TestTaskRepository_CreateCompletionTx(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
tx := db.Begin()
|
||||
completion := &models.TaskCompletion{
|
||||
TaskID: task.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: time.Now().UTC(),
|
||||
Notes: "Done via tx",
|
||||
}
|
||||
err := repo.CreateCompletionTx(tx, completion)
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
|
||||
assert.NotZero(t, completion.ID)
|
||||
|
||||
found, err := repo.FindCompletionByID(completion.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Done via tx", found.Notes)
|
||||
}
|
||||
|
||||
// === GetFrequencyByID ===
|
||||
|
||||
func TestTaskRepository_GetFrequencyByID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
frequencies, err := repo.GetAllFrequencies()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, frequencies)
|
||||
|
||||
found, err := repo.GetFrequencyByID(frequencies[0].ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, frequencies[0].Name, found.Name)
|
||||
}
|
||||
|
||||
func TestTaskRepository_GetFrequencyByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
_, err := repo.GetFrequencyByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === FindByUser ===
|
||||
|
||||
func TestTaskRepository_FindByUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task A")
|
||||
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task B")
|
||||
|
||||
tasks, err := repo.FindByUser(user.ID, []uint{r1.ID, r2.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tasks, 2)
|
||||
}
|
||||
|
||||
// === CountByResidenceIDs ===
|
||||
|
||||
func TestTaskRepository_CountByResidenceIDs(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task 1")
|
||||
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 2")
|
||||
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 3")
|
||||
|
||||
count, err := repo.CountByResidenceIDs([]uint{r1.ID, r2.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), count)
|
||||
}
|
||||
|
||||
func TestTaskRepository_CountByResidenceIDs_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
count, err := repo.CountByResidenceIDs([]uint{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
// === FindCompletionsByUser ===
|
||||
|
||||
func TestTaskRepository_FindCompletionsByUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
c := &models.TaskCompletion{
|
||||
TaskID: task.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, db.Create(c).Error)
|
||||
|
||||
completions, err := repo.FindCompletionsByUser(user.ID, []uint{residence.ID})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, completions, 1)
|
||||
}
|
||||
|
||||
// === UpdateCompletion ===
|
||||
|
||||
func TestTaskRepository_UpdateCompletion(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
completion := &models.TaskCompletion{
|
||||
TaskID: task.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: time.Now().UTC(),
|
||||
Notes: "Original",
|
||||
}
|
||||
require.NoError(t, db.Create(completion).Error)
|
||||
|
||||
completion.Notes = "Updated notes"
|
||||
err := repo.UpdateCompletion(completion)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindCompletionByID(completion.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated notes", found.Notes)
|
||||
}
|
||||
|
||||
// === CompletionImage CRUD ===
|
||||
|
||||
func TestTaskRepository_CompletionImage_CRUD(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
completion := &models.TaskCompletion{
|
||||
TaskID: task.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, db.Create(completion).Error)
|
||||
|
||||
// Create image
|
||||
img := &models.TaskCompletionImage{
|
||||
CompletionID: completion.ID,
|
||||
ImageURL: "https://example.com/img.jpg",
|
||||
}
|
||||
err := repo.CreateCompletionImage(img)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, img.ID)
|
||||
|
||||
// Find image
|
||||
found, err := repo.FindCompletionImageByID(img.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/img.jpg", found.ImageURL)
|
||||
|
||||
// Delete image
|
||||
err = repo.DeleteCompletionImage(img.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindCompletionImageByID(img.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTaskRepository_FindCompletionImageByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
_, err := repo.FindCompletionImageByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === GetOverdueCountByResidence ===
|
||||
|
||||
func TestTaskRepository_GetOverdueCountByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
now := time.Now().UTC()
|
||||
pastDue := now.AddDate(0, 0, -5)
|
||||
|
||||
// Overdue task in r1
|
||||
t1 := &models.Task{
|
||||
ResidenceID: r1.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Overdue in r1",
|
||||
DueDate: &pastDue,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
|
||||
// Not overdue task in r2 (future)
|
||||
futureDue := now.AddDate(0, 0, 30)
|
||||
t2 := &models.Task{
|
||||
ResidenceID: r2.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Future in r2",
|
||||
DueDate: &futureDue,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
|
||||
countMap, err := repo.GetOverdueCountByResidence([]uint{r1.ID, r2.ID}, now)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, countMap[r1.ID])
|
||||
assert.Equal(t, 0, countMap[r2.ID])
|
||||
}
|
||||
|
||||
func TestTaskRepository_GetOverdueCountByResidence_EmptyIDs(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
countMap, err := repo.GetOverdueCountByResidence([]uint{}, time.Now().UTC())
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, countMap)
|
||||
}
|
||||
|
||||
// === GetKanbanDataForMultipleResidences ===
|
||||
|
||||
func TestTaskRepository_GetKanbanDataForMultipleResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task 1")
|
||||
testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 2")
|
||||
|
||||
board, err := repo.GetKanbanDataForMultipleResidences([]uint{r1.ID, r2.ID}, 30, time.Now().UTC())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "all", board.ResidenceID)
|
||||
assert.Len(t, board.Columns, 5)
|
||||
|
||||
// Both tasks should appear in upcoming (no due date)
|
||||
var upcomingCol *models.KanbanColumn
|
||||
for i := range board.Columns {
|
||||
if board.Columns[i].Name == "upcoming_tasks" {
|
||||
upcomingCol = &board.Columns[i]
|
||||
}
|
||||
}
|
||||
require.NotNil(t, upcomingCol)
|
||||
assert.Equal(t, 2, upcomingCol.Count)
|
||||
}
|
||||
|
||||
// === DB() accessor ===
|
||||
|
||||
func TestTaskRepository_DB(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
assert.NotNil(t, repo.DB())
|
||||
}
|
||||
|
||||
// === GetBatchCompletionSummaries ===
|
||||
|
||||
func TestTaskRepository_GetBatchCompletionSummaries_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
result, err := repo.GetBatchCompletionSummaries([]uint{}, time.Now().UTC(), 10)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestTaskRepository_GetBatchCompletionSummaries_WithData(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
r1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
r2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
task1 := testutil.CreateTestTask(t, db, r1.ID, user.ID, "Task 1")
|
||||
task2 := testutil.CreateTestTask(t, db, r2.ID, user.ID, "Task 2")
|
||||
|
||||
now := time.Date(2026, 3, 15, 0, 0, 0, 0, time.UTC)
|
||||
c1 := models.TaskCompletion{
|
||||
TaskID: task1.ID, CompletedByID: user.ID,
|
||||
CompletedAt: time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC), CompletedFromColumn: "completed_tasks",
|
||||
}
|
||||
require.NoError(t, db.Create(&c1).Error)
|
||||
|
||||
c2 := models.TaskCompletion{
|
||||
TaskID: task2.ID, CompletedByID: user.ID,
|
||||
CompletedAt: time.Date(2026, 1, 15, 12, 0, 0, 0, time.UTC), CompletedFromColumn: "overdue_tasks",
|
||||
}
|
||||
require.NoError(t, db.Create(&c2).Error)
|
||||
|
||||
result, err := repo.GetBatchCompletionSummaries([]uint{r1.ID, r2.ID}, now, 10)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
|
||||
assert.Equal(t, 1, result[r1.ID].TotalAllTime)
|
||||
assert.Equal(t, 1, result[r2.ID].TotalAllTime)
|
||||
}
|
||||
|
||||
// === FindCompletionByID not found ===
|
||||
|
||||
func TestTaskRepository_FindCompletionByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
_, err := repo.FindCompletionByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === DeleteCompletion deletes images ===
|
||||
|
||||
func TestTaskRepository_DeleteCompletion_DeletesImages(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task := testutil.CreateTestTask(t, db, residence.ID, user.ID, "Test Task")
|
||||
|
||||
completion := &models.TaskCompletion{
|
||||
TaskID: task.ID,
|
||||
CompletedByID: user.ID,
|
||||
CompletedAt: time.Now().UTC(),
|
||||
}
|
||||
require.NoError(t, db.Create(completion).Error)
|
||||
|
||||
// Add images
|
||||
img := &models.TaskCompletionImage{
|
||||
CompletionID: completion.ID,
|
||||
ImageURL: "https://example.com/img.jpg",
|
||||
}
|
||||
require.NoError(t, db.Create(img).Error)
|
||||
|
||||
// Delete completion (should cascade to images)
|
||||
err := repo.DeleteCompletion(completion.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Images should be gone
|
||||
var count int64
|
||||
db.Model(&models.TaskCompletionImage{}).Where("completion_id = ?", completion.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
236
internal/repositories/task_template_repo_test.go
Normal file
236
internal/repositories/task_template_repo_test.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestTaskTemplateRepository_Create(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
var cat models.TaskCategory
|
||||
require.NoError(t, db.First(&cat).Error)
|
||||
|
||||
var freq models.TaskFrequency
|
||||
require.NoError(t, db.First(&freq).Error)
|
||||
|
||||
template := &models.TaskTemplate{
|
||||
Title: "Change HVAC Filter",
|
||||
Description: "Replace the HVAC air filter",
|
||||
CategoryID: &cat.ID,
|
||||
FrequencyID: &freq.ID,
|
||||
IsActive: true,
|
||||
Tags: "hvac,filter,maintenance",
|
||||
}
|
||||
err := repo.Create(template)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, template.ID)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetAll(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
// Create active and inactive templates
|
||||
t1 := &models.TaskTemplate{Title: "Active Template", IsActive: true}
|
||||
t2 := &models.TaskTemplate{Title: "Inactive Template", IsActive: false}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
|
||||
templates, err := repo.GetAll()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, templates, 1) // Only active
|
||||
assert.Equal(t, "Active Template", templates[0].Title)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetByCategory(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
var cat models.TaskCategory
|
||||
require.NoError(t, db.First(&cat).Error)
|
||||
|
||||
t1 := &models.TaskTemplate{Title: "Cat Template", CategoryID: &cat.ID, IsActive: true}
|
||||
t2 := &models.TaskTemplate{Title: "No Cat Template", IsActive: true}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
|
||||
templates, err := repo.GetByCategory(cat.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, templates, 1)
|
||||
assert.Equal(t, "Cat Template", templates[0].Title)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_Search(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
t1 := &models.TaskTemplate{Title: "Change HVAC Filter", Tags: "hvac,filter", IsActive: true}
|
||||
t2 := &models.TaskTemplate{Title: "Fix Faucet", Tags: "plumbing", IsActive: true}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
|
||||
// Search by title
|
||||
results, err := repo.Search("hvac")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "Change HVAC Filter", results[0].Title)
|
||||
|
||||
// Search by tag
|
||||
results, err = repo.Search("plumbing")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "Fix Faucet", results[0].Title)
|
||||
|
||||
// No match
|
||||
results, err = repo.Search("nonexistent")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, results)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetByID(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
tmpl := &models.TaskTemplate{Title: "Test Template", IsActive: true}
|
||||
require.NoError(t, db.Create(tmpl).Error)
|
||||
|
||||
found, err := repo.GetByID(tmpl.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test Template", found.Title)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetByID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
_, err := repo.GetByID(9999)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_Update(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
tmpl := &models.TaskTemplate{Title: "Original", IsActive: true}
|
||||
require.NoError(t, db.Create(tmpl).Error)
|
||||
|
||||
tmpl.Title = "Updated"
|
||||
err := repo.Update(tmpl)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.GetByID(tmpl.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", found.Title)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_Delete(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
tmpl := &models.TaskTemplate{Title: "To Delete", IsActive: true}
|
||||
require.NoError(t, db.Create(tmpl).Error)
|
||||
|
||||
err := repo.Delete(tmpl.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.GetByID(tmpl.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetAllIncludingInactive(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
t1 := &models.TaskTemplate{Title: "Active", IsActive: true}
|
||||
t2 := &models.TaskTemplate{Title: "Inactive", IsActive: false}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
|
||||
templates, err := repo.GetAllIncludingInactive()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, templates, 2) // Both active and inactive
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_Count(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
t1 := &models.TaskTemplate{Title: "Active 1", IsActive: true}
|
||||
t2 := &models.TaskTemplate{Title: "Active 2", IsActive: true}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
// Use raw SQL for inactive record to avoid GORM default:true overriding IsActive=false
|
||||
require.NoError(t, db.Exec(
|
||||
"INSERT INTO task_tasktemplate (title, is_active, display_order, created_at, updated_at) VALUES (?, ?, ?, datetime('now'), datetime('now'))",
|
||||
"Inactive", false, 0,
|
||||
).Error)
|
||||
|
||||
count, err := repo.Count()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count) // Only active
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetByRegion(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
// Create a climate region
|
||||
region := &models.ClimateRegion{Name: "Hot-Humid", ZoneNumber: 1, IsActive: true}
|
||||
require.NoError(t, db.Create(region).Error)
|
||||
|
||||
// Create template with region association
|
||||
tmpl := &models.TaskTemplate{Title: "Regional Task", IsActive: true}
|
||||
require.NoError(t, db.Create(tmpl).Error)
|
||||
|
||||
// Associate template with region via join table
|
||||
err := db.Exec("INSERT INTO task_tasktemplate_regions (task_template_id, climate_region_id) VALUES (?, ?)", tmpl.ID, region.ID).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create template without region
|
||||
tmpl2 := &models.TaskTemplate{Title: "Non-Regional Task", IsActive: true}
|
||||
require.NoError(t, db.Create(tmpl2).Error)
|
||||
|
||||
templates, err := repo.GetByRegion(region.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, templates, 1)
|
||||
assert.Equal(t, "Regional Task", templates[0].Title)
|
||||
}
|
||||
|
||||
func TestTaskTemplateRepository_GetGroupedByCategory(t *testing.T) {
|
||||
t.Skip("requires PostgreSQL: SQLite cannot scan jsonb default into json.RawMessage")
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
repo := NewTaskTemplateRepository(db)
|
||||
|
||||
var cat models.TaskCategory
|
||||
require.NoError(t, db.First(&cat).Error)
|
||||
|
||||
t1 := &models.TaskTemplate{Title: "Categorized", CategoryID: &cat.ID, IsActive: true}
|
||||
t2 := &models.TaskTemplate{Title: "Uncategorized", IsActive: true}
|
||||
require.NoError(t, db.Create(t1).Error)
|
||||
require.NoError(t, db.Create(t2).Error)
|
||||
|
||||
grouped, err := repo.GetGroupedByCategory()
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(grouped), 2) // At least the category + "Uncategorized"
|
||||
|
||||
// Uncategorized should have the template without category
|
||||
assert.Len(t, grouped["Uncategorized"], 1)
|
||||
assert.Equal(t, "Uncategorized", grouped["Uncategorized"][0].Title)
|
||||
}
|
||||
465
internal/repositories/user_repo_coverage_test.go
Normal file
465
internal/repositories/user_repo_coverage_test.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func TestUserRepository_FindByUsername_CaseInsensitive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "TestUser", "test@example.com", "Password123")
|
||||
|
||||
// Should find with different cases
|
||||
found, err := repo.FindByUsername("testuser")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "TestUser", found.Username)
|
||||
|
||||
found, err = repo.FindByUsername("TESTUSER")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "TestUser", found.Username)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByUsername_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByUsername("nonexistent")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByEmail_CaseInsensitive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "Test@Example.com", "Password123")
|
||||
|
||||
found, err := repo.FindByEmail("test@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test@Example.com", found.Email)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByEmail_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByEmail("nonexistent@example.com")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_ExistsByUsername_CaseInsensitive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "TestUser", "test@example.com", "Password123")
|
||||
|
||||
exists, err := repo.ExistsByUsername("TESTUSER")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUserRepository_ExistsByEmail_CaseInsensitive(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "Test@Example.com", "Password123")
|
||||
|
||||
exists, err := repo.ExistsByEmail("test@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestUserRepository_GetOrCreateToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create token
|
||||
token1, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token1.Key)
|
||||
|
||||
// Should return same token
|
||||
token2, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token1.Key, token2.Key)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindTokenByKey(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindTokenByKey(token.Key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, token.Key, found.Key)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindTokenByKey_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindTokenByKey("nonexistent-token-key")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteToken(token.Key)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindTokenByKey(token.Key)
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
err := repo.DeleteToken("nonexistent-key")
|
||||
assert.ErrorIs(t, err, ErrTokenNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_DeleteTokenByUserID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, err := repo.GetOrCreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.DeleteTokenByUserID(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should be gone
|
||||
var count int64
|
||||
db.Model(&models.AuthToken{}).Where("user_id = ?", user.ID).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
token, err := repo.CreateToken(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token.Key)
|
||||
assert.Equal(t, user.ID, token.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_UpdateLastLogin(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
err := repo.UpdateLastLogin(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, found.LastLogin)
|
||||
}
|
||||
|
||||
func TestUserRepository_UpdateProfile(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
profile, err := repo.GetOrCreateProfile(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
profile.Bio = "Test bio"
|
||||
profile.PhoneNumber = "+1-555-0100"
|
||||
err = repo.UpdateProfile(profile)
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := repo.GetOrCreateProfile(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test bio", updated.Bio)
|
||||
assert.Equal(t, "+1-555-0100", updated.PhoneNumber)
|
||||
}
|
||||
|
||||
func TestUserRepository_SetProfileVerified(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create profile
|
||||
_, err := repo.GetOrCreateProfile(user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.SetProfileVerified(user.ID, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
profile, err := repo.GetOrCreateProfile(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, profile.Verified)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByIDWithProfile(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create profile first
|
||||
profile := &models.UserProfile{
|
||||
UserID: user.ID,
|
||||
Bio: "My bio",
|
||||
}
|
||||
err := db.Create(profile).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByIDWithProfile(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.ID)
|
||||
require.NotNil(t, found.Profile)
|
||||
assert.Equal(t, "My bio", found.Profile.Bio)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByIDWithProfile_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByIDWithProfile(9999)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_ConfirmationCode_Lifecycle(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create confirmation code
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreateConfirmationCode(user.ID, "123456", expiresAt)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, code.ID)
|
||||
|
||||
// Find it
|
||||
found, err := repo.FindConfirmationCode(user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, code.ID, found.ID)
|
||||
|
||||
// Mark as used
|
||||
err = repo.MarkConfirmationCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not find used code
|
||||
_, err = repo.FindConfirmationCode(user.ID, "123456")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserRepository_ConfirmationCode_InvalidatesExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
// Create first code
|
||||
code1, err := repo.CreateConfirmationCode(user.ID, "111111", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create second code (should invalidate first)
|
||||
_, err = repo.CreateConfirmationCode(user.ID, "222222", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First code should be used/invalidated
|
||||
var c models.ConfirmationCode
|
||||
db.First(&c, code1.ID)
|
||||
assert.True(t, c.IsUsed)
|
||||
}
|
||||
|
||||
func TestUserRepository_Transaction(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
err := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
found, err := txRepo.FindByID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
found.FirstName = "Updated"
|
||||
return txRepo.Update(found)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", found.FirstName)
|
||||
}
|
||||
|
||||
func TestUserRepository_DB(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
assert.NotNil(t, repo.DB())
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByAppleID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_123",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(appleAuth).Error)
|
||||
|
||||
found, err := repo.FindByAppleID("apple_sub_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByAppleID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByAppleID("nonexistent_apple_id")
|
||||
assert.ErrorIs(t, err, ErrAppleAuthNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByGoogleID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_123",
|
||||
Email: "google@test.com",
|
||||
}
|
||||
require.NoError(t, db.Create(googleAuth).Error)
|
||||
|
||||
found, err := repo.FindByGoogleID("google_sub_123")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, found.UserID)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindByGoogleID_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByGoogleID("nonexistent_google_id")
|
||||
assert.ErrorIs(t, err, ErrGoogleAuthNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndUpdateAppleSocialAuth(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
|
||||
auth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_456",
|
||||
Email: "apple@test.com",
|
||||
}
|
||||
err := repo.CreateAppleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, auth.ID)
|
||||
|
||||
auth.Email = "updated@test.com"
|
||||
err = repo.UpdateAppleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByAppleID("apple_sub_456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updated@test.com", found.Email)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreateAndUpdateGoogleSocialAuth(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
|
||||
auth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_456",
|
||||
Email: "google@test.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
err := repo.CreateGoogleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, auth.ID)
|
||||
|
||||
auth.Name = "Updated Name"
|
||||
err = repo.UpdateGoogleSocialAuth(auth)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindByGoogleID("google_sub_456")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Name", found.Name)
|
||||
}
|
||||
|
||||
func TestUserRepository_SearchUsers(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "john_doe", "john@example.com", "Password123")
|
||||
testutil.CreateTestUser(t, db, "jane_smith", "jane@example.com", "Password123")
|
||||
testutil.CreateTestUser(t, db, "bob_builder", "bob@example.com", "Password123")
|
||||
|
||||
users, total, err := repo.SearchUsers("john", 10, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "john_doe", users[0].Username)
|
||||
}
|
||||
|
||||
func TestUserRepository_ListUsers(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "user1", "user1@example.com", "Password123")
|
||||
testutil.CreateTestUser(t, db, "user2", "user2@example.com", "Password123")
|
||||
testutil.CreateTestUser(t, db, "user3", "user3@example.com", "Password123")
|
||||
|
||||
users, total, err := repo.ListUsers(2, 0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), total)
|
||||
assert.Len(t, users, 2) // Limited to 2
|
||||
|
||||
users, total, err = repo.ListUsers(2, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(3), total)
|
||||
assert.Len(t, users, 1) // Only 1 remaining
|
||||
}
|
||||
367
internal/repositories/user_repo_extended_test.go
Normal file
367
internal/repositories/user_repo_extended_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
// === Password Reset Code Lifecycle ===
|
||||
|
||||
func TestUserRepository_PasswordResetCode_Lifecycle(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_abc123", "reset_token_xyz", expiresAt)
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, code.ID)
|
||||
assert.Equal(t, "hash_abc123", code.CodeHash)
|
||||
assert.Equal(t, "reset_token_xyz", code.ResetToken)
|
||||
assert.False(t, code.Used)
|
||||
assert.Equal(t, 0, code.Attempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_CreatePasswordResetCode_InvalidatesExisting(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
|
||||
code1, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First code should be marked as used
|
||||
var c models.PasswordResetCode
|
||||
db.First(&c, code1.ID)
|
||||
assert.True(t, c.Used)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_abc", "token_abc", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, foundUser, err := repo.FindPasswordResetCodeByEmail("test@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, user.ID, foundUser.ID)
|
||||
assert.Equal(t, "hash_abc", found.CodeHash)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail_UserNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, _, err := repo.FindPasswordResetCodeByEmail("nonexistent@example.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByEmail_NoCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, _, err := repo.FindPasswordResetCodeByEmail("test@example.com")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_xyz", "token_xyz", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
found, err := repo.FindPasswordResetCodeByToken("token_xyz")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hash_xyz", found.CodeHash)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindPasswordResetCodeByToken("nonexistent_token")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_Expired(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Already expired
|
||||
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash_exp", "token_exp", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_exp")
|
||||
assert.ErrorIs(t, err, ErrCodeExpired)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_Used(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_used", "token_used", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mark as used
|
||||
err = repo.MarkPasswordResetCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_used")
|
||||
assert.ErrorIs(t, err, ErrCodeUsed)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindPasswordResetCodeByToken_TooManyAttempts(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_attempts", "token_attempts", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Max out attempts
|
||||
for i := 0; i < 5; i++ {
|
||||
err = repo.IncrementResetCodeAttempts(code.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = repo.FindPasswordResetCodeByToken("token_attempts")
|
||||
assert.ErrorIs(t, err, ErrTooManyAttempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_IncrementResetCodeAttempts(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_inc", "token_inc", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.IncrementResetCodeAttempts(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.PasswordResetCode
|
||||
db.First(&updated, code.ID)
|
||||
assert.Equal(t, 1, updated.Attempts)
|
||||
}
|
||||
|
||||
func TestUserRepository_MarkPasswordResetCodeUsed(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
code, err := repo.CreatePasswordResetCode(user.ID, "hash_mark", "token_mark", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.MarkPasswordResetCodeUsed(code.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var updated models.PasswordResetCode
|
||||
db.First(&updated, code.ID)
|
||||
assert.True(t, updated.Used)
|
||||
}
|
||||
|
||||
func TestUserRepository_CountRecentPasswordResetRequests(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
expiresAt := time.Now().UTC().Add(1 * time.Hour)
|
||||
_, err := repo.CreatePasswordResetCode(user.ID, "hash1", "token1", expiresAt)
|
||||
require.NoError(t, err)
|
||||
_, err = repo.CreatePasswordResetCode(user.ID, "hash2", "token2", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := repo.CountRecentPasswordResetRequests(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), count)
|
||||
}
|
||||
|
||||
// === FindUsersInSharedResidences ===
|
||||
|
||||
func TestUserRepository_FindUsersInSharedResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := NewUserRepository(db)
|
||||
resRepo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
unrelated := testutil.CreateTestUser(t, db, "unrelated", "unrelated@test.com", "Password123")
|
||||
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Shared House")
|
||||
resRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
// Owner should see shared user
|
||||
users, err := userRepo.FindUsersInSharedResidences(owner.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, shared.ID, users[0].ID)
|
||||
|
||||
// Shared user should see owner
|
||||
users, err = userRepo.FindUsersInSharedResidences(shared.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, owner.ID, users[0].ID)
|
||||
|
||||
// Unrelated should see no one
|
||||
users, err = userRepo.FindUsersInSharedResidences(unrelated.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, users)
|
||||
}
|
||||
|
||||
// === FindUserIfSharedResidence ===
|
||||
|
||||
func TestUserRepository_FindUserIfSharedResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := NewUserRepository(db)
|
||||
resRepo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
unrelated := testutil.CreateTestUser(t, db, "unrelated", "unrelated@test.com", "Password123")
|
||||
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Shared House")
|
||||
resRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
// Owner requesting shared user => should find
|
||||
found, err := userRepo.FindUserIfSharedResidence(shared.ID, owner.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
assert.Equal(t, shared.ID, found.ID)
|
||||
|
||||
// Unrelated requesting shared user => should not find
|
||||
found, err = userRepo.FindUserIfSharedResidence(shared.ID, unrelated.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, found)
|
||||
|
||||
// Requesting self => should work
|
||||
found, err = userRepo.FindUserIfSharedResidence(owner.ID, owner.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, found)
|
||||
assert.Equal(t, owner.ID, found.ID)
|
||||
}
|
||||
|
||||
// === FindProfilesInSharedResidences ===
|
||||
|
||||
func TestUserRepository_FindProfilesInSharedResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := NewUserRepository(db)
|
||||
resRepo := NewResidenceRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
|
||||
// Create profiles
|
||||
ownerProfile := &models.UserProfile{UserID: owner.ID, Bio: "Owner bio"}
|
||||
sharedProfile := &models.UserProfile{UserID: shared.ID, Bio: "Shared bio"}
|
||||
require.NoError(t, db.Create(ownerProfile).Error)
|
||||
require.NoError(t, db.Create(sharedProfile).Error)
|
||||
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Shared House")
|
||||
resRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
// Owner sees own profile + shared user profile
|
||||
profiles, err := userRepo.FindProfilesInSharedResidences(owner.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, profiles, 2)
|
||||
}
|
||||
|
||||
// === ConfirmationCode Expired ===
|
||||
|
||||
func TestUserRepository_FindConfirmationCode_Expired(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Create already-expired code
|
||||
expiresAt := time.Now().UTC().Add(-1 * time.Hour)
|
||||
_, err := repo.CreateConfirmationCode(user.ID, "999999", expiresAt)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = repo.FindConfirmationCode(user.ID, "999999")
|
||||
assert.ErrorIs(t, err, ErrCodeExpired)
|
||||
}
|
||||
|
||||
func TestUserRepository_FindConfirmationCode_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
_, err := repo.FindConfirmationCode(user.ID, "000000")
|
||||
assert.ErrorIs(t, err, ErrCodeNotFound)
|
||||
}
|
||||
|
||||
// === Transaction Rollback ===
|
||||
|
||||
func TestUserRepository_Transaction_Rollback(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
err := repo.Transaction(func(txRepo *UserRepository) error {
|
||||
found, err := txRepo.FindByID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
found.FirstName = "ShouldRollback"
|
||||
if err := txRepo.Update(found); err != nil {
|
||||
return err
|
||||
}
|
||||
// Simulate an error to trigger rollback
|
||||
return ErrUserNotFound
|
||||
})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Name should NOT have been updated
|
||||
found, err := repo.FindByID(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, "ShouldRollback", found.FirstName)
|
||||
}
|
||||
|
||||
// === FindByUsernameOrEmail not found ===
|
||||
|
||||
func TestUserRepository_FindByUsernameOrEmail_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
_, err := repo.FindByUsernameOrEmail("nonexistent")
|
||||
assert.ErrorIs(t, err, ErrUserNotFound)
|
||||
}
|
||||
@@ -19,7 +19,7 @@ func TestUserRepository_Create(t *testing.T) {
|
||||
Email: "test@example.com",
|
||||
IsActive: true,
|
||||
}
|
||||
user.SetPassword("password123")
|
||||
user.SetPassword("Password123")
|
||||
|
||||
err := repo.Create(user)
|
||||
require.NoError(t, err)
|
||||
@@ -31,7 +31,7 @@ func TestUserRepository_FindByID(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Find by ID
|
||||
found, err := repo.FindByID(user.ID)
|
||||
@@ -54,7 +54,7 @@ func TestUserRepository_FindByUsername(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Find by username
|
||||
found, err := repo.FindByUsername("testuser")
|
||||
@@ -67,7 +67,7 @@ func TestUserRepository_FindByEmail(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Find by email
|
||||
found, err := repo.FindByEmail("test@example.com")
|
||||
@@ -80,7 +80,7 @@ func TestUserRepository_FindByUsernameOrEmail(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -105,7 +105,7 @@ func TestUserRepository_Update(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// Update user
|
||||
user.FirstName = "John"
|
||||
@@ -125,7 +125,7 @@ func TestUserRepository_ExistsByUsername(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "password123")
|
||||
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "Password123")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -150,7 +150,7 @@ func TestUserRepository_ExistsByEmail(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "password123")
|
||||
testutil.CreateTestUser(t, db, "existinguser", "existing@example.com", "Password123")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -175,7 +175,7 @@ func TestUserRepository_GetOrCreateProfile(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
// Create user
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "Password123")
|
||||
|
||||
// First call should create
|
||||
profile1, err := repo.GetOrCreateProfile(user.ID)
|
||||
@@ -193,14 +193,14 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
t.Run("email user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "emailuser", "email@test.com", "Password123")
|
||||
provider, err := repo.FindAuthProvider(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "email", provider)
|
||||
})
|
||||
|
||||
t.Run("apple user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "appleuser", "apple@test.com", "Password123")
|
||||
appleAuth := &models.AppleSocialAuth{
|
||||
UserID: user.ID,
|
||||
AppleID: "apple_sub_test",
|
||||
@@ -214,7 +214,7 @@ func TestUserRepository_FindAuthProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("google user", func(t *testing.T) {
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "googleuser", "google@test.com", "Password123")
|
||||
googleAuth := &models.GoogleSocialAuth{
|
||||
UserID: user.ID,
|
||||
GoogleID: "google_sub_test",
|
||||
@@ -233,7 +233,7 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "deletebare", "deletebare@test.com", "Password123")
|
||||
|
||||
// Create profile and token
|
||||
profile := &models.UserProfile{UserID: user.ID, Verified: true}
|
||||
@@ -271,7 +271,7 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "deletefiles", "deletefiles@test.com", "password123")
|
||||
user := testutil.CreateTestUser(t, db, "deletefiles", "deletefiles@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test Home")
|
||||
|
||||
// Create document with file
|
||||
@@ -319,8 +319,8 @@ func TestUserRepository_DeleteUserCascade(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "deleteowner", "deleteowner@test.com", "password123")
|
||||
otherUser := testutil.CreateTestUser(t, db, "otheruser", "other@test.com", "password123")
|
||||
owner := testutil.CreateTestUser(t, db, "deleteowner", "deleteowner@test.com", "Password123")
|
||||
otherUser := testutil.CreateTestUser(t, db, "otheruser", "other@test.com", "Password123")
|
||||
otherResidence := testutil.CreateTestResidence(t, db, otherUser.ID, "Other Home")
|
||||
|
||||
// Owner's residence
|
||||
|
||||
29
internal/repositories/util_test.go
Normal file
29
internal/repositories/util_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEscapeLikeWildcards(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"no wildcards", "hello", "hello"},
|
||||
{"percent sign", "50% off", "50\\% off"},
|
||||
{"underscore", "user_name", "user\\_name"},
|
||||
{"both wildcards", "50%_off", "50\\%\\_off"},
|
||||
{"empty string", "", ""},
|
||||
{"only wildcards", "%_", "\\%\\_"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := escapeLikeWildcards(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
262
internal/router/error_handler_test.go
Normal file
262
internal/router/error_handler_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/apperrors"
|
||||
"github.com/treytartt/honeydue-api/internal/dto/responses"
|
||||
"github.com/treytartt/honeydue-api/internal/i18n"
|
||||
"github.com/treytartt/honeydue-api/internal/services"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Initialize i18n so LocalizedMessage returns real translations
|
||||
_ = i18n.Init()
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// makeContext creates a fresh echo.Context and response recorder for testing.
|
||||
func makeContext(method, path string) (echo.Context, *httptest.ResponseRecorder) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(method, path, nil)
|
||||
req.Header.Set("Accept-Language", "en")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
// Set up i18n localizer on context (same key the middleware uses)
|
||||
if i18n.Bundle != nil {
|
||||
loc := i18n.NewLocalizer("en")
|
||||
c.Set(i18n.LocalizerKey, loc)
|
||||
}
|
||||
return c, rec
|
||||
}
|
||||
|
||||
// decodeError reads the JSON response into an ErrorResponse.
|
||||
func decodeError(t *testing.T, rec *httptest.ResponseRecorder) responses.ErrorResponse {
|
||||
t.Helper()
|
||||
var resp responses.ErrorResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to decode error response: %v\nbody: %s", err, rec.Body.String())
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// --- AppError branch ---
|
||||
|
||||
func TestErrorHandler_AppError_NotFound(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/tasks/1")
|
||||
err := apperrors.NotFound("error.task_not_found")
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
resp := decodeError(t, rec)
|
||||
if resp.Error == "" {
|
||||
t.Error("expected non-empty error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_AppError_Forbidden(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/res/1")
|
||||
err := apperrors.Forbidden("error.task_access_denied")
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_AppError_BadRequest(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/tasks")
|
||||
err := apperrors.BadRequest("error.task_already_cancelled")
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_AppError_WithWrappedErr(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/x")
|
||||
err := apperrors.Internal(fmt.Errorf("db conn failed"))
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_AppError_I18nFallback(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/x")
|
||||
err := apperrors.NotFound("nonexistent.i18n.key").WithMessage("fallback msg")
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
resp := decodeError(t, rec)
|
||||
if resp.Error != "fallback msg" {
|
||||
t.Errorf("error = %q, want %q", resp.Error, "fallback msg")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Echo HTTPError branch ---
|
||||
|
||||
func TestErrorHandler_EchoHTTPError_404(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/nope")
|
||||
err := echo.NewHTTPError(http.StatusNotFound, "not found")
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
resp := decodeError(t, rec)
|
||||
if resp.Error != "not found" {
|
||||
t.Errorf("error = %q, want %q", resp.Error, "not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_EchoHTTPError_405(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/x")
|
||||
err := echo.ErrMethodNotAllowed
|
||||
customHTTPErrorHandler(err, c)
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Sentinel error branch ---
|
||||
|
||||
func TestErrorHandler_ErrTaskNotFound(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/t")
|
||||
customHTTPErrorHandler(services.ErrTaskNotFound, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrCompletionNotFound(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/t")
|
||||
customHTTPErrorHandler(services.ErrCompletionNotFound, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrTaskAccessDenied(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/t")
|
||||
customHTTPErrorHandler(services.ErrTaskAccessDenied, c)
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrTaskAlreadyCancelled(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/t")
|
||||
customHTTPErrorHandler(services.ErrTaskAlreadyCancelled, c)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrTaskAlreadyArchived(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/t")
|
||||
customHTTPErrorHandler(services.ErrTaskAlreadyArchived, c)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrResidenceNotFound(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/r")
|
||||
customHTTPErrorHandler(services.ErrResidenceNotFound, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrResidenceAccessDenied(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/r")
|
||||
customHTTPErrorHandler(services.ErrResidenceAccessDenied, c)
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrNotResidenceOwner(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodDelete, "/r")
|
||||
customHTTPErrorHandler(services.ErrNotResidenceOwner, c)
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrPropertiesLimitReached(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/r")
|
||||
customHTTPErrorHandler(services.ErrPropertiesLimitReached, c)
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrCannotRemoveOwner(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodDelete, "/r")
|
||||
customHTTPErrorHandler(services.ErrCannotRemoveOwner, c)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrShareCodeExpired(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/r")
|
||||
customHTTPErrorHandler(services.ErrShareCodeExpired, c)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrShareCodeInvalid(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/r")
|
||||
customHTTPErrorHandler(services.ErrShareCodeInvalid, c)
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorHandler_ErrUserAlreadyMember(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodPost, "/r")
|
||||
customHTTPErrorHandler(services.ErrUserAlreadyMember, c)
|
||||
if rec.Code != http.StatusConflict {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusConflict)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Default branch ---
|
||||
|
||||
func TestErrorHandler_UnknownError(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/x")
|
||||
customHTTPErrorHandler(errors.New("random unexpected error"), c)
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Errorf("code = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Committed response ---
|
||||
|
||||
func TestErrorHandler_CommittedResponse_Noop(t *testing.T) {
|
||||
c, rec := makeContext(http.MethodGet, "/x")
|
||||
// Write something to commit the response
|
||||
c.JSON(http.StatusOK, map[string]string{"ok": "true"})
|
||||
// Capture the body after first write
|
||||
bodyBefore := rec.Body.String()
|
||||
|
||||
// Now call error handler — it should be a no-op
|
||||
customHTTPErrorHandler(errors.New("should be ignored"), c)
|
||||
bodyAfter := rec.Body.String()
|
||||
|
||||
if bodyBefore != bodyAfter {
|
||||
t.Errorf("body changed after committed response:\nbefore: %s\nafter: %s", bodyBefore, bodyAfter)
|
||||
}
|
||||
}
|
||||
115
internal/router/router_helpers.go
Normal file
115
internal/router/router_helpers.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/monitoring"
|
||||
)
|
||||
|
||||
// CorsOrigins returns the CORS allowed origins based on debug mode.
|
||||
func CorsOrigins(debug bool, configuredOrigins []string) []string {
|
||||
if debug {
|
||||
return []string{
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3001",
|
||||
"http://localhost:8080",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:3001",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://127.0.0.1:8000",
|
||||
}
|
||||
}
|
||||
if len(configuredOrigins) > 0 {
|
||||
return configuredOrigins
|
||||
}
|
||||
return []string{
|
||||
"https://api.myhoneydue.com",
|
||||
"https://myhoneydue.com",
|
||||
"https://admin.myhoneydue.com",
|
||||
}
|
||||
}
|
||||
|
||||
// ShouldSkipTimeout returns true for paths that should bypass timeout middleware.
|
||||
func ShouldSkipTimeout(path, host, adminHost string) bool {
|
||||
return (adminHost != "" && host == adminHost) ||
|
||||
strings.HasPrefix(path, "/_next") ||
|
||||
strings.HasSuffix(path, "/ws")
|
||||
}
|
||||
|
||||
// ShouldSkipBodyLimit returns true for webhook endpoints.
|
||||
func ShouldSkipBodyLimit(path string) bool {
|
||||
return strings.HasPrefix(path, "/api/subscription/webhook")
|
||||
}
|
||||
|
||||
// ShouldSkipGzip returns true for media endpoints.
|
||||
func ShouldSkipGzip(path string) bool {
|
||||
return strings.HasPrefix(path, "/api/media/")
|
||||
}
|
||||
|
||||
// ParseEndpoint splits "GET /api/foo" into method and path.
|
||||
func ParseEndpoint(endpoint string) (method, path string) {
|
||||
parts := strings.SplitN(endpoint, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return endpoint, ""
|
||||
}
|
||||
|
||||
// AllowedProxyHosts builds the list of hosts allowed for admin proxy.
|
||||
func AllowedProxyHosts(adminHost string) []string {
|
||||
var hosts []string
|
||||
if adminHost != "" {
|
||||
hosts = append(hosts, adminHost)
|
||||
}
|
||||
hosts = append(hosts, "localhost:3001", "127.0.0.1:3001", "localhost:8000", "127.0.0.1:8000")
|
||||
return hosts
|
||||
}
|
||||
|
||||
// DetermineAdminRoute decides how to route admin subdomain requests.
|
||||
func DetermineAdminRoute(path string) string {
|
||||
if strings.HasPrefix(path, "/admin") {
|
||||
return "redirect"
|
||||
}
|
||||
if strings.HasPrefix(path, "/api/") {
|
||||
return "passthrough"
|
||||
}
|
||||
return "proxy"
|
||||
}
|
||||
|
||||
// FormatPrometheusMetrics converts HTTP stats to Prometheus text format.
|
||||
func FormatPrometheusMetrics(stats monitoring.HTTPStats) string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("# HELP http_requests_total Total number of HTTP requests.\n")
|
||||
b.WriteString("# TYPE http_requests_total counter\n")
|
||||
for statusCode, count := range stats.ByStatusCode {
|
||||
fmt.Fprintf(&b, "http_requests_total{status_code=\"%d\"} %d\n", statusCode, count)
|
||||
}
|
||||
|
||||
b.WriteString("# HELP http_endpoint_requests_total Total requests per endpoint.\n")
|
||||
b.WriteString("# TYPE http_endpoint_requests_total counter\n")
|
||||
for endpoint, epStats := range stats.ByEndpoint {
|
||||
method, path := ParseEndpoint(endpoint)
|
||||
fmt.Fprintf(&b, "http_endpoint_requests_total{method=\"%s\",path=\"%s\"} %d\n", method, path, epStats.Count)
|
||||
}
|
||||
|
||||
b.WriteString("# HELP http_request_duration_ms Average request duration in milliseconds per endpoint.\n")
|
||||
b.WriteString("# TYPE http_request_duration_ms gauge\n")
|
||||
for endpoint, epStats := range stats.ByEndpoint {
|
||||
method, path := ParseEndpoint(endpoint)
|
||||
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"avg\"} %.2f\n", method, path, epStats.AvgLatencyMs)
|
||||
fmt.Fprintf(&b, "http_request_duration_ms{method=\"%s\",path=\"%s\",quantile=\"p95\"} %.2f\n", method, path, epStats.P95LatencyMs)
|
||||
}
|
||||
|
||||
b.WriteString("# HELP http_error_rate Overall error rate (4xx+5xx / total).\n")
|
||||
b.WriteString("# TYPE http_error_rate gauge\n")
|
||||
fmt.Fprintf(&b, "http_error_rate %.4f\n", stats.ErrorRate)
|
||||
|
||||
b.WriteString("# HELP http_requests_per_minute Current request rate.\n")
|
||||
b.WriteString("# TYPE http_requests_per_minute gauge\n")
|
||||
fmt.Fprintf(&b, "http_requests_per_minute %.2f\n", stats.RequestsPerMinute)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
200
internal/router/router_helpers_test.go
Normal file
200
internal/router/router_helpers_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/monitoring"
|
||||
)
|
||||
|
||||
// --- CorsOrigins ---
|
||||
|
||||
func TestCorsOrigins_Debug(t *testing.T) {
|
||||
origins := CorsOrigins(true, nil)
|
||||
if len(origins) != 8 {
|
||||
t.Errorf("len = %d, want 8", len(origins))
|
||||
}
|
||||
found := false
|
||||
for _, o := range origins {
|
||||
if o == "http://localhost:3000" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected localhost:3000 in debug origins")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsOrigins_ProductionConfigured(t *testing.T) {
|
||||
custom := []string{"https://example.com"}
|
||||
origins := CorsOrigins(false, custom)
|
||||
if len(origins) != 1 || origins[0] != "https://example.com" {
|
||||
t.Errorf("got %v, want [https://example.com]", origins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsOrigins_ProductionDefault(t *testing.T) {
|
||||
origins := CorsOrigins(false, nil)
|
||||
if len(origins) != 3 {
|
||||
t.Errorf("len = %d, want 3", len(origins))
|
||||
}
|
||||
found := false
|
||||
for _, o := range origins {
|
||||
if o == "https://myhoneydue.com" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected myhoneydue.com in default origins")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ShouldSkipTimeout ---
|
||||
|
||||
func TestShouldSkipTimeout_AdminHost_True(t *testing.T) {
|
||||
if !ShouldSkipTimeout("/some/path", "admin.example.com", "admin.example.com") {
|
||||
t.Error("expected true for admin host")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipTimeout_NextPath_True(t *testing.T) {
|
||||
if !ShouldSkipTimeout("/_next/static/chunk.js", "app.example.com", "admin.example.com") {
|
||||
t.Error("expected true for _next path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipTimeout_WsPath_True(t *testing.T) {
|
||||
if !ShouldSkipTimeout("/api/events/ws", "app.example.com", "admin.example.com") {
|
||||
t.Error("expected true for /ws path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipTimeout_NormalPath_False(t *testing.T) {
|
||||
if ShouldSkipTimeout("/api/tasks/", "app.example.com", "admin.example.com") {
|
||||
t.Error("expected false for normal path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipTimeout_EmptyAdminHost_False(t *testing.T) {
|
||||
if ShouldSkipTimeout("/api/tasks/", "admin.example.com", "") {
|
||||
t.Error("expected false when admin host is empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ShouldSkipBodyLimit ---
|
||||
|
||||
func TestShouldSkipBodyLimit_Webhook_True(t *testing.T) {
|
||||
if !ShouldSkipBodyLimit("/api/subscription/webhook/apple/") {
|
||||
t.Error("expected true for webhook path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipBodyLimit_Normal_False(t *testing.T) {
|
||||
if ShouldSkipBodyLimit("/api/tasks/") {
|
||||
t.Error("expected false for normal path")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ShouldSkipGzip ---
|
||||
|
||||
func TestShouldSkipGzip_Media_True(t *testing.T) {
|
||||
if !ShouldSkipGzip("/api/media/document/123") {
|
||||
t.Error("expected true for media path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipGzip_Api_False(t *testing.T) {
|
||||
if ShouldSkipGzip("/api/tasks/") {
|
||||
t.Error("expected false for non-media path")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseEndpoint ---
|
||||
|
||||
func TestParseEndpoint_MethodAndPath(t *testing.T) {
|
||||
method, path := ParseEndpoint("GET /api/tasks/")
|
||||
if method != "GET" || path != "/api/tasks/" {
|
||||
t.Errorf("got (%q, %q), want (GET, /api/tasks/)", method, path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEndpoint_NoSpace(t *testing.T) {
|
||||
method, path := ParseEndpoint("GET")
|
||||
if method != "GET" || path != "" {
|
||||
t.Errorf("got (%q, %q), want (GET, \"\")", method, path)
|
||||
}
|
||||
}
|
||||
|
||||
// --- AllowedProxyHosts ---
|
||||
|
||||
func TestAllowedProxyHosts_WithAdmin(t *testing.T) {
|
||||
hosts := AllowedProxyHosts("admin.example.com")
|
||||
if len(hosts) != 5 {
|
||||
t.Errorf("len = %d, want 5", len(hosts))
|
||||
}
|
||||
if hosts[0] != "admin.example.com" {
|
||||
t.Errorf("first host = %q, want admin.example.com", hosts[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowedProxyHosts_WithoutAdmin(t *testing.T) {
|
||||
hosts := AllowedProxyHosts("")
|
||||
if len(hosts) != 4 {
|
||||
t.Errorf("len = %d, want 4", len(hosts))
|
||||
}
|
||||
}
|
||||
|
||||
// --- DetermineAdminRoute ---
|
||||
|
||||
func TestDetermineAdminRoute_Admin_Redirect(t *testing.T) {
|
||||
got := DetermineAdminRoute("/admin/dashboard")
|
||||
if got != "redirect" {
|
||||
t.Errorf("got %q, want redirect", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineAdminRoute_Api_Passthrough(t *testing.T) {
|
||||
got := DetermineAdminRoute("/api/tasks/")
|
||||
if got != "passthrough" {
|
||||
t.Errorf("got %q, want passthrough", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetermineAdminRoute_Other_Proxy(t *testing.T) {
|
||||
got := DetermineAdminRoute("/dashboard")
|
||||
if got != "proxy" {
|
||||
t.Errorf("got %q, want proxy", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- FormatPrometheusMetrics ---
|
||||
|
||||
func TestFormatPrometheusMetrics_Output(t *testing.T) {
|
||||
stats := monitoring.HTTPStats{
|
||||
RequestsTotal: 100,
|
||||
RequestsPerMinute: 10.5,
|
||||
ErrorRate: 0.05,
|
||||
ByStatusCode: map[int]int64{200: 90, 500: 10},
|
||||
ByEndpoint: map[string]monitoring.EndpointStats{
|
||||
"GET /api/tasks/": {Count: 50, AvgLatencyMs: 12.5, P95LatencyMs: 45.0},
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatPrometheusMetrics(stats)
|
||||
|
||||
// Check key sections exist
|
||||
checks := []string{
|
||||
"http_requests_total",
|
||||
"http_endpoint_requests_total",
|
||||
"http_request_duration_ms",
|
||||
"http_error_rate",
|
||||
"http_requests_per_minute",
|
||||
`method="GET"`,
|
||||
`path="/api/tasks/"`,
|
||||
}
|
||||
for _, check := range checks {
|
||||
if !strings.Contains(output, check) {
|
||||
t.Errorf("output missing %q", check)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func createRefreshTestUser(t *testing.T, db *gorm.DB) *models.User {
|
||||
Email: "refresh@test.com",
|
||||
IsActive: true,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("password123"))
|
||||
require.NoError(t, user.SetPassword("Password123"))
|
||||
require.NoError(t, db.Create(user).Error)
|
||||
return user
|
||||
}
|
||||
|
||||
800
internal/services/auth_service_test.go
Normal file
800
internal/services/auth_service_test.go
Normal file
@@ -0,0 +1,800 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupAuthService(t *testing.T) (*AuthService, *repositories.UserRepository) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
DebugFixedCodes: true,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
SecretKey: "test-secret",
|
||||
ConfirmationExpiry: 24 * time.Hour,
|
||||
PasswordResetExpiry: 15 * time.Minute,
|
||||
MaxPasswordResetRate: 3,
|
||||
TokenExpiryDays: 90,
|
||||
TokenRefreshDays: 60,
|
||||
},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
service.SetNotificationRepository(notifRepo)
|
||||
return service, userRepo
|
||||
}
|
||||
|
||||
// === Login ===
|
||||
|
||||
func TestAuthService_Login(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "testuser", resp.User.Username)
|
||||
}
|
||||
|
||||
func TestAuthService_Login_ByEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, err := service.Login(req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
}
|
||||
|
||||
func TestAuthService_Login_InvalidCredentials(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "WrongPassword1",
|
||||
}
|
||||
|
||||
_, err := service.Login(req)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_Login_UserNotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "nonexistent",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(req)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_Login_InactiveUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "inactive", "inactive@test.com", "Password123")
|
||||
// Deactivate
|
||||
user.IsActive = false
|
||||
db.Save(user)
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "inactive",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, err := service.Login(req)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.account_inactive")
|
||||
}
|
||||
|
||||
// === Register ===
|
||||
|
||||
func TestAuthService_Register(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
resp, code, err := service.Register(req)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.Token)
|
||||
assert.Equal(t, "newuser", resp.User.Username)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes=true
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DuplicateUsername(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{DebugFixedCodes: true},
|
||||
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "taken", "taken@test.com", "Password123")
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "taken",
|
||||
Email: "different@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, _, err := service.Register(req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.username_taken")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DuplicateEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{DebugFixedCodes: true},
|
||||
Security: config.SecurityConfig{SecretKey: "test", ConfirmationExpiry: 24 * time.Hour},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "existing", "taken@test.com", "Password123")
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "taken@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
|
||||
_, _, err := service.Register(req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_taken")
|
||||
}
|
||||
|
||||
// === GetCurrentUser ===
|
||||
|
||||
func TestAuthService_GetCurrentUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
// Create profile
|
||||
userRepo.GetOrCreateProfile(user.ID)
|
||||
|
||||
resp, err := service.GetCurrentUser(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testuser", resp.Username)
|
||||
assert.Equal(t, "test@test.com", resp.Email)
|
||||
assert.Equal(t, "email", resp.AuthProvider) // Default for no social auth
|
||||
}
|
||||
|
||||
// === UpdateProfile ===
|
||||
|
||||
func TestAuthService_UpdateProfile(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
userRepo.GetOrCreateProfile(user.ID)
|
||||
|
||||
newFirst := "John"
|
||||
newLast := "Doe"
|
||||
req := &requests.UpdateProfileRequest{
|
||||
FirstName: &newFirst,
|
||||
LastName: &newLast,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateProfile(user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "John", resp.FirstName)
|
||||
assert.Equal(t, "Doe", resp.LastName)
|
||||
}
|
||||
|
||||
func TestAuthService_UpdateProfile_DuplicateEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "user1", "user1@test.com", "Password123")
|
||||
user2 := testutil.CreateTestUser(t, db, "user2", "user2@test.com", "Password123")
|
||||
userRepo.GetOrCreateProfile(user2.ID)
|
||||
|
||||
takenEmail := "user1@test.com"
|
||||
req := &requests.UpdateProfileRequest{
|
||||
Email: &takenEmail,
|
||||
}
|
||||
|
||||
_, err := service.UpdateProfile(user2.ID, req)
|
||||
testutil.AssertAppError(t, err, http.StatusConflict, "error.email_already_taken")
|
||||
}
|
||||
|
||||
func TestAuthService_UpdateProfile_SameEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
userRepo.GetOrCreateProfile(user.ID)
|
||||
|
||||
sameEmail := "test@test.com"
|
||||
req := &requests.UpdateProfileRequest{
|
||||
Email: &sameEmail,
|
||||
}
|
||||
|
||||
// Same email should not trigger duplicate error
|
||||
resp, err := service.UpdateProfile(user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test@test.com", resp.Email)
|
||||
}
|
||||
|
||||
// === VerifyEmail ===
|
||||
|
||||
func TestAuthService_VerifyEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register a user (creates confirmation code)
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the user ID
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify with the debug code
|
||||
err = service.VerifyEmail(user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify again — should get already verified error
|
||||
err = service.VerifyEmail(user.ID, "123456")
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyEmail_InvalidCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code — with DebugFixedCodes enabled, "123456" bypasses normal lookup,
|
||||
// but a wrong code should use the normal path
|
||||
err = service.VerifyEmail(user.ID, "000000")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === ResendVerificationCode ===
|
||||
|
||||
func TestAuthService_ResendVerificationCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
code, err := service.ResendVerificationCode(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||
}
|
||||
|
||||
func TestAuthService_ResendVerificationCode_AlreadyVerified(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register and verify
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "newuser",
|
||||
Email: "new@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("new@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.VerifyEmail(user.ID, "123456")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.ResendVerificationCode(user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.email_already_verified")
|
||||
}
|
||||
|
||||
// === ForgotPassword ===
|
||||
|
||||
func TestAuthService_ForgotPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register a user first
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
code, user, err := service.ForgotPassword("test@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "123456", code) // DebugFixedCodes
|
||||
assert.NotNil(t, user)
|
||||
assert.Equal(t, "test@test.com", user.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_ForgotPassword_NonexistentEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Should not reveal that email doesn't exist
|
||||
code, user, err := service.ForgotPassword("nonexistent@test.com")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, code)
|
||||
assert.Nil(t, user)
|
||||
}
|
||||
|
||||
// === ResetPassword ===
|
||||
|
||||
func TestAuthService_ResetPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Forgot password
|
||||
_, _, err = service.ForgotPassword("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify reset code to get the token
|
||||
resetToken, err := service.VerifyResetCode("test@test.com", "123456")
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resetToken)
|
||||
|
||||
// Reset password
|
||||
err = service.ResetPassword(resetToken, "NewPassword123")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Login with new password
|
||||
loginReq := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "NewPassword123",
|
||||
}
|
||||
loginResp, err := service.Login(loginReq)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, loginResp.Token)
|
||||
}
|
||||
|
||||
func TestAuthService_ResetPassword_InvalidToken(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
err := service.ResetPassword("invalid-token", "NewPassword123")
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.invalid_reset_token")
|
||||
}
|
||||
|
||||
// === Logout ===
|
||||
|
||||
func TestAuthService_Logout(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
// Login first
|
||||
loginReq := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "Password123",
|
||||
}
|
||||
loginResp, err := service.Login(loginReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Logout
|
||||
err = service.Logout(loginResp.Token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should be deleted — refreshing should fail
|
||||
_, err = service.RefreshToken(loginResp.Token, user.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === DeleteAccount ===
|
||||
|
||||
func TestAuthService_DeleteAccount_EmailAuth(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
// Register
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
password := "Password123"
|
||||
_, err = service.DeleteAccount(user.ID, &password, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_WrongPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongPassword := "WrongPassword1"
|
||||
_, err = service.DeleteAccount(user.ID, &wrongPassword, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_NoPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.DeleteAccount(user.ID, nil, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||
}
|
||||
|
||||
func TestAuthService_DeleteAccount_UserNotFound(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
password := "Password123"
|
||||
_, err := service.DeleteAccount(99999, &password, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.user_not_found")
|
||||
}
|
||||
|
||||
// === Helper functions ===
|
||||
|
||||
func TestGenerateSixDigitCode(t *testing.T) {
|
||||
code := generateSixDigitCode()
|
||||
assert.Len(t, code, 6)
|
||||
// Should be numeric
|
||||
for _, c := range code {
|
||||
assert.True(t, c >= '0' && c <= '9', "code should contain only digits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateResetToken(t *testing.T) {
|
||||
token := generateResetToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.Len(t, token, 64) // 32 bytes = 64 hex chars
|
||||
}
|
||||
|
||||
func TestGetStringOrEmpty(t *testing.T) {
|
||||
s := "hello"
|
||||
assert.Equal(t, "hello", getStringOrEmpty(&s))
|
||||
assert.Equal(t, "", getStringOrEmpty(nil))
|
||||
}
|
||||
|
||||
func TestIsPrivateRelayEmail(t *testing.T) {
|
||||
assert.True(t, isPrivateRelayEmail("abc@privaterelay.appleid.com"))
|
||||
assert.True(t, isPrivateRelayEmail("ABC@PRIVATERELAY.APPLEID.COM"))
|
||||
assert.False(t, isPrivateRelayEmail("user@gmail.com"))
|
||||
}
|
||||
|
||||
func TestGetEmailFromRequest(t *testing.T) {
|
||||
email := "req@test.com"
|
||||
assert.Equal(t, "req@test.com", getEmailFromRequest(&email, "claims@test.com"))
|
||||
assert.Equal(t, "claims@test.com", getEmailFromRequest(nil, "claims@test.com"))
|
||||
empty := ""
|
||||
assert.Equal(t, "claims@test.com", getEmailFromRequest(&empty, "claims@test.com"))
|
||||
}
|
||||
|
||||
// === getEmailOrDefault ===
|
||||
|
||||
func TestGetEmailOrDefault(t *testing.T) {
|
||||
// Non-empty email returns itself
|
||||
assert.Equal(t, "user@test.com", getEmailOrDefault("user@test.com"))
|
||||
|
||||
// Empty email returns a generated placeholder
|
||||
result := getEmailOrDefault("")
|
||||
assert.Contains(t, result, "@privaterelay.appleid.com")
|
||||
assert.Contains(t, result, "apple_")
|
||||
}
|
||||
|
||||
// === generateUniqueUsername ===
|
||||
|
||||
func TestGenerateUniqueUsername(t *testing.T) {
|
||||
// Normal email generates username from email prefix
|
||||
username := generateUniqueUsername("john@test.com", nil)
|
||||
assert.Contains(t, username, "john_")
|
||||
|
||||
// Private relay email falls back to first name
|
||||
firstName := "Jane"
|
||||
username = generateUniqueUsername("abc@privaterelay.appleid.com", &firstName)
|
||||
assert.Contains(t, username, "jane_")
|
||||
|
||||
// Private relay email and no first name — fallback
|
||||
username = generateUniqueUsername("abc@privaterelay.appleid.com", nil)
|
||||
assert.Contains(t, username, "user_")
|
||||
|
||||
// Empty email with first name
|
||||
firstName2 := "Bob"
|
||||
username = generateUniqueUsername("", &firstName2)
|
||||
assert.Contains(t, username, "bob_")
|
||||
|
||||
// Empty email and no first name
|
||||
username = generateUniqueUsername("", nil)
|
||||
assert.Contains(t, username, "user_")
|
||||
}
|
||||
|
||||
// === generateGoogleUsername ===
|
||||
|
||||
func TestGenerateGoogleUsername(t *testing.T) {
|
||||
// Normal email
|
||||
username := generateGoogleUsername("john@gmail.com", "John")
|
||||
assert.Contains(t, username, "john_")
|
||||
|
||||
// Empty email falls back to first name
|
||||
username = generateGoogleUsername("", "Alice")
|
||||
assert.Contains(t, username, "alice_")
|
||||
|
||||
// Empty email and empty first name — fallback
|
||||
username = generateGoogleUsername("", "")
|
||||
assert.Contains(t, username, "google_")
|
||||
}
|
||||
|
||||
// === Login with empty password ===
|
||||
|
||||
func TestAuthService_Login_EmptyPassword(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
|
||||
req := &requests.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "",
|
||||
}
|
||||
|
||||
_, err := service.Login(req)
|
||||
testutil.AssertAppError(t, err, http.StatusUnauthorized, "error.invalid_credentials")
|
||||
}
|
||||
|
||||
// === ForgotPassword rate limiting ===
|
||||
|
||||
func TestAuthService_ForgotPassword_RateLimit(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make max allowed reset requests (3 based on setup)
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := service.ForgotPassword("test@test.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// The 4th should be rate limited
|
||||
_, _, err = service.ForgotPassword("test@test.com")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === VerifyResetCode with wrong code ===
|
||||
|
||||
func TestAuthService_VerifyResetCode_WrongCode(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = service.ForgotPassword("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code but with debug mode, "123456" works, "000000" should fail
|
||||
_, err = service.VerifyResetCode("test@test.com", "000000")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === VerifyResetCode with nonexistent email ===
|
||||
|
||||
func TestAuthService_VerifyResetCode_NonexistentEmail(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
_, err := service.VerifyResetCode("nonexistent@test.com", "123456")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === UpdateProfile — change email to new email ===
|
||||
|
||||
func TestAuthService_UpdateProfile_ChangeEmail(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@test.com", "Password123")
|
||||
userRepo.GetOrCreateProfile(user.ID)
|
||||
|
||||
newEmail := "newemail@test.com"
|
||||
req := &requests.UpdateProfileRequest{
|
||||
Email: &newEmail,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateProfile(user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newemail@test.com", resp.Email)
|
||||
}
|
||||
|
||||
// === DeleteAccount — empty password string ===
|
||||
|
||||
func TestAuthService_DeleteAccount_EmptyPassword(t *testing.T) {
|
||||
service, _ := setupAuthService(t)
|
||||
|
||||
registerReq := &requests.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Email: "test@test.com",
|
||||
Password: "Password123",
|
||||
}
|
||||
_, _, err := service.Register(registerReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := service.userRepo.FindByEmail("test@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
emptyPw := ""
|
||||
_, err = service.DeleteAccount(user.ID, &emptyPw, nil)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.password_required")
|
||||
}
|
||||
|
||||
// === SetNotificationRepository ===
|
||||
|
||||
func TestAuthService_SetNotificationRepository(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
notifRepo := repositories.NewNotificationRepository(db)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{SecretKey: "test-secret"},
|
||||
}
|
||||
service := NewAuthService(userRepo, cfg)
|
||||
assert.Nil(t, service.notificationRepo)
|
||||
|
||||
service.SetNotificationRepository(notifRepo)
|
||||
assert.NotNil(t, service.notificationRepo)
|
||||
}
|
||||
|
||||
// === Register creates profile and notification preferences ===
|
||||
|
||||
func TestAuthService_Register_CreatesProfile(t *testing.T) {
|
||||
service, userRepo := setupAuthService(t)
|
||||
|
||||
req := &requests.RegisterRequest{
|
||||
Username: "profileuser",
|
||||
Email: "profile@test.com",
|
||||
Password: "Password123",
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
}
|
||||
|
||||
resp, _, err := service.Register(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "profileuser", resp.User.Username)
|
||||
|
||||
// Profile should exist
|
||||
profile, err := userRepo.GetOrCreateProfile(resp.User.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, profile)
|
||||
}
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
@@ -20,6 +22,420 @@ func setupContractorService(t *testing.T) (*ContractorService, *repositories.Con
|
||||
return service, contractorRepo, residenceRepo
|
||||
}
|
||||
|
||||
// === CreateContractor ===
|
||||
|
||||
func TestContractorService_CreateContractor(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
req := &requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Bob's Plumbing",
|
||||
Phone: "555-1234",
|
||||
Email: "bob@plumbing.com",
|
||||
}
|
||||
|
||||
resp, err := service.CreateContractor(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, "Bob's Plumbing", resp.Name)
|
||||
assert.Equal(t, "555-1234", resp.Phone)
|
||||
assert.Equal(t, "bob@plumbing.com", resp.Email)
|
||||
}
|
||||
|
||||
func TestContractorService_CreateContractor_Personal(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// No residence ID - personal contractor
|
||||
req := &requests.CreateContractorRequest{
|
||||
Name: "Personal Handyman",
|
||||
}
|
||||
|
||||
resp, err := service.CreateContractor(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Personal Handyman", resp.Name)
|
||||
}
|
||||
|
||||
func TestContractorService_CreateContractor_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
req := &requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Unauthorized Contractor",
|
||||
}
|
||||
|
||||
_, err := service.CreateContractor(req, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
func TestContractorService_CreateContractor_WithFavorite(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
isFav := true
|
||||
req := &requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Fav Plumber",
|
||||
IsFavorite: &isFav,
|
||||
}
|
||||
|
||||
resp, err := service.CreateContractor(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsFavorite)
|
||||
}
|
||||
|
||||
// === GetContractor ===
|
||||
|
||||
func TestContractorService_GetContractor(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
resp, err := service.GetContractor(contractor.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, contractor.ID, resp.ID)
|
||||
assert.Equal(t, "Test Contractor", resp.Name)
|
||||
}
|
||||
|
||||
func TestContractorService_GetContractor_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.GetContractor(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||
}
|
||||
|
||||
func TestContractorService_GetContractor_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||
|
||||
_, err := service.GetContractor(contractor.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||
}
|
||||
|
||||
func TestContractorService_GetContractor_SharedUserHasAccess(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Shared Contractor")
|
||||
|
||||
resp, err := service.GetContractor(contractor.ID, shared.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Shared Contractor", resp.Name)
|
||||
}
|
||||
|
||||
// === ListContractors ===
|
||||
|
||||
func TestContractorService_ListContractors(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 1")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor 2")
|
||||
|
||||
resp, err := service.ListContractors(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 2)
|
||||
}
|
||||
|
||||
// === DeleteContractor ===
|
||||
|
||||
func TestContractorService_DeleteContractor(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "To Delete")
|
||||
|
||||
err := service.DeleteContractor(contractor.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not be found after deletion
|
||||
_, err = service.GetContractor(contractor.ID, user.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestContractorService_DeleteContractor_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
err := service.DeleteContractor(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||
}
|
||||
|
||||
func TestContractorService_DeleteContractor_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||
|
||||
err := service.DeleteContractor(contractor.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||
}
|
||||
|
||||
// === ToggleFavorite ===
|
||||
|
||||
func TestContractorService_ToggleFavorite(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
// Initially not favorite
|
||||
resp, err := service.GetContractor(contractor.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsFavorite)
|
||||
|
||||
// Toggle to favorite
|
||||
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsFavorite)
|
||||
|
||||
// Toggle back
|
||||
resp, err = service.ToggleFavorite(contractor.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsFavorite)
|
||||
}
|
||||
|
||||
func TestContractorService_ToggleFavorite_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.ToggleFavorite(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||
}
|
||||
|
||||
func TestContractorService_ToggleFavorite_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||
|
||||
_, err := service.ToggleFavorite(contractor.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||
}
|
||||
|
||||
// === ListContractorsByResidence ===
|
||||
|
||||
func TestContractorService_ListContractorsByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor A")
|
||||
testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Contractor B")
|
||||
|
||||
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 2)
|
||||
}
|
||||
|
||||
func TestContractorService_ListContractorsByResidence_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
_, err := service.ListContractorsByResidence(residence.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
// === GetContractorTasks ===
|
||||
|
||||
func TestContractorService_GetContractorTasks_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.GetContractorTasks(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||
}
|
||||
|
||||
func TestContractorService_GetContractorTasks_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||
|
||||
_, err := service.GetContractorTasks(contractor.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||
}
|
||||
|
||||
func TestContractorService_GetContractorTasks_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
resp, err := service.GetContractorTasks(contractor.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
// === GetSpecialties ===
|
||||
|
||||
func TestContractorService_GetSpecialties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
resp, err := service.GetSpecialties()
|
||||
require.NoError(t, err)
|
||||
// SeedLookupData creates 4 specialties
|
||||
assert.Len(t, resp, 4)
|
||||
}
|
||||
|
||||
// === UpdateContractor ===
|
||||
|
||||
func TestContractorService_UpdateContractor_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
newName := "Won't Work"
|
||||
req := &requests.UpdateContractorRequest{Name: &newName}
|
||||
|
||||
_, err := service.UpdateContractor(9999, user.ID, req)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.contractor_not_found")
|
||||
}
|
||||
|
||||
func TestContractorService_UpdateContractor_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, owner.ID, "Private Contractor")
|
||||
|
||||
newName := "Hacked"
|
||||
req := &requests.UpdateContractorRequest{Name: &newName}
|
||||
|
||||
_, err := service.UpdateContractor(contractor.ID, other.ID, req)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||
}
|
||||
|
||||
func TestUpdateContractor_CrossResidence_Returns403(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
@@ -96,3 +512,171 @@ func TestUpdateContractor_RemoveResidence_Succeeds(t *testing.T) {
|
||||
require.NoError(t, err, "should allow removing residence association")
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
// === UpdateContractor — partial update multiple fields ===
|
||||
|
||||
func TestContractorService_UpdateContractor_PartialUpdate(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Original Name")
|
||||
|
||||
newName := "Updated Plumber"
|
||||
newPhone := "555-9999"
|
||||
newEmail := "new@plumber.com"
|
||||
newCompany := "Best Plumbing"
|
||||
newWebsite := "https://bestplumbing.com"
|
||||
newNotes := "Great work"
|
||||
newStreet := "456 Plumber Ave"
|
||||
newCity := "Dallas"
|
||||
newState := "TX"
|
||||
newPostal := "75001"
|
||||
rating := 5.0
|
||||
isFav := true
|
||||
|
||||
req := &requests.UpdateContractorRequest{
|
||||
Name: &newName,
|
||||
Phone: &newPhone,
|
||||
Email: &newEmail,
|
||||
Company: &newCompany,
|
||||
Website: &newWebsite,
|
||||
Notes: &newNotes,
|
||||
StreetAddress: &newStreet,
|
||||
City: &newCity,
|
||||
StateProvince: &newState,
|
||||
PostalCode: &newPostal,
|
||||
Rating: &rating,
|
||||
IsFavorite: &isFav,
|
||||
ResidenceID: &residence.ID,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Plumber", resp.Name)
|
||||
assert.Equal(t, "555-9999", resp.Phone)
|
||||
assert.Equal(t, "new@plumber.com", resp.Email)
|
||||
assert.True(t, resp.IsFavorite)
|
||||
}
|
||||
|
||||
// === UpdateContractor — with specialties ===
|
||||
|
||||
func TestContractorService_UpdateContractor_WithSpecialties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
contractor := testutil.CreateTestContractor(t, db, residence.ID, user.ID, "Test Contractor")
|
||||
|
||||
// Get specialty IDs from seeded data
|
||||
var specialties []models.ContractorSpecialty
|
||||
err := db.Find(&specialties).Error
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, specialties)
|
||||
|
||||
specialtyIDs := []uint{specialties[0].ID, specialties[1].ID}
|
||||
req := &requests.UpdateContractorRequest{
|
||||
SpecialtyIDs: specialtyIDs,
|
||||
ResidenceID: &residence.ID,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateContractor(contractor.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
// === CreateContractor — with specialties ===
|
||||
|
||||
func TestContractorService_CreateContractor_WithSpecialties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
var specialties []models.ContractorSpecialty
|
||||
err := db.Find(&specialties).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &requests.CreateContractorRequest{
|
||||
ResidenceID: &residence.ID,
|
||||
Name: "Specialized Plumber",
|
||||
SpecialtyIDs: []uint{specialties[0].ID},
|
||||
}
|
||||
|
||||
resp, err := service.CreateContractor(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Specialized Plumber", resp.Name)
|
||||
}
|
||||
|
||||
// === ListContractors — empty result ===
|
||||
|
||||
func TestContractorService_ListContractors_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
// No residence, no contractors
|
||||
resp, err := service.ListContractors(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
// === ListContractorsByResidence — empty result ===
|
||||
|
||||
func TestContractorService_ListContractorsByResidence_Empty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Empty House")
|
||||
|
||||
resp, err := service.ListContractorsByResidence(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
// === Personal contractor access — creator has access, others don't ===
|
||||
|
||||
func TestContractorService_PersonalContractor_OnlyCreatorAccess(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewContractorService(contractorRepo, residenceRepo)
|
||||
|
||||
creator := testutil.CreateTestUser(t, db, "creator", "creator@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
|
||||
// Create personal contractor (no residence)
|
||||
req := &requests.CreateContractorRequest{
|
||||
Name: "Personal Plumber",
|
||||
}
|
||||
resp, err := service.CreateContractor(req, creator.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creator can access
|
||||
_, err = service.GetContractor(resp.ID, creator.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other user cannot
|
||||
_, err = service.GetContractor(resp.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.contractor_access_denied")
|
||||
}
|
||||
|
||||
764
internal/services/document_service_test.go
Normal file
764
internal/services/document_service_test.go
Normal file
@@ -0,0 +1,764 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/dto/requests"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
func setupDocumentService(t *testing.T) (*DocumentService, *repositories.DocumentRepository, *repositories.ResidenceRepository) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
return service, documentRepo, residenceRepo
|
||||
}
|
||||
|
||||
// === CreateDocument ===
|
||||
|
||||
func TestDocumentService_CreateDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
req := &requests.CreateDocumentRequest{
|
||||
ResidenceID: residence.ID,
|
||||
Title: "Furnace Manual",
|
||||
Description: "Installation manual for the furnace",
|
||||
DocumentType: models.DocumentTypeManual,
|
||||
FileURL: "https://example.com/manual.pdf",
|
||||
FileName: "manual.pdf",
|
||||
}
|
||||
|
||||
resp, err := service.CreateDocument(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, "Furnace Manual", resp.Title)
|
||||
assert.Equal(t, models.DocumentTypeManual, resp.DocumentType)
|
||||
}
|
||||
|
||||
func TestDocumentService_CreateDocument_DefaultType(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
req := &requests.CreateDocumentRequest{
|
||||
ResidenceID: residence.ID,
|
||||
Title: "Some Document",
|
||||
// DocumentType not set — should default to "general"
|
||||
}
|
||||
|
||||
resp, err := service.CreateDocument(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.DocumentTypeGeneral, resp.DocumentType)
|
||||
}
|
||||
|
||||
func TestDocumentService_CreateDocument_WithImages(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
req := &requests.CreateDocumentRequest{
|
||||
ResidenceID: residence.ID,
|
||||
Title: "Receipt with photos",
|
||||
ImageURLs: []string{"https://example.com/img1.jpg", "https://example.com/img2.jpg"},
|
||||
}
|
||||
|
||||
resp, err := service.CreateDocument(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, "Receipt with photos", resp.Title)
|
||||
}
|
||||
|
||||
func TestDocumentService_CreateDocument_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
req := &requests.CreateDocumentRequest{
|
||||
ResidenceID: residence.ID,
|
||||
Title: "Unauthorized Doc",
|
||||
}
|
||||
|
||||
_, err := service.CreateDocument(req, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
// === GetDocument ===
|
||||
|
||||
func TestDocumentService_GetDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||
|
||||
resp, err := service.GetDocument(doc.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, doc.ID, resp.ID)
|
||||
assert.Equal(t, "Test Doc", resp.Title)
|
||||
}
|
||||
|
||||
func TestDocumentService_GetDocument_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.GetDocument(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_GetDocument_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
_, err := service.GetDocument(doc.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
// === UpdateDocument ===
|
||||
|
||||
func TestDocumentService_UpdateDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original Title")
|
||||
|
||||
newTitle := "Updated Title"
|
||||
newDesc := "Updated description"
|
||||
req := &requests.UpdateDocumentRequest{
|
||||
Title: &newTitle,
|
||||
Description: &newDesc,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated Title", resp.Title)
|
||||
assert.Equal(t, "Updated description", resp.Description)
|
||||
}
|
||||
|
||||
func TestDocumentService_UpdateDocument_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
newTitle := "Won't Work"
|
||||
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
||||
|
||||
_, err := service.UpdateDocument(9999, user.ID, req)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_UpdateDocument_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
newTitle := "Hacked"
|
||||
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
||||
|
||||
_, err := service.UpdateDocument(doc.ID, other.ID, req)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
func TestDocumentService_UpdateDocument_ChangeType(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "My Receipt")
|
||||
|
||||
newType := models.DocumentTypeWarranty
|
||||
req := &requests.UpdateDocumentRequest{DocumentType: &newType}
|
||||
|
||||
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
|
||||
}
|
||||
|
||||
// === DeleteDocument ===
|
||||
|
||||
func TestDocumentService_DeleteDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Delete")
|
||||
|
||||
err := service.DeleteDocument(doc.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should not be found after deletion
|
||||
_, err = service.GetDocument(doc.ID, user.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDocumentService_DeleteDocument_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
err := service.DeleteDocument(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_DeleteDocument_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
err := service.DeleteDocument(doc.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
// === ListDocuments ===
|
||||
|
||||
func TestDocumentService_ListDocuments(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
|
||||
|
||||
resp, err := service.ListDocuments(user.ID, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 2)
|
||||
}
|
||||
|
||||
func TestDocumentService_ListDocuments_NoResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||
|
||||
resp, err := service.ListDocuments(user.ID, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
func TestDocumentService_ListDocuments_FilterByResidence(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence1 := testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
residence2 := testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
testutil.CreateTestDocument(t, db, residence1.ID, user.ID, "Doc A")
|
||||
testutil.CreateTestDocument(t, db, residence2.ID, user.ID, "Doc B")
|
||||
|
||||
filter := &repositories.DocumentFilter{ResidenceID: &residence1.ID}
|
||||
resp, err := service.ListDocuments(user.ID, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 1)
|
||||
assert.Equal(t, "Doc A", resp[0].Title)
|
||||
}
|
||||
|
||||
func TestDocumentService_ListDocuments_FilterByResidence_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Owner House")
|
||||
// other has their own residence so they have at least one
|
||||
testutil.CreateTestResidence(t, db, other.ID, "Other House")
|
||||
|
||||
filter := &repositories.DocumentFilter{ResidenceID: &residence.ID}
|
||||
_, err := service.ListDocuments(other.ID, filter)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
// === ListWarranties ===
|
||||
|
||||
func TestDocumentService_ListWarranties(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a warranty doc directly
|
||||
warrantyDoc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "HVAC Warranty",
|
||||
DocumentType: "warranty",
|
||||
FileURL: "https://example.com/warranty.pdf",
|
||||
}
|
||||
err := db.Create(warrantyDoc).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a general doc
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
||||
|
||||
resp, err := service.ListWarranties(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 1)
|
||||
assert.Equal(t, "HVAC Warranty", resp[0].Title)
|
||||
}
|
||||
|
||||
func TestDocumentService_ListWarranties_NoResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||
|
||||
resp, err := service.ListWarranties(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
// === DeactivateDocument ===
|
||||
|
||||
func TestDocumentService_DeactivateDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Deactivate")
|
||||
|
||||
resp, err := service.DeactivateDocument(doc.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.IsActive)
|
||||
}
|
||||
|
||||
func TestDocumentService_DeactivateDocument_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.DeactivateDocument(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_DeactivateDocument_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
_, err := service.DeactivateDocument(doc.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
// === UploadDocumentImage ===
|
||||
|
||||
func TestDocumentService_UploadDocumentImage(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||
|
||||
resp, err := service.UploadDocumentImage(doc.ID, user.ID, "https://example.com/photo.jpg", "Front view")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestDocumentService_UploadDocumentImage_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.UploadDocumentImage(9999, user.ID, "https://example.com/photo.jpg", "")
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_UploadDocumentImage_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
_, err := service.UploadDocumentImage(doc.ID, other.ID, "https://example.com/photo.jpg", "")
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
// === DeleteDocumentImage ===
|
||||
|
||||
func TestDocumentService_DeleteDocumentImage(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||
|
||||
// Create an image
|
||||
img := &models.DocumentImage{
|
||||
DocumentID: doc.ID,
|
||||
ImageURL: "https://example.com/photo.jpg",
|
||||
}
|
||||
err := db.Create(img).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := service.DeleteDocumentImage(doc.ID, img.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestDocumentService_DeleteDocumentImage_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Test Doc")
|
||||
|
||||
_, err := service.DeleteDocumentImage(doc.ID, 9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_DeleteDocumentImage_WrongDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc1 := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 1")
|
||||
doc2 := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Doc 2")
|
||||
|
||||
// Create an image on doc1
|
||||
img := &models.DocumentImage{
|
||||
DocumentID: doc1.ID,
|
||||
ImageURL: "https://example.com/photo.jpg",
|
||||
}
|
||||
err := db.Create(img).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to delete the image specifying doc2
|
||||
_, err = service.DeleteDocumentImage(doc2.ID, img.ID, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_image_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_DeleteDocumentImage_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
img := &models.DocumentImage{
|
||||
DocumentID: doc.ID,
|
||||
ImageURL: "https://example.com/photo.jpg",
|
||||
}
|
||||
err := db.Create(img).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.DeleteDocumentImage(doc.ID, img.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
// === SharedUser access ===
|
||||
|
||||
func TestDocumentService_GetDocument_SharedUserHasAccess(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
||||
|
||||
resp, err := service.GetDocument(doc.ID, shared.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Shared Doc", resp.Title)
|
||||
}
|
||||
|
||||
// === ActivateDocument ===
|
||||
|
||||
func TestDocumentService_ActivateDocument(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "To Activate")
|
||||
|
||||
// Deactivate first
|
||||
_, err := service.DeactivateDocument(doc.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now activate
|
||||
resp, err := service.ActivateDocument(doc.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.IsActive)
|
||||
}
|
||||
|
||||
func TestDocumentService_ActivateDocument_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
_, err := service.ActivateDocument(9999, user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.document_not_found")
|
||||
}
|
||||
|
||||
func TestDocumentService_ActivateDocument_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Private Doc")
|
||||
|
||||
_, err := service.ActivateDocument(doc.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.document_access_denied")
|
||||
}
|
||||
|
||||
// === CreateDocument — with empty image URL in array (should skip) ===
|
||||
|
||||
func TestDocumentService_CreateDocument_WithEmptyImageURL(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
req := &requests.CreateDocumentRequest{
|
||||
ResidenceID: residence.ID,
|
||||
Title: "Doc with empty images",
|
||||
ImageURLs: []string{"", "https://example.com/img.jpg", ""},
|
||||
}
|
||||
|
||||
resp, err := service.CreateDocument(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
}
|
||||
|
||||
// === UpdateDocument — all optional fields ===
|
||||
|
||||
func TestDocumentService_UpdateDocument_AllFields(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, user.ID, "Original")
|
||||
|
||||
newTitle := "Updated"
|
||||
newDesc := "New description"
|
||||
newType := models.DocumentTypeWarranty
|
||||
newFileURL := "https://example.com/new.pdf"
|
||||
newFileName := "new.pdf"
|
||||
newMimeType := "application/pdf"
|
||||
newVendor := "HVAC Corp"
|
||||
newSerial := "SN12345"
|
||||
newModel := "Model X"
|
||||
size := int64(1024)
|
||||
|
||||
req := &requests.UpdateDocumentRequest{
|
||||
Title: &newTitle,
|
||||
Description: &newDesc,
|
||||
DocumentType: &newType,
|
||||
FileURL: &newFileURL,
|
||||
FileName: &newFileName,
|
||||
FileSize: &size,
|
||||
MimeType: &newMimeType,
|
||||
Vendor: &newVendor,
|
||||
SerialNumber: &newSerial,
|
||||
ModelNumber: &newModel,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateDocument(doc.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated", resp.Title)
|
||||
assert.Equal(t, "New description", resp.Description)
|
||||
assert.Equal(t, models.DocumentTypeWarranty, resp.DocumentType)
|
||||
}
|
||||
|
||||
// === ListDocuments — filter by document type ===
|
||||
|
||||
func TestDocumentService_ListDocuments_FilterByType(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create a warranty doc
|
||||
warrantyDoc := &models.Document{
|
||||
ResidenceID: residence.ID,
|
||||
CreatedByID: user.ID,
|
||||
Title: "Warranty Doc",
|
||||
DocumentType: models.DocumentTypeWarranty,
|
||||
FileURL: "https://example.com/w.pdf",
|
||||
}
|
||||
err := db.Create(warrantyDoc).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a general doc
|
||||
testutil.CreateTestDocument(t, db, residence.ID, user.ID, "General Doc")
|
||||
|
||||
filter := &repositories.DocumentFilter{DocumentType: string(models.DocumentTypeWarranty)}
|
||||
resp, err := service.ListDocuments(user.ID, filter)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp, 1)
|
||||
assert.Equal(t, "Warranty Doc", resp[0].Title)
|
||||
}
|
||||
|
||||
// === Shared user can update/delete documents ===
|
||||
|
||||
func TestDocumentService_SharedUser_CanUpdate(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
||||
|
||||
newTitle := "Updated by shared user"
|
||||
req := &requests.UpdateDocumentRequest{Title: &newTitle}
|
||||
resp, err := service.UpdateDocument(doc.ID, shared.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Updated by shared user", resp.Title)
|
||||
}
|
||||
|
||||
func TestDocumentService_SharedUser_CanDelete(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
service := NewDocumentService(documentRepo, residenceRepo)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
doc := testutil.CreateTestDocument(t, db, residence.ID, owner.ID, "Shared Doc")
|
||||
|
||||
err := service.DeleteDocument(doc.ID, shared.ID)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -456,3 +456,631 @@ func TestCreateResidence_ProTier_AllowsMore(t *testing.T) {
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
// === GetMyResidences ===
|
||||
|
||||
func TestResidenceService_GetMyResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
resp, err := service.GetMyResidences(user.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Residences, 2)
|
||||
}
|
||||
|
||||
func TestResidenceService_GetMyResidences_NoResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||
|
||||
resp, err := service.GetMyResidences(user.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp.Residences)
|
||||
}
|
||||
|
||||
// === GetSummary ===
|
||||
|
||||
func TestResidenceService_GetSummary(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
resp, err := service.GetSummary(user.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, resp.TotalResidences)
|
||||
}
|
||||
|
||||
func TestResidenceService_GetSummary_NoResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||
|
||||
resp, err := service.GetSummary(user.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, resp.TotalResidences)
|
||||
}
|
||||
|
||||
// === GetShareCode ===
|
||||
|
||||
func TestResidenceService_GetShareCode_NoActiveCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
resp, err := service.GetShareCode(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp) // No active code
|
||||
}
|
||||
|
||||
func TestResidenceService_GetShareCode_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
_, err := service.GetShareCode(residence.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
// === GenerateShareCode ===
|
||||
|
||||
func TestResidenceService_GenerateShareCode_NotOwner(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
_, err := service.GenerateShareCode(residence.ID, shared.ID, 24)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
|
||||
}
|
||||
|
||||
func TestResidenceService_GenerateShareCode_DefaultExpiry(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Pass 0 hours — should default to 24
|
||||
resp, err := service.GenerateShareCode(residence.ID, user.ID, 0)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.ShareCode.Code)
|
||||
}
|
||||
|
||||
// === GenerateSharePackage ===
|
||||
|
||||
func TestResidenceService_GenerateSharePackage(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
resp, err := service.GenerateSharePackage(residence.ID, user.ID, 48)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.ShareCode)
|
||||
assert.Equal(t, "Test House", resp.ResidenceName)
|
||||
assert.Equal(t, "owner@test.com", resp.SharedBy)
|
||||
}
|
||||
|
||||
func TestResidenceService_GenerateSharePackage_NotOwner(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
_, err := service.GenerateSharePackage(residence.ID, shared.ID, 24)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
|
||||
}
|
||||
|
||||
// === JoinWithCode ===
|
||||
|
||||
func TestResidenceService_JoinWithCode_InvalidCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "user", "user@test.com", "Password123")
|
||||
|
||||
_, err := service.JoinWithCode("BADCODE", user.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusNotFound, "error.share_code_invalid")
|
||||
}
|
||||
|
||||
// === RemoveUser ===
|
||||
|
||||
func TestResidenceService_RemoveUser_NotOwner(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
// shared user tries to remove other — should fail because shared is not owner
|
||||
err := service.RemoveUser(residence.ID, other.ID, shared.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.not_residence_owner")
|
||||
}
|
||||
|
||||
// === GetResidenceUsers ===
|
||||
|
||||
func TestResidenceService_GetResidenceUsers_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
_, err := service.GetResidenceUsers(residence.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
// === GetResidenceTypes ===
|
||||
|
||||
func TestResidenceService_GetResidenceTypes(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
resp, err := service.GetResidenceTypes()
|
||||
require.NoError(t, err)
|
||||
// SeedLookupData creates 4 residence types
|
||||
assert.Len(t, resp, 4)
|
||||
}
|
||||
|
||||
// === UpdateResidence with home profile fields ===
|
||||
|
||||
func TestResidenceService_UpdateResidence_HomeProfileFields(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
hasPool := true
|
||||
hasGarage := true
|
||||
heatingType := "Forced Air"
|
||||
req := &requests.UpdateResidenceRequest{
|
||||
HasPool: &hasPool,
|
||||
HasGarage: &hasGarage,
|
||||
HeatingType: &heatingType,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateResidence(residence.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Data.HasPool)
|
||||
assert.True(t, resp.Data.HasGarage)
|
||||
}
|
||||
|
||||
// === CreateResidence with home profile fields ===
|
||||
|
||||
func TestResidenceService_CreateResidence_HomeProfileFields(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
hasPool := true
|
||||
hasSeptic := true
|
||||
req := &requests.CreateResidenceRequest{
|
||||
Name: "New House",
|
||||
StreetAddress: "456 Oak St",
|
||||
City: "Dallas",
|
||||
StateProvince: "TX",
|
||||
PostalCode: "75201",
|
||||
HasPool: &hasPool,
|
||||
HasSeptic: &hasSeptic,
|
||||
}
|
||||
|
||||
resp, err := service.CreateResidence(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Data.HasPool)
|
||||
assert.True(t, resp.Data.HasSeptic)
|
||||
}
|
||||
|
||||
// === Shared user GetResidence ===
|
||||
|
||||
func TestResidenceService_GetResidence_SharedUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
shared := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, shared.ID)
|
||||
|
||||
resp, err := service.GetResidence(residence.ID, shared.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test House", resp.Name)
|
||||
}
|
||||
|
||||
// === GetMyResidences with task repo (overdue counts + completion summaries) ===
|
||||
|
||||
func TestResidenceService_GetMyResidences_WithTaskRepo(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
service.SetTaskRepository(taskRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 1")
|
||||
testutil.CreateTestResidence(t, db, user.ID, "House 2")
|
||||
|
||||
resp, err := service.GetMyResidences(user.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Residences, 2)
|
||||
}
|
||||
|
||||
// === GetResidence with task repo (completion summary) ===
|
||||
|
||||
func TestResidenceService_GetResidence_WithTaskRepo(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
service.SetTaskRepository(taskRepo)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
resp, err := service.GetResidence(residence.ID, user.ID, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test House", resp.Name)
|
||||
}
|
||||
|
||||
// === GenerateShareCode with negative expiry defaults to 24 ===
|
||||
|
||||
func TestResidenceService_GenerateShareCode_NegativeExpiry(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
resp, err := service.GenerateShareCode(residence.ID, user.ID, -5)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.ShareCode.Code)
|
||||
}
|
||||
|
||||
// === GenerateSharePackage with default expiry ===
|
||||
|
||||
func TestResidenceService_GenerateSharePackage_DefaultExpiry(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Pass 0 hours — should default to 24
|
||||
resp, err := service.GenerateSharePackage(residence.ID, user.ID, 0)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.ShareCode)
|
||||
assert.Equal(t, "Test House", resp.ResidenceName)
|
||||
}
|
||||
|
||||
// === RemoveUser — trying to remove the owner by a different owner ID ===
|
||||
|
||||
func TestResidenceService_RemoveUser_OwnerViaResidenceOwnerID(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
sharedUser := testutil.CreateTestUser(t, db, "shared", "shared@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
residenceRepo.AddUser(residence.ID, sharedUser.ID)
|
||||
|
||||
// Try removing the owner (by residence.OwnerID) — even though requestingUserID != userIDToRemove
|
||||
// The second check (userIDToRemove == residence.OwnerID) should catch this
|
||||
err := service.RemoveUser(residence.ID, owner.ID, owner.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusBadRequest, "error.cannot_remove_owner")
|
||||
}
|
||||
|
||||
// === GenerateTasksReport ===
|
||||
|
||||
func TestResidenceService_GenerateTasksReport(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
testutil.SeedLookupData(t, db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Create some tasks
|
||||
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 1")
|
||||
testutil.CreateTestTask(t, db, residence.ID, user.ID, "Task 2")
|
||||
|
||||
report, err := service.GenerateTasksReport(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, residence.ID, report.ResidenceID)
|
||||
assert.Equal(t, "Test House", report.ResidenceName)
|
||||
assert.Equal(t, 2, report.TotalTasks)
|
||||
}
|
||||
|
||||
func TestResidenceService_GenerateTasksReport_AccessDenied(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
owner := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
other := testutil.CreateTestUser(t, db, "other", "other@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, owner.ID, "Test House")
|
||||
|
||||
_, err := service.GenerateTasksReport(residence.ID, other.ID)
|
||||
testutil.AssertAppError(t, err, http.StatusForbidden, "error.residence_access_denied")
|
||||
}
|
||||
|
||||
func TestResidenceService_GenerateTasksReport_NotFound(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Non-existent residence — user has no access
|
||||
_, err := service.GenerateTasksReport(9999, user.ID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// === GetShareCode with active code ===
|
||||
|
||||
func TestResidenceService_GetShareCode_WithActiveCode(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
|
||||
// Generate a share code first
|
||||
_, err := service.GenerateShareCode(residence.ID, user.ID, 24)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now get the active code
|
||||
resp, err := service.GetShareCode(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.NotEmpty(t, resp.Code)
|
||||
}
|
||||
|
||||
// === CreateResidence with all boolean fields ===
|
||||
|
||||
func TestResidenceService_CreateResidence_AllBooleanFields(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
hasPool := true
|
||||
hasSprinkler := true
|
||||
hasSeptic := true
|
||||
hasFireplace := true
|
||||
hasGarage := true
|
||||
hasBasement := true
|
||||
hasAttic := true
|
||||
|
||||
req := &requests.CreateResidenceRequest{
|
||||
Name: "Full Feature House",
|
||||
StreetAddress: "789 Full St",
|
||||
City: "Austin",
|
||||
StateProvince: "TX",
|
||||
PostalCode: "78701",
|
||||
HasPool: &hasPool,
|
||||
HasSprinklerSystem: &hasSprinkler,
|
||||
HasSeptic: &hasSeptic,
|
||||
HasFireplace: &hasFireplace,
|
||||
HasGarage: &hasGarage,
|
||||
HasBasement: &hasBasement,
|
||||
HasAttic: &hasAttic,
|
||||
}
|
||||
|
||||
resp, err := service.CreateResidence(req, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, resp.Data.HasPool)
|
||||
assert.True(t, resp.Data.HasSprinklerSystem)
|
||||
assert.True(t, resp.Data.HasSeptic)
|
||||
assert.True(t, resp.Data.HasFireplace)
|
||||
assert.True(t, resp.Data.HasGarage)
|
||||
assert.True(t, resp.Data.HasBasement)
|
||||
assert.True(t, resp.Data.HasAttic)
|
||||
}
|
||||
|
||||
// === UpdateResidence with all optional fields ===
|
||||
|
||||
func TestResidenceService_UpdateResidence_AllOptionalFields(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, db, user.ID, "Original Name")
|
||||
|
||||
newStreet := "456 New St"
|
||||
newApt := "Apt 2B"
|
||||
newState := "CA"
|
||||
newPostal := "90210"
|
||||
newCountry := "Canada"
|
||||
bedrooms := 4
|
||||
bathrooms := decimal.NewFromFloat(3.0)
|
||||
sqft := 3000
|
||||
lotSize := decimal.NewFromFloat(0.5)
|
||||
yearBuilt := 2020
|
||||
newDesc := "Nice house"
|
||||
isPrimary := false
|
||||
hasPool := true
|
||||
hasSprinkler := true
|
||||
hasSeptic := false
|
||||
hasFireplace := true
|
||||
hasGarage := true
|
||||
hasBasement := false
|
||||
hasAttic := true
|
||||
coolingType := "Central AC"
|
||||
waterHeaterType := "Tankless"
|
||||
roofType := "Shingle"
|
||||
exteriorType := "Brick"
|
||||
flooringPrimary := "Hardwood"
|
||||
landscapingType := "Xeriscape"
|
||||
|
||||
req := &requests.UpdateResidenceRequest{
|
||||
StreetAddress: &newStreet,
|
||||
ApartmentUnit: &newApt,
|
||||
StateProvince: &newState,
|
||||
PostalCode: &newPostal,
|
||||
Country: &newCountry,
|
||||
Bedrooms: &bedrooms,
|
||||
Bathrooms: &bathrooms,
|
||||
SquareFootage: &sqft,
|
||||
LotSize: &lotSize,
|
||||
YearBuilt: &yearBuilt,
|
||||
Description: &newDesc,
|
||||
IsPrimary: &isPrimary,
|
||||
HasPool: &hasPool,
|
||||
HasSprinklerSystem: &hasSprinkler,
|
||||
HasSeptic: &hasSeptic,
|
||||
HasFireplace: &hasFireplace,
|
||||
HasGarage: &hasGarage,
|
||||
HasBasement: &hasBasement,
|
||||
HasAttic: &hasAttic,
|
||||
CoolingType: &coolingType,
|
||||
WaterHeaterType: &waterHeaterType,
|
||||
RoofType: &roofType,
|
||||
ExteriorType: &exteriorType,
|
||||
FlooringPrimary: &flooringPrimary,
|
||||
LandscapingType: &landscapingType,
|
||||
}
|
||||
|
||||
resp, err := service.UpdateResidence(residence.ID, user.ID, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "456 New St", resp.Data.StreetAddress)
|
||||
assert.Equal(t, "CA", resp.Data.StateProvince)
|
||||
assert.True(t, resp.Data.HasPool)
|
||||
assert.True(t, resp.Data.HasFireplace)
|
||||
assert.True(t, resp.Data.HasAttic)
|
||||
}
|
||||
|
||||
// === ListResidences with no residences ===
|
||||
|
||||
func TestResidenceService_ListResidences_NoResidences(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "loner", "loner@test.com", "Password123")
|
||||
|
||||
resp, err := service.ListResidences(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp)
|
||||
}
|
||||
|
||||
// === getSummaryForUser returns empty summary ===
|
||||
|
||||
func TestResidenceService_getSummaryForUser_ReturnsEmpty(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
userRepo := repositories.NewUserRepository(db)
|
||||
cfg := &config.Config{}
|
||||
service := NewResidenceService(residenceRepo, userRepo, cfg)
|
||||
|
||||
summary := service.getSummaryForUser(999)
|
||||
assert.Equal(t, 0, summary.TotalResidences)
|
||||
}
|
||||
|
||||
@@ -162,3 +162,179 @@ func TestDelete_NonexistentFile(t *testing.T) {
|
||||
t.Fatalf("Delete should not error for non-existent file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// === isAllowedType ===
|
||||
|
||||
func TestIsAllowedType(t *testing.T) {
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: t.TempDir(),
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg,image/png,application/pdf",
|
||||
}
|
||||
svc := NewStorageServiceForTest(cfg)
|
||||
|
||||
if !svc.isAllowedType("image/jpeg") {
|
||||
t.Fatal("image/jpeg should be allowed")
|
||||
}
|
||||
if !svc.isAllowedType("image/png") {
|
||||
t.Fatal("image/png should be allowed")
|
||||
}
|
||||
if !svc.isAllowedType("application/pdf") {
|
||||
t.Fatal("application/pdf should be allowed")
|
||||
}
|
||||
if svc.isAllowedType("text/html") {
|
||||
t.Fatal("text/html should not be allowed")
|
||||
}
|
||||
if svc.isAllowedType("") {
|
||||
t.Fatal("empty MIME should not be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// === mimeTypesCompatible ===
|
||||
|
||||
func TestMimeTypesCompatible(t *testing.T) {
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: t.TempDir(),
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg",
|
||||
}
|
||||
svc := NewStorageServiceForTest(cfg)
|
||||
|
||||
// Same primary type
|
||||
if !svc.mimeTypesCompatible("image/jpeg", "image/png") {
|
||||
t.Fatal("image/* types should be compatible")
|
||||
}
|
||||
|
||||
// Different primary types
|
||||
if svc.mimeTypesCompatible("image/jpeg", "application/pdf") {
|
||||
t.Fatal("image and application should not be compatible")
|
||||
}
|
||||
|
||||
// Same exact types
|
||||
if !svc.mimeTypesCompatible("application/pdf", "application/octet-stream") {
|
||||
t.Fatal("application/* types should be compatible")
|
||||
}
|
||||
}
|
||||
|
||||
// === getExtensionFromMimeType ===
|
||||
|
||||
func TestGetExtensionFromMimeType(t *testing.T) {
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: t.TempDir(),
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 10 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg",
|
||||
}
|
||||
svc := NewStorageServiceForTest(cfg)
|
||||
|
||||
tests := []struct {
|
||||
mimeType string
|
||||
expected string
|
||||
}{
|
||||
{"image/jpeg", ".jpg"},
|
||||
{"image/png", ".png"},
|
||||
{"image/gif", ".gif"},
|
||||
{"image/webp", ".webp"},
|
||||
{"application/pdf", ".pdf"},
|
||||
{"text/html", ""},
|
||||
{"unknown/type", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := svc.getExtensionFromMimeType(tt.mimeType)
|
||||
if got != tt.expected {
|
||||
t.Fatalf("getExtensionFromMimeType(%q) = %q, want %q", tt.mimeType, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === GetUploadDir ===
|
||||
|
||||
func TestGetUploadDir(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, false)
|
||||
if svc.GetUploadDir() != tmpDir {
|
||||
t.Fatalf("GetUploadDir() = %q, want %q", svc.GetUploadDir(), tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
// === SetEncryptionService ===
|
||||
|
||||
func TestSetEncryptionService(t *testing.T) {
|
||||
svc, _ := setupTestStorage(t, false)
|
||||
|
||||
// Initially no encryption service
|
||||
if svc.encryptionSvc != nil && svc.encryptionSvc.IsEnabled() {
|
||||
t.Fatal("encryption should not be enabled initially in plain mode")
|
||||
}
|
||||
|
||||
encSvc, err := NewEncryptionService(validTestKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
svc.SetEncryptionService(encSvc)
|
||||
|
||||
if svc.encryptionSvc == nil {
|
||||
t.Fatal("encryption service should be set")
|
||||
}
|
||||
}
|
||||
|
||||
// === NewStorageServiceForTest ===
|
||||
|
||||
func TestNewStorageServiceForTest(t *testing.T) {
|
||||
cfg := &config.StorageConfig{
|
||||
UploadDir: "/tmp/test",
|
||||
BaseURL: "/uploads",
|
||||
MaxFileSize: 5 * 1024 * 1024,
|
||||
AllowedTypes: "image/jpeg, image/png, application/pdf",
|
||||
}
|
||||
svc := NewStorageServiceForTest(cfg)
|
||||
|
||||
// Should have 3 allowed types (whitespace trimmed)
|
||||
if !svc.isAllowedType("image/jpeg") {
|
||||
t.Fatal("image/jpeg should be allowed")
|
||||
}
|
||||
if !svc.isAllowedType("image/png") {
|
||||
t.Fatal("image/png should be allowed")
|
||||
}
|
||||
if !svc.isAllowedType("application/pdf") {
|
||||
t.Fatal("application/pdf should be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// === Delete only plain file ===
|
||||
|
||||
func TestDelete_OnlyPlainFile(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, false)
|
||||
|
||||
dir := filepath.Join(tmpDir, "images")
|
||||
os.WriteFile(filepath.Join(dir, "only-plain.jpg"), []byte("plain"), 0644)
|
||||
|
||||
err := svc.Delete("/uploads/images/only-plain.jpg")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "only-plain.jpg")); !os.IsNotExist(err) {
|
||||
t.Fatal("plain file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// === Delete only enc file ===
|
||||
|
||||
func TestDelete_OnlyEncFile(t *testing.T) {
|
||||
svc, tmpDir := setupTestStorage(t, false)
|
||||
|
||||
dir := filepath.Join(tmpDir, "documents")
|
||||
os.WriteFile(filepath.Join(dir, "secret.pdf.enc"), []byte("encrypted"), 0644)
|
||||
|
||||
err := svc.Delete("/uploads/documents/secret.pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "secret.pdf.enc")); !os.IsNotExist(err) {
|
||||
t.Fatal("encrypted file should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,6 +181,107 @@ func TestProcessGooglePurchase_ValidationFails_DoesNotUpgrade(t *testing.T) {
|
||||
assert.Equal(t, models.TierFree, updatedSub.Tier, "User should remain on free tier")
|
||||
}
|
||||
|
||||
// === GetSubscription ===
|
||||
|
||||
func TestSubscriptionService_GetSubscription(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
resp, err := svc.GetSubscription(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "free", resp.Tier)
|
||||
assert.False(t, resp.IsPro)
|
||||
}
|
||||
|
||||
func TestSubscriptionService_GetSubscription_ProUser(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create a pro subscription
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
ExpiresAt: &future,
|
||||
Platform: "ios",
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := svc.GetSubscription(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pro", resp.Tier)
|
||||
assert.True(t, resp.IsPro)
|
||||
assert.True(t, resp.IsActive)
|
||||
}
|
||||
|
||||
// === CancelSubscription ===
|
||||
|
||||
func TestSubscriptionService_CancelSubscription(t *testing.T) {
|
||||
db := testutil.SetupTestDB(t)
|
||||
|
||||
subscriptionRepo := repositories.NewSubscriptionRepository(db)
|
||||
residenceRepo := repositories.NewResidenceRepository(db)
|
||||
taskRepo := repositories.NewTaskRepository(db)
|
||||
contractorRepo := repositories.NewContractorRepository(db)
|
||||
documentRepo := repositories.NewDocumentRepository(db)
|
||||
|
||||
svc := &SubscriptionService{
|
||||
subscriptionRepo: subscriptionRepo,
|
||||
residenceRepo: residenceRepo,
|
||||
taskRepo: taskRepo,
|
||||
contractorRepo: contractorRepo,
|
||||
documentRepo: documentRepo,
|
||||
}
|
||||
|
||||
user := testutil.CreateTestUser(t, db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create a pro subscription with auto_renew
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
sub := &models.UserSubscription{
|
||||
UserID: user.ID,
|
||||
Tier: models.TierPro,
|
||||
ExpiresAt: &future,
|
||||
AutoRenew: true,
|
||||
}
|
||||
err := db.Create(sub).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := svc.CancelSubscription(user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, resp.AutoRenew)
|
||||
}
|
||||
|
||||
func TestIsAlreadyProFromOtherPlatform(t *testing.T) {
|
||||
future := time.Now().UTC().Add(30 * 24 * time.Hour)
|
||||
|
||||
|
||||
@@ -231,3 +231,472 @@ func TestSuggestionService_MultipleConditionsAllMustMatch(t *testing.T) {
|
||||
assert.InDelta(t, expectedScore, resp.Suggestions[0].RelevanceScore, 0.01)
|
||||
assert.Len(t, resp.Suggestions[0].MatchReasons, 3)
|
||||
}
|
||||
|
||||
// === Malformed conditions JSON treated as universal ===
|
||||
|
||||
func TestSuggestionService_MalformedConditions(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Test House")
|
||||
|
||||
// Create template with malformed JSON conditions
|
||||
tmpl := &models.TaskTemplate{
|
||||
Title: "Bad JSON Template",
|
||||
IsActive: true,
|
||||
Conditions: json.RawMessage(`{bad json`),
|
||||
}
|
||||
err := service.db.Create(tmpl).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
// Should be treated as universal
|
||||
assert.Equal(t, baseUniversalScore, resp.Suggestions[0].RelevanceScore)
|
||||
}
|
||||
|
||||
// === Null conditions JSON treated as universal ===
|
||||
|
||||
func TestSuggestionService_NullConditions(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Test House")
|
||||
|
||||
tmpl := &models.TaskTemplate{
|
||||
Title: "Null Conditions",
|
||||
IsActive: true,
|
||||
Conditions: json.RawMessage(`null`),
|
||||
}
|
||||
err := service.db.Create(tmpl).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Equal(t, baseUniversalScore, resp.Suggestions[0].RelevanceScore)
|
||||
}
|
||||
|
||||
// === Template with property_type condition ===
|
||||
|
||||
func TestSuggestionService_PropertyTypeMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
// Create a property type
|
||||
propType := &models.ResidenceType{Name: "House"}
|
||||
err := service.db.Create(propType).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "My House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
PropertyTypeID: &propType.ID,
|
||||
}
|
||||
err = service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
// Reload with PropertyType preloaded
|
||||
err = service.db.Preload("PropertyType").First(residence, residence.ID).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "House Task", map[string]interface{}{
|
||||
"property_type": "House",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "property_type:House")
|
||||
}
|
||||
|
||||
// === CalculateProfileCompleteness with fully filled profile ===
|
||||
|
||||
func TestCalculateProfileCompleteness_FullProfile(t *testing.T) {
|
||||
ht := "gas_furnace"
|
||||
ct := "central_ac"
|
||||
wht := "tank_gas"
|
||||
rt := "asphalt_shingle"
|
||||
et := "brick"
|
||||
fp := "hardwood"
|
||||
lt := "lawn"
|
||||
|
||||
residence := &models.Residence{
|
||||
HeatingType: &ht,
|
||||
CoolingType: &ct,
|
||||
WaterHeaterType: &wht,
|
||||
RoofType: &rt,
|
||||
HasPool: true,
|
||||
HasSprinklerSystem: true,
|
||||
HasSeptic: true,
|
||||
HasFireplace: true,
|
||||
HasGarage: true,
|
||||
HasBasement: true,
|
||||
HasAttic: true,
|
||||
ExteriorType: &et,
|
||||
FlooringPrimary: &fp,
|
||||
LandscapingType: <,
|
||||
}
|
||||
|
||||
completeness := CalculateProfileCompleteness(residence)
|
||||
assert.Equal(t, 1.0, completeness)
|
||||
}
|
||||
|
||||
// === CalculateProfileCompleteness with empty profile ===
|
||||
|
||||
func TestCalculateProfileCompleteness_EmptyProfile(t *testing.T) {
|
||||
residence := &models.Residence{}
|
||||
completeness := CalculateProfileCompleteness(residence)
|
||||
assert.Equal(t, 0.0, completeness)
|
||||
}
|
||||
|
||||
// === Score capped at 1.0 ===
|
||||
|
||||
func TestSuggestionService_ScoreCappedAtOne(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
ht := "gas_furnace"
|
||||
ct := "central_ac"
|
||||
wht := "tank_gas"
|
||||
rt := "asphalt_shingle"
|
||||
et := "brick"
|
||||
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Full House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
HeatingType: &ht,
|
||||
CoolingType: &ct,
|
||||
WaterHeaterType: &wht,
|
||||
RoofType: &rt,
|
||||
ExteriorType: &et,
|
||||
HasPool: true,
|
||||
HasFireplace: true,
|
||||
HasGarage: true,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Template that matches many fields — score should be capped at 1.0
|
||||
createTemplateWithConditions(t, service, "Super Match", map[string]interface{}{
|
||||
"heating_type": "gas_furnace",
|
||||
"cooling_type": "central_ac",
|
||||
"water_heater_type": "tank_gas",
|
||||
"roof_type": "asphalt_shingle",
|
||||
"exterior_type": "brick",
|
||||
"has_pool": true,
|
||||
"has_fireplace": true,
|
||||
"has_garage": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.LessOrEqual(t, resp.Suggestions[0].RelevanceScore, 1.0)
|
||||
}
|
||||
|
||||
// === Inactive templates are excluded ===
|
||||
|
||||
func TestSuggestionService_InactiveTemplateExcluded(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Test House")
|
||||
|
||||
// Create inactive template via raw SQL to bypass GORM default:true on is_active
|
||||
err := service.db.Exec("INSERT INTO task_tasktemplate (title, is_active, conditions, created_at, updated_at) VALUES ('Inactive Task', false, '{}', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)").Error
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
// === Template excluded when requires sprinkler but residence doesn't have it ===
|
||||
|
||||
func TestSuggestionService_ExcludedWhenSprinklerRequired(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "No Sprinkler House")
|
||||
|
||||
createTemplateWithConditions(t, service, "Sprinkler Maintenance", map[string]interface{}{
|
||||
"has_sprinkler_system": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
// === All bool field exclusions ===
|
||||
|
||||
func TestSuggestionService_ExcludedWhenSepticRequired(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "City House")
|
||||
|
||||
createTemplateWithConditions(t, service, "Septic Pump", map[string]interface{}{
|
||||
"has_septic": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
func TestSuggestionService_ExcludedWhenFireplaceRequired(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "No Fireplace")
|
||||
|
||||
createTemplateWithConditions(t, service, "Chimney Sweep", map[string]interface{}{
|
||||
"has_fireplace": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
func TestSuggestionService_ExcludedWhenGarageRequired(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "No Garage")
|
||||
|
||||
createTemplateWithConditions(t, service, "Garage Door Service", map[string]interface{}{
|
||||
"has_garage": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
func TestSuggestionService_ExcludedWhenBasementRequired(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Slab Home")
|
||||
|
||||
createTemplateWithConditions(t, service, "Basement Waterproofing", map[string]interface{}{
|
||||
"has_basement": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
func TestSuggestionService_ExcludedWhenAtticRequired(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
residence := testutil.CreateTestResidence(t, service.db, user.ID, "Flat Roof")
|
||||
|
||||
createTemplateWithConditions(t, service, "Attic Insulation", map[string]interface{}{
|
||||
"has_attic": true,
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resp.Suggestions, 0)
|
||||
}
|
||||
|
||||
// === String field matches for all remaining types ===
|
||||
|
||||
func TestSuggestionService_CoolingTypeMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
coolingType := "central_ac"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Cool House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
CoolingType: &coolingType,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "AC Filter", map[string]interface{}{
|
||||
"cooling_type": "central_ac",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "cooling_type:central_ac")
|
||||
}
|
||||
|
||||
func TestSuggestionService_WaterHeaterTypeMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
wht := "tank_gas"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Hot Water House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
WaterHeaterType: &wht,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "Flush Water Heater", map[string]interface{}{
|
||||
"water_heater_type": "tank_gas",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "water_heater_type:tank_gas")
|
||||
}
|
||||
|
||||
func TestSuggestionService_ExteriorTypeMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
et := "brick"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Brick House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
ExteriorType: &et,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "Pressure Wash Brick", map[string]interface{}{
|
||||
"exterior_type": "brick",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "exterior_type:brick")
|
||||
}
|
||||
|
||||
func TestSuggestionService_FlooringPrimaryMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
fp := "hardwood"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Hardwood House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
FlooringPrimary: &fp,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "Refinish Floors", map[string]interface{}{
|
||||
"flooring_primary": "hardwood",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "flooring_primary:hardwood")
|
||||
}
|
||||
|
||||
func TestSuggestionService_LandscapingTypeMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
lt := "lawn"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Lawn House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
LandscapingType: <,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "Fertilize Lawn", map[string]interface{}{
|
||||
"landscaping_type": "lawn",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "landscaping_type:lawn")
|
||||
}
|
||||
|
||||
func TestSuggestionService_RoofTypeMatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
rt := "asphalt_shingle"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Shingle House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
RoofType: &rt,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
createTemplateWithConditions(t, service, "Inspect Roof", map[string]interface{}{
|
||||
"roof_type": "asphalt_shingle",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "roof_type:asphalt_shingle")
|
||||
}
|
||||
|
||||
// === Mismatch on string field — no score for that field ===
|
||||
|
||||
func TestSuggestionService_HeatingTypeMismatch(t *testing.T) {
|
||||
service := setupSuggestionService(t)
|
||||
user := testutil.CreateTestUser(t, service.db, "owner", "owner@test.com", "Password123")
|
||||
|
||||
heatingType := "electric_furnace"
|
||||
residence := &models.Residence{
|
||||
OwnerID: user.ID,
|
||||
Name: "Electric House",
|
||||
IsActive: true,
|
||||
IsPrimary: true,
|
||||
HeatingType: &heatingType,
|
||||
}
|
||||
err := service.db.Create(residence).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// Template wants gas_furnace but residence has electric_furnace
|
||||
createTemplateWithConditions(t, service, "Gas Furnace Service", map[string]interface{}{
|
||||
"heating_type": "gas_furnace",
|
||||
})
|
||||
|
||||
resp, err := service.GetSuggestions(residence.ID, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resp.Suggestions, 1)
|
||||
// Should still be included but with partial_profile (no match, no exclude)
|
||||
assert.Contains(t, resp.Suggestions[0].MatchReasons, "partial_profile")
|
||||
}
|
||||
|
||||
// === templateConditions.isEmpty ===
|
||||
|
||||
func TestTemplateConditions_IsEmpty(t *testing.T) {
|
||||
cond := &templateConditions{}
|
||||
assert.True(t, cond.isEmpty())
|
||||
|
||||
ht := "gas"
|
||||
cond2 := &templateConditions{HeatingType: &ht}
|
||||
assert.False(t, cond2.isEmpty())
|
||||
|
||||
pool := true
|
||||
cond3 := &templateConditions{HasPool: &pool}
|
||||
assert.False(t, cond3.isEmpty())
|
||||
|
||||
pt := "House"
|
||||
cond4 := &templateConditions{PropertyType: &pt}
|
||||
assert.False(t, cond4.isEmpty())
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
467
internal/task/task_test.go
Normal file
467
internal/task/task_test.go
Normal file
@@ -0,0 +1,467 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/task/categorization"
|
||||
"github.com/treytartt/honeydue-api/internal/testutil"
|
||||
)
|
||||
|
||||
var now = time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
func timePtr(t time.Time) *time.Time { return &t }
|
||||
|
||||
func completedTask() *models.Task {
|
||||
return &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 1},
|
||||
Completions: []models.TaskCompletion{{BaseModel: models.BaseModel{ID: 1}}},
|
||||
// NextDueDate is nil → completed
|
||||
}
|
||||
}
|
||||
|
||||
func overdueTask() *models.Task {
|
||||
past := now.AddDate(0, 0, -5)
|
||||
return &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 2},
|
||||
DueDate: &past,
|
||||
}
|
||||
}
|
||||
|
||||
func dueSoonTask() *models.Task {
|
||||
soon := now.AddDate(0, 0, 10)
|
||||
return &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 3},
|
||||
DueDate: &soon,
|
||||
}
|
||||
}
|
||||
|
||||
func activeTask() *models.Task {
|
||||
return &models.Task{
|
||||
BaseModel: models.BaseModel{ID: 4},
|
||||
}
|
||||
}
|
||||
|
||||
// --- Predicate re-exports ---
|
||||
|
||||
func TestReExport_IsCompleted(t *testing.T) {
|
||||
if !IsCompleted(completedTask()) {
|
||||
t.Error("expected completed")
|
||||
}
|
||||
if IsCompleted(activeTask()) {
|
||||
t.Error("expected not completed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsOverdue(t *testing.T) {
|
||||
if !IsOverdue(overdueTask(), now) {
|
||||
t.Error("expected overdue")
|
||||
}
|
||||
if IsOverdue(dueSoonTask(), now) {
|
||||
t.Error("expected not overdue")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsDueSoon(t *testing.T) {
|
||||
if !IsDueSoon(dueSoonTask(), now, 30) {
|
||||
t.Error("expected due soon")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsActive(t *testing.T) {
|
||||
if !IsActive(activeTask()) {
|
||||
t.Error("expected active")
|
||||
}
|
||||
cancelled := &models.Task{IsCancelled: true}
|
||||
if IsActive(cancelled) {
|
||||
t.Error("cancelled should not be active")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsArchived(t *testing.T) {
|
||||
archived := &models.Task{IsArchived: true}
|
||||
if !IsArchived(archived) {
|
||||
t.Error("expected archived")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsCancelled(t *testing.T) {
|
||||
cancelled := &models.Task{IsCancelled: true}
|
||||
if !IsCancelled(cancelled) {
|
||||
t.Error("expected cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsInProgress(t *testing.T) {
|
||||
ip := &models.Task{InProgress: true}
|
||||
if !IsInProgress(ip) {
|
||||
t.Error("expected in progress")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsRecurring(t *testing.T) {
|
||||
days := 30
|
||||
freqID := uint(1)
|
||||
recurring := &models.Task{
|
||||
FrequencyID: &freqID,
|
||||
Frequency: &models.TaskFrequency{Days: &days},
|
||||
}
|
||||
if !IsRecurring(recurring) {
|
||||
t.Error("expected recurring")
|
||||
}
|
||||
if IsRecurring(activeTask()) {
|
||||
t.Error("expected not recurring")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsOneTime(t *testing.T) {
|
||||
if !IsOneTime(activeTask()) {
|
||||
t.Error("expected one-time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_HasCompletions(t *testing.T) {
|
||||
if !HasCompletions(completedTask()) {
|
||||
t.Error("expected has completions")
|
||||
}
|
||||
if HasCompletions(activeTask()) {
|
||||
t.Error("expected no completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_GetCompletionCount(t *testing.T) {
|
||||
ct := completedTask()
|
||||
if GetCompletionCount(ct) != 1 {
|
||||
t.Errorf("count = %d, want 1", GetCompletionCount(ct))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_EffectiveDate(t *testing.T) {
|
||||
task := overdueTask()
|
||||
ed := EffectiveDate(task)
|
||||
if ed == nil {
|
||||
t.Error("expected non-nil effective date")
|
||||
}
|
||||
// If NextDueDate is set, prefer it
|
||||
next := now.AddDate(0, 1, 0)
|
||||
task.NextDueDate = &next
|
||||
ed = EffectiveDate(task)
|
||||
if !ed.Equal(next) {
|
||||
t.Errorf("expected NextDueDate, got %v", ed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_IsUpcoming(t *testing.T) {
|
||||
far := now.AddDate(0, 6, 0)
|
||||
task := &models.Task{DueDate: &far}
|
||||
if !IsUpcoming(task, now, 30) {
|
||||
t.Error("expected upcoming")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_CategorizeTask(t *testing.T) {
|
||||
col := CategorizeTask(overdueTask(), 30)
|
||||
if col != categorization.ColumnOverdue {
|
||||
t.Errorf("column = %v, want overdue", col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_DetermineKanbanColumn(t *testing.T) {
|
||||
col := DetermineKanbanColumn(overdueTask(), 30)
|
||||
if col == "" {
|
||||
t.Error("expected non-empty column string")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_CategorizeTasksIntoColumns(t *testing.T) {
|
||||
tasks := []models.Task{*overdueTask(), *dueSoonTask(), *activeTask()}
|
||||
result := CategorizeTasksIntoColumns(tasks, 30)
|
||||
if result == nil {
|
||||
t.Error("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_NewChain(t *testing.T) {
|
||||
chain := NewChain()
|
||||
if chain == nil {
|
||||
t.Error("expected non-nil chain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_Constants(t *testing.T) {
|
||||
if ColumnOverdue.String() != "overdue_tasks" {
|
||||
t.Errorf("ColumnOverdue = %q", ColumnOverdue.String())
|
||||
}
|
||||
if ColumnDueSoon.String() != "due_soon_tasks" {
|
||||
t.Errorf("ColumnDueSoon = %q", ColumnDueSoon.String())
|
||||
}
|
||||
if ColumnUpcoming.String() != "upcoming_tasks" {
|
||||
t.Errorf("ColumnUpcoming = %q", ColumnUpcoming.String())
|
||||
}
|
||||
if ColumnInProgress.String() != "in_progress_tasks" {
|
||||
t.Errorf("ColumnInProgress = %q", ColumnInProgress.String())
|
||||
}
|
||||
if ColumnCompleted.String() != "completed_tasks" {
|
||||
t.Errorf("ColumnCompleted = %q", ColumnCompleted.String())
|
||||
}
|
||||
if ColumnCancelled.String() != "cancelled_tasks" {
|
||||
t.Errorf("ColumnCancelled = %q", ColumnCancelled.String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Scope re-exports (use SQLite in-memory DB) ---
|
||||
|
||||
func setupDB(t *testing.T) *gorm.DB {
|
||||
return testutil.SetupTestDB(t)
|
||||
}
|
||||
|
||||
func seedTask(t *testing.T, db *gorm.DB, task *models.Task) {
|
||||
// Ensure we have a user and residence
|
||||
user := testutil.CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
res := testutil.CreateTestResidence(t, db, user.ID, "Test House")
|
||||
task.CreatedByID = user.ID
|
||||
task.ResidenceID = res.ID
|
||||
task.Version = 1
|
||||
err := db.Create(task).Error
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create task: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeActive_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "active"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeActive).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeCancelled_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "cancelled", IsCancelled: true})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeCancelled).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeArchived_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "archived", IsArchived: true})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeArchived).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeInProgress_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "inprog", InProgress: true})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeInProgress).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeNotInProgress_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "not inprog"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeNotInProgress).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeCompleted_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "u2", "u2@test.com", "password123")
|
||||
res := testutil.CreateTestResidence(t, db, user.ID, "House2")
|
||||
task := &models.Task{Title: "completed", CreatedByID: user.ID, ResidenceID: res.ID, Version: 1}
|
||||
db.Create(task)
|
||||
// Add a completion and ensure NextDueDate is nil
|
||||
db.Create(&models.TaskCompletion{TaskID: task.ID, CompletedByID: user.ID, CompletedAt: now})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeCompleted).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeNotCompleted_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "not completed"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeNotCompleted).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeOverdue_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
past := now.AddDate(0, 0, -5)
|
||||
seedTask(t, db, &models.Task{Title: "overdue", DueDate: &past})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeOverdue(now)).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeDueSoon_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
soon := now.AddDate(0, 0, 5)
|
||||
seedTask(t, db, &models.Task{Title: "due soon", DueDate: &soon})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeDueSoon(now, 30)).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeUpcoming_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
far := now.AddDate(0, 6, 0)
|
||||
seedTask(t, db, &models.Task{Title: "upcoming", DueDate: &far})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeUpcoming(now, 30)).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeForResidence_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "u3", "u3@test.com", "password123")
|
||||
res := testutil.CreateTestResidence(t, db, user.ID, "House3")
|
||||
db.Create(&models.Task{Title: "t1", CreatedByID: user.ID, ResidenceID: res.ID, Version: 1})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeForResidence(res.ID)).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeForResidences_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "u4", "u4@test.com", "password123")
|
||||
res1 := testutil.CreateTestResidence(t, db, user.ID, "H1")
|
||||
res2 := testutil.CreateTestResidence(t, db, user.ID, "H2")
|
||||
db.Create(&models.Task{Title: "t1", CreatedByID: user.ID, ResidenceID: res1.ID, Version: 1})
|
||||
db.Create(&models.Task{Title: "t2", CreatedByID: user.ID, ResidenceID: res2.ID, Version: 1})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeForResidences([]uint{res1.ID, res2.ID})).Find(&tasks)
|
||||
if len(tasks) != 2 {
|
||||
t.Errorf("len = %d, want 2", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeDueInRange_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
due := now.AddDate(0, 0, 3)
|
||||
seedTask(t, db, &models.Task{Title: "in range", DueDate: &due})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeDueInRange(now, now.AddDate(0, 0, 7))).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeHasDueDate_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
due := now.AddDate(0, 0, 1)
|
||||
seedTask(t, db, &models.Task{Title: "with due", DueDate: &due})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeHasDueDate).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeNoDueDate_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "no due"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeNoDueDate).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeHasCompletions_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "u5", "u5@test.com", "password123")
|
||||
res := testutil.CreateTestResidence(t, db, user.ID, "H5")
|
||||
task := &models.Task{Title: "has comp", CreatedByID: user.ID, ResidenceID: res.ID, Version: 1}
|
||||
db.Create(task)
|
||||
db.Create(&models.TaskCompletion{TaskID: task.ID, CompletedByID: user.ID, CompletedAt: now})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeHasCompletions).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeNoCompletions_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "no comp"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeNoCompletions).Find(&tasks)
|
||||
if len(tasks) != 1 {
|
||||
t.Errorf("len = %d, want 1", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeOrderByDueDate_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
user := testutil.CreateTestUser(t, db, "u6", "u6@test.com", "password123")
|
||||
res := testutil.CreateTestResidence(t, db, user.ID, "H6")
|
||||
d1 := now.AddDate(0, 0, 5)
|
||||
d2 := now.AddDate(0, 0, 1)
|
||||
db.Create(&models.Task{Title: "later", CreatedByID: user.ID, ResidenceID: res.ID, DueDate: &d1, Version: 1})
|
||||
db.Create(&models.Task{Title: "sooner", CreatedByID: user.ID, ResidenceID: res.ID, DueDate: &d2, Version: 1})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeOrderByDueDate).Find(&tasks)
|
||||
if len(tasks) != 2 {
|
||||
t.Fatalf("len = %d", len(tasks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeOrderByPriority_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "prio test"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeOrderByPriority).Find(&tasks)
|
||||
if len(tasks) < 1 {
|
||||
t.Error("expected at least 1 task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeOrderByCreatedAt_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "created test"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeOrderByCreatedAt).Find(&tasks)
|
||||
if len(tasks) < 1 {
|
||||
t.Error("expected at least 1 task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReExport_ScopeKanbanOrder_DB(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
seedTask(t, db, &models.Task{Title: "kanban test"})
|
||||
var tasks []models.Task
|
||||
db.Scopes(ScopeKanbanOrder).Find(&tasks)
|
||||
if len(tasks) < 1 {
|
||||
t.Error("expected at least 1 task")
|
||||
}
|
||||
}
|
||||
177
internal/testutil/testutil_test.go
Normal file
177
internal/testutil/testutil_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
func TestSetupTestDB_Works(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
if db == nil {
|
||||
t.Fatal("expected non-nil db")
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get sql.DB: %v", err)
|
||||
}
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
t.Fatalf("ping failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupTestRouter_HasValidator(t *testing.T) {
|
||||
e := SetupTestRouter()
|
||||
if e == nil {
|
||||
t.Fatal("expected non-nil router")
|
||||
}
|
||||
if e.Validator == nil {
|
||||
t.Error("expected validator to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTestUser_ReturnsUser(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
user := CreateTestUser(t, db, "testuser", "test@example.com", "password123")
|
||||
if user == nil {
|
||||
t.Fatal("expected non-nil user")
|
||||
}
|
||||
if user.ID == 0 {
|
||||
t.Error("expected user to have an ID")
|
||||
}
|
||||
if user.Username != "testuser" {
|
||||
t.Errorf("username = %q, want testuser", user.Username)
|
||||
}
|
||||
if user.Email != "test@example.com" {
|
||||
t.Errorf("email = %q, want test@example.com", user.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeedLookupData_PopulatesData(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
SeedLookupData(t, db)
|
||||
|
||||
// Verify residence types seeded
|
||||
var rtCount int64
|
||||
db.Table("residence_residencetype").Count(&rtCount)
|
||||
if rtCount < 4 {
|
||||
t.Errorf("residence types = %d, want ≥4", rtCount)
|
||||
}
|
||||
|
||||
// Verify task categories seeded
|
||||
var catCount int64
|
||||
db.Table("task_taskcategory").Count(&catCount)
|
||||
if catCount < 4 {
|
||||
t.Errorf("task categories = %d, want ≥4", catCount)
|
||||
}
|
||||
|
||||
// Verify task priorities seeded
|
||||
var priCount int64
|
||||
db.Table("task_taskpriority").Count(&priCount)
|
||||
if priCount < 4 {
|
||||
t.Errorf("task priorities = %d, want ≥4", priCount)
|
||||
}
|
||||
|
||||
// Verify task frequencies seeded
|
||||
var freqCount int64
|
||||
db.Table("task_taskfrequency").Count(&freqCount)
|
||||
if freqCount < 3 {
|
||||
t.Errorf("task frequencies = %d, want ≥3", freqCount)
|
||||
}
|
||||
|
||||
// Verify contractor specialties seeded
|
||||
var specCount int64
|
||||
db.Table("task_contractorspecialty").Count(&specCount)
|
||||
if specCount < 4 {
|
||||
t.Errorf("contractor specialties = %d, want ≥4", specCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockAuthMiddleware_SetsUser(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
user := CreateTestUser(t, db, "authuser", "auth@example.com", "pass")
|
||||
|
||||
e := echo.New()
|
||||
mw := MockAuthMiddleware(user)
|
||||
|
||||
var capturedUser interface{}
|
||||
handler := mw(func(c echo.Context) error {
|
||||
capturedUser = c.Get("auth_user")
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
if err := handler(c); err != nil {
|
||||
t.Fatalf("handler error: %v", err)
|
||||
}
|
||||
|
||||
if capturedUser == nil {
|
||||
t.Fatal("expected auth_user to be set in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTestResidence_ReturnsResidence(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
user := CreateTestUser(t, db, "owner", "owner@example.com", "pass")
|
||||
residence := CreateTestResidence(t, db, user.ID, "Test Home")
|
||||
|
||||
if residence == nil {
|
||||
t.Fatal("expected non-nil residence")
|
||||
}
|
||||
if residence.ID == 0 {
|
||||
t.Error("expected residence to have an ID")
|
||||
}
|
||||
if residence.Name != "Test Home" {
|
||||
t.Errorf("name = %q, want Test Home", residence.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTestTask_ReturnsTask(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
user := CreateTestUser(t, db, "taskowner", "task@example.com", "pass")
|
||||
residence := CreateTestResidence(t, db, user.ID, "Task Home")
|
||||
task := CreateTestTask(t, db, residence.ID, user.ID, "Fix the sink")
|
||||
|
||||
if task == nil {
|
||||
t.Fatal("expected non-nil task")
|
||||
}
|
||||
if task.ID == 0 {
|
||||
t.Error("expected task to have an ID")
|
||||
}
|
||||
if task.Title != "Fix the sink" {
|
||||
t.Errorf("title = %q, want Fix the sink", task.Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSON_Valid(t *testing.T) {
|
||||
data := []byte(`{"name":"test","count":42}`)
|
||||
result := ParseJSON(t, data)
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("name = %v, want test", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONArray_Valid(t *testing.T) {
|
||||
data := []byte(`[{"id":1},{"id":2}]`)
|
||||
result := ParseJSONArray(t, data)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("len = %d, want 2", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRequestT_ReturnsRecorder(t *testing.T) {
|
||||
e := SetupTestRouter()
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"status": "ok"})
|
||||
})
|
||||
|
||||
rec := MakeRequestT(t, e, http.MethodGet, "/test", nil, "")
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rec.Code)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
govalidator "github.com/go-playground/validator/v10"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidatePasswordComplexity(t *testing.T) {
|
||||
@@ -113,3 +116,118 @@ func TestFormatMessagePasswordComplexity(t *testing.T) {
|
||||
t.Errorf("expected tag 'password_complexity', got %q", field.Tag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordComplexity_AdditionalCases(t *testing.T) {
|
||||
cv := NewCustomValidator()
|
||||
|
||||
type request struct {
|
||||
Password string `json:"password" validate:"required,min=8,password_complexity"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pw string
|
||||
valid bool
|
||||
}{
|
||||
{"no uppercase no digit", "password", false},
|
||||
{"no lowercase", "PASSWORD1", false},
|
||||
{"no digit", "Password", false},
|
||||
{"too short", "Pass1", false},
|
||||
{"valid standard", "Password1", true},
|
||||
{"valid with special chars", "P@ssw0rd", true},
|
||||
{"spaces with complexity", "Pass 1234", true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := request{Password: tc.pw}
|
||||
err := cv.Validate(r)
|
||||
if tc.valid {
|
||||
assert.NoError(t, err, "expected %q to be valid", tc.pw)
|
||||
} else {
|
||||
assert.Error(t, err, "expected %q to be invalid", tc.pw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatValidationErrors_AllTags(t *testing.T) {
|
||||
cv := NewCustomValidator()
|
||||
|
||||
type allTags struct {
|
||||
Required string `json:"required" validate:"required"`
|
||||
Email string `json:"email" validate:"email"`
|
||||
MinLen string `json:"min_len" validate:"min=5"`
|
||||
MaxLen string `json:"max_len" validate:"max=3"`
|
||||
OneOf string `json:"one_of" validate:"oneof=a b c"`
|
||||
URL string `json:"url" validate:"url"`
|
||||
}
|
||||
|
||||
input := allTags{
|
||||
Required: "", // fails required
|
||||
Email: "bad", // fails email
|
||||
MinLen: "ab", // fails min=5
|
||||
MaxLen: "abcde", // fails max=3
|
||||
OneOf: "z", // fails oneof
|
||||
URL: "nope", // fails url
|
||||
}
|
||||
|
||||
err := cv.Validate(input)
|
||||
require.Error(t, err)
|
||||
|
||||
resp := FormatValidationErrors(err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "Validation failed", resp.Error)
|
||||
|
||||
expectedMessages := map[string]string{
|
||||
"required": "This field is required",
|
||||
"email": "Must be a valid email address",
|
||||
"min_len": "Must be at least 5 characters",
|
||||
"max_len": "Must be at most 3 characters",
|
||||
"one_of": "Must be one of: a b c",
|
||||
"url": "Must be a valid URL",
|
||||
}
|
||||
|
||||
for field, expectedMsg := range expectedMessages {
|
||||
fe, ok := resp.Fields[field]
|
||||
assert.True(t, ok, "expected field %q in error response", field)
|
||||
if ok {
|
||||
assert.Equal(t, expectedMsg, fe.Message, "message mismatch for field %q", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatValidationErrors_NonValidationError(t *testing.T) {
|
||||
err := fmt.Errorf("some random error")
|
||||
resp := FormatValidationErrors(err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "some random error", resp.Error)
|
||||
assert.Nil(t, resp.Fields)
|
||||
}
|
||||
|
||||
func TestNewCustomValidator_UsesJSONTagNames(t *testing.T) {
|
||||
cv := NewCustomValidator()
|
||||
|
||||
type request struct {
|
||||
FirstName string `json:"first_name" validate:"required"`
|
||||
}
|
||||
|
||||
err := cv.Validate(request{})
|
||||
require.Error(t, err)
|
||||
|
||||
resp := FormatValidationErrors(err)
|
||||
require.NotNil(t, resp)
|
||||
_, ok := resp.Fields["first_name"]
|
||||
assert.True(t, ok, "expected JSON tag name 'first_name' in error fields")
|
||||
}
|
||||
|
||||
func TestCustomValidator_Validate_Success(t *testing.T) {
|
||||
cv := NewCustomValidator()
|
||||
|
||||
type request struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
}
|
||||
|
||||
err := cv.Validate(request{Name: "test"})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
44
internal/worker/enqueuer.go
Normal file
44
internal/worker/enqueuer.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package worker
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Enqueuer defines the interface for enqueuing background email tasks.
|
||||
type Enqueuer interface {
|
||||
EnqueueWelcomeEmail(to, firstName, code string) error
|
||||
EnqueueVerificationEmail(to, firstName, code string) error
|
||||
EnqueuePasswordResetEmail(to, firstName, code, resetToken string) error
|
||||
EnqueuePasswordChangedEmail(to, firstName string) error
|
||||
}
|
||||
|
||||
// Verify TaskClient satisfies the interface at compile time.
|
||||
var _ Enqueuer = (*TaskClient)(nil)
|
||||
|
||||
// BuildWelcomeEmailPayload marshals a WelcomeEmailPayload to JSON bytes.
|
||||
func BuildWelcomeEmailPayload(to, firstName, code string) ([]byte, error) {
|
||||
return json.Marshal(WelcomeEmailPayload{
|
||||
EmailPayload: EmailPayload{To: to, FirstName: firstName},
|
||||
ConfirmationCode: code,
|
||||
})
|
||||
}
|
||||
|
||||
// BuildVerificationEmailPayload marshals a VerificationEmailPayload to JSON bytes.
|
||||
func BuildVerificationEmailPayload(to, firstName, code string) ([]byte, error) {
|
||||
return json.Marshal(VerificationEmailPayload{
|
||||
EmailPayload: EmailPayload{To: to, FirstName: firstName},
|
||||
Code: code,
|
||||
})
|
||||
}
|
||||
|
||||
// BuildPasswordResetEmailPayload marshals a PasswordResetEmailPayload to JSON bytes.
|
||||
func BuildPasswordResetEmailPayload(to, firstName, code, resetToken string) ([]byte, error) {
|
||||
return json.Marshal(PasswordResetEmailPayload{
|
||||
EmailPayload: EmailPayload{To: to, FirstName: firstName},
|
||||
Code: code,
|
||||
ResetToken: resetToken,
|
||||
})
|
||||
}
|
||||
|
||||
// BuildPasswordChangedEmailPayload marshals an EmailPayload to JSON bytes.
|
||||
func BuildPasswordChangedEmailPayload(to, firstName string) ([]byte, error) {
|
||||
return json.Marshal(EmailPayload{To: to, FirstName: firstName})
|
||||
}
|
||||
79
internal/worker/enqueuer_test.go
Normal file
79
internal/worker/enqueuer_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildWelcomeEmailPayload_RoundTrip(t *testing.T) {
|
||||
data, err := BuildWelcomeEmailPayload("a@b.com", "Alice", "CODE123")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p WelcomeEmailPayload
|
||||
if err := json.Unmarshal(data, &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.To != "a@b.com" || p.FirstName != "Alice" || p.ConfirmationCode != "CODE123" {
|
||||
t.Errorf("got %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildVerificationEmailPayload_RoundTrip(t *testing.T) {
|
||||
data, err := BuildVerificationEmailPayload("b@c.com", "Bob", "VERIFY456")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p VerificationEmailPayload
|
||||
if err := json.Unmarshal(data, &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.To != "b@c.com" || p.FirstName != "Bob" || p.Code != "VERIFY456" {
|
||||
t.Errorf("got %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPasswordResetEmailPayload_RoundTrip(t *testing.T) {
|
||||
data, err := BuildPasswordResetEmailPayload("c@d.com", "Carol", "RST789", "token-abc")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p PasswordResetEmailPayload
|
||||
if err := json.Unmarshal(data, &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.To != "c@d.com" || p.FirstName != "Carol" || p.Code != "RST789" || p.ResetToken != "token-abc" {
|
||||
t.Errorf("got %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPasswordChangedEmailPayload_RoundTrip(t *testing.T) {
|
||||
data, err := BuildPasswordChangedEmailPayload("d@e.com", "Dave")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p EmailPayload
|
||||
if err := json.Unmarshal(data, &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.To != "d@e.com" || p.FirstName != "Dave" {
|
||||
t.Errorf("got %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWelcomeEmailPayload_Fields(t *testing.T) {
|
||||
data, err := BuildWelcomeEmailPayload("test@example.com", "Test", "ABC")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Verify raw JSON contains expected keys
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
for _, key := range []string{"to", "first_name", "confirmation_code"} {
|
||||
if _, ok := raw[key]; !ok {
|
||||
t.Errorf("missing key %q in payload JSON", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -30,38 +30,43 @@ const (
|
||||
|
||||
// Handler handles background job processing
|
||||
type Handler struct {
|
||||
db *gorm.DB
|
||||
taskRepo *repositories.TaskRepository
|
||||
residenceRepo *repositories.ResidenceRepository
|
||||
reminderRepo *repositories.ReminderRepository
|
||||
notificationRepo *repositories.NotificationRepository
|
||||
pushClient *push.Client
|
||||
emailService *services.EmailService
|
||||
notificationService *services.NotificationService
|
||||
onboardingService *services.OnboardingEmailService
|
||||
config *config.Config
|
||||
db *gorm.DB
|
||||
taskRepo TaskRepo
|
||||
residenceRepo ResidenceRepo
|
||||
reminderRepo ReminderRepo
|
||||
notificationRepo NotificationRepo
|
||||
pushClient PushSender
|
||||
emailService EmailSender
|
||||
notificationService NotificationSender
|
||||
onboardingService OnboardingEmailSender
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewHandler creates a new job handler
|
||||
func NewHandler(db *gorm.DB, pushClient *push.Client, emailService *services.EmailService, notificationService *services.NotificationService, cfg *config.Config) *Handler {
|
||||
// Create onboarding email service
|
||||
var onboardingService *services.OnboardingEmailService
|
||||
if emailService != nil {
|
||||
onboardingService = services.NewOnboardingEmailService(db, emailService, cfg.Server.BaseURL)
|
||||
h := &Handler{
|
||||
db: db,
|
||||
taskRepo: repositories.NewTaskRepository(db),
|
||||
residenceRepo: repositories.NewResidenceRepository(db),
|
||||
reminderRepo: repositories.NewReminderRepository(db),
|
||||
notificationRepo: repositories.NewNotificationRepository(db),
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
db: db,
|
||||
taskRepo: repositories.NewTaskRepository(db),
|
||||
residenceRepo: repositories.NewResidenceRepository(db),
|
||||
reminderRepo: repositories.NewReminderRepository(db),
|
||||
notificationRepo: repositories.NewNotificationRepository(db),
|
||||
pushClient: pushClient,
|
||||
emailService: emailService,
|
||||
notificationService: notificationService,
|
||||
onboardingService: onboardingService,
|
||||
config: cfg,
|
||||
// Assign interface fields only when concrete values are non-nil
|
||||
// to preserve correct nil checks on the interface values.
|
||||
if pushClient != nil {
|
||||
h.pushClient = pushClient
|
||||
}
|
||||
if emailService != nil {
|
||||
h.emailService = emailService
|
||||
h.onboardingService = services.NewOnboardingEmailService(db, emailService, cfg.Server.BaseURL)
|
||||
}
|
||||
if notificationService != nil {
|
||||
h.notificationService = notificationService
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// HandleDailyDigest processes daily digest notifications with task statistics
|
||||
|
||||
39
internal/worker/jobs/handler_helpers.go
Normal file
39
internal/worker/jobs/handler_helpers.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package jobs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// BuildDigestMessage constructs the daily digest notification text.
|
||||
func BuildDigestMessage(overdueCount, dueThisWeekCount int) (title, body string) {
|
||||
title = "Daily Task Summary"
|
||||
if overdueCount > 0 && dueThisWeekCount > 0 {
|
||||
body = fmt.Sprintf("You have %d overdue task(s) and %d task(s) due this week", overdueCount, dueThisWeekCount)
|
||||
} else if overdueCount > 0 {
|
||||
body = fmt.Sprintf("You have %d overdue task(s) that need attention", overdueCount)
|
||||
} else {
|
||||
body = fmt.Sprintf("You have %d task(s) due this week", dueThisWeekCount)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// IsOverdueStage checks if a reminder stage string represents overdue.
|
||||
func IsOverdueStage(stage string) bool {
|
||||
return strings.HasPrefix(stage, "overdue")
|
||||
}
|
||||
|
||||
// ExtractFrequencyDays gets interval days from a task's frequency.
|
||||
func ExtractFrequencyDays(t *models.Task) *int {
|
||||
if t.Frequency != nil && t.Frequency.Days != nil {
|
||||
days := *t.Frequency.Days
|
||||
return &days
|
||||
}
|
||||
if t.CustomIntervalDays != nil {
|
||||
days := *t.CustomIntervalDays
|
||||
return &days
|
||||
}
|
||||
return nil
|
||||
}
|
||||
226
internal/worker/jobs/handler_helpers_test.go
Normal file
226
internal/worker/jobs/handler_helpers_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package jobs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
)
|
||||
|
||||
// --- BuildDigestMessage ---
|
||||
|
||||
func TestBuildDigestMessage_BothCounts(t *testing.T) {
|
||||
title, body := BuildDigestMessage(3, 5)
|
||||
if title != "Daily Task Summary" {
|
||||
t.Errorf("title = %q, want %q", title, "Daily Task Summary")
|
||||
}
|
||||
want := "You have 3 overdue task(s) and 5 task(s) due this week"
|
||||
if body != want {
|
||||
t.Errorf("body = %q, want %q", body, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDigestMessage_OnlyOverdue(t *testing.T) {
|
||||
_, body := BuildDigestMessage(2, 0)
|
||||
want := "You have 2 overdue task(s) that need attention"
|
||||
if body != want {
|
||||
t.Errorf("body = %q, want %q", body, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDigestMessage_OnlyDueSoon(t *testing.T) {
|
||||
_, body := BuildDigestMessage(0, 4)
|
||||
want := "You have 4 task(s) due this week"
|
||||
if body != want {
|
||||
t.Errorf("body = %q, want %q", body, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDigestMessage_Title_AlwaysDailyTaskSummary(t *testing.T) {
|
||||
cases := [][2]int{{1, 1}, {0, 1}, {1, 0}}
|
||||
for _, c := range cases {
|
||||
title, _ := BuildDigestMessage(c[0], c[1])
|
||||
if title != "Daily Task Summary" {
|
||||
t.Errorf("BuildDigestMessage(%d,%d) title = %q", c[0], c[1], title)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- IsOverdueStage ---
|
||||
|
||||
func TestIsOverdueStage_Overdue1_True(t *testing.T) {
|
||||
if !IsOverdueStage("overdue_1") {
|
||||
t.Error("expected true for overdue_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOverdueStage_Overdue14_True(t *testing.T) {
|
||||
if !IsOverdueStage("overdue_14") {
|
||||
t.Error("expected true for overdue_14")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOverdueStage_Reminder7d_False(t *testing.T) {
|
||||
if IsOverdueStage("reminder_7d") {
|
||||
t.Error("expected false for reminder_7d")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOverdueStage_DayOf_False(t *testing.T) {
|
||||
if IsOverdueStage("day_of") {
|
||||
t.Error("expected false for day_of")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOverdueStage_Empty_False(t *testing.T) {
|
||||
if IsOverdueStage("") {
|
||||
t.Error("expected false for empty string")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ExtractFrequencyDays ---
|
||||
|
||||
func TestExtractFrequencyDays_WithFrequency(t *testing.T) {
|
||||
days := 7
|
||||
task := &models.Task{
|
||||
Frequency: &models.TaskFrequency{Days: &days},
|
||||
}
|
||||
got := ExtractFrequencyDays(task)
|
||||
if got == nil || *got != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFrequencyDays_WithCustomInterval(t *testing.T) {
|
||||
custom := 14
|
||||
task := &models.Task{
|
||||
CustomIntervalDays: &custom,
|
||||
}
|
||||
got := ExtractFrequencyDays(task)
|
||||
if got == nil || *got != 14 {
|
||||
t.Errorf("got %v, want 14", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFrequencyDays_NilFrequency(t *testing.T) {
|
||||
task := &models.Task{}
|
||||
got := ExtractFrequencyDays(task)
|
||||
if got != nil {
|
||||
t.Errorf("got %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractFrequencyDays_NilDays(t *testing.T) {
|
||||
task := &models.Task{
|
||||
Frequency: &models.TaskFrequency{},
|
||||
}
|
||||
got := ExtractFrequencyDays(task)
|
||||
if got != nil {
|
||||
t.Errorf("got %v, want nil", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Email payload tests ---
|
||||
|
||||
func TestEmailPayload_Unmarshal_Valid(t *testing.T) {
|
||||
data := []byte(`{"to":"a@b.com","subject":"hi","html_body":"<b>hi</b>","text_body":"hi"}`)
|
||||
var p EmailPayload
|
||||
if err := json.Unmarshal(data, &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.To != "a@b.com" || p.Subject != "hi" {
|
||||
t.Errorf("got %+v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailPayload_Unmarshal_Invalid(t *testing.T) {
|
||||
var p EmailPayload
|
||||
if err := json.Unmarshal([]byte(`{invalid}`), &p); err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- NewSendEmailTask ---
|
||||
|
||||
func TestNewSendEmailTask_ReturnsTask(t *testing.T) {
|
||||
task, err := NewSendEmailTask("a@b.com", "Subject", "<b>hi</b>", "hi")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task.Type() != TypeSendEmail {
|
||||
t.Errorf("task type = %q, want %q", task.Type(), TypeSendEmail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSendEmailTask_PayloadFields(t *testing.T) {
|
||||
task, err := NewSendEmailTask("user@example.com", "Welcome", "<h1>Hello</h1>", "Hello")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p EmailPayload
|
||||
if err := json.Unmarshal(task.Payload(), &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.To != "user@example.com" {
|
||||
t.Errorf("To = %q, want %q", p.To, "user@example.com")
|
||||
}
|
||||
if p.Subject != "Welcome" {
|
||||
t.Errorf("Subject = %q, want %q", p.Subject, "Welcome")
|
||||
}
|
||||
if p.HTMLBody != "<h1>Hello</h1>" {
|
||||
t.Errorf("HTMLBody = %q, want %q", p.HTMLBody, "<h1>Hello</h1>")
|
||||
}
|
||||
if p.TextBody != "Hello" {
|
||||
t.Errorf("TextBody = %q, want %q", p.TextBody, "Hello")
|
||||
}
|
||||
}
|
||||
|
||||
// --- NewSendPushTask ---
|
||||
|
||||
func TestNewSendPushTask_ReturnsTask(t *testing.T) {
|
||||
task, err := NewSendPushTask(42, "Title", "Body", map[string]string{"key": "val"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if task.Type() != TypeSendPush {
|
||||
t.Errorf("task type = %q, want %q", task.Type(), TypeSendPush)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSendPushTask_PayloadFields(t *testing.T) {
|
||||
data := map[string]string{"action": "open", "id": "123"}
|
||||
task, err := NewSendPushTask(7, "Alert", "Something happened", data)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p PushPayload
|
||||
if err := json.Unmarshal(task.Payload(), &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.UserID != 7 {
|
||||
t.Errorf("UserID = %d, want 7", p.UserID)
|
||||
}
|
||||
if p.Title != "Alert" {
|
||||
t.Errorf("Title = %q, want %q", p.Title, "Alert")
|
||||
}
|
||||
if p.Message != "Something happened" {
|
||||
t.Errorf("Message = %q, want %q", p.Message, "Something happened")
|
||||
}
|
||||
if p.Data["action"] != "open" || p.Data["id"] != "123" {
|
||||
t.Errorf("Data = %v, want map with action=open, id=123", p.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSendPushTask_NilData(t *testing.T) {
|
||||
task, err := NewSendPushTask(1, "T", "M", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var p PushPayload
|
||||
if err := json.Unmarshal(task.Payload(), &p); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if p.Data != nil {
|
||||
t.Errorf("Data = %v, want nil", p.Data)
|
||||
}
|
||||
}
|
||||
388
internal/worker/jobs/handler_test.go
Normal file
388
internal/worker/jobs/handler_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package jobs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/config"
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
)
|
||||
|
||||
// --- Mock implementations ---
|
||||
|
||||
type mockEmailSender struct {
|
||||
sendFn func(to, subject, htmlBody, textBody string) error
|
||||
}
|
||||
|
||||
func (m *mockEmailSender) SendEmail(to, subject, htmlBody, textBody string) error {
|
||||
if m.sendFn != nil {
|
||||
return m.sendFn(to, subject, htmlBody, textBody)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockPushSender struct {
|
||||
sendFn func(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error
|
||||
}
|
||||
|
||||
func (m *mockPushSender) SendToAll(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error {
|
||||
if m.sendFn != nil {
|
||||
return m.sendFn(ctx, iosTokens, androidTokens, title, message, data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockNotificationRepo struct {
|
||||
findPrefsFn func(userID uint) (*models.NotificationPreference, error)
|
||||
getTokensFn func(userID uint) ([]string, []string, error)
|
||||
}
|
||||
|
||||
func (m *mockNotificationRepo) FindPreferencesByUser(userID uint) (*models.NotificationPreference, error) {
|
||||
if m.findPrefsFn != nil {
|
||||
return m.findPrefsFn(userID)
|
||||
}
|
||||
return &models.NotificationPreference{}, nil
|
||||
}
|
||||
|
||||
func (m *mockNotificationRepo) GetActiveTokensForUser(userID uint) ([]string, []string, error) {
|
||||
if m.getTokensFn != nil {
|
||||
return m.getTokensFn(userID)
|
||||
}
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
type mockReminderRepo struct {
|
||||
cleanupFn func(daysOld int) (int64, error)
|
||||
batchFn func(keys []repositories.ReminderKey) (map[int]bool, error)
|
||||
logFn func(taskID, userID uint, dueDate time.Time, stage models.ReminderStage, notificationID *uint) (*models.TaskReminderLog, error)
|
||||
}
|
||||
|
||||
func (m *mockReminderRepo) HasSentReminderBatch(keys []repositories.ReminderKey) (map[int]bool, error) {
|
||||
if m.batchFn != nil {
|
||||
return m.batchFn(keys)
|
||||
}
|
||||
return map[int]bool{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReminderRepo) LogReminder(taskID, userID uint, dueDate time.Time, stage models.ReminderStage, notificationID *uint) (*models.TaskReminderLog, error) {
|
||||
if m.logFn != nil {
|
||||
return m.logFn(taskID, userID, dueDate, stage, notificationID)
|
||||
}
|
||||
return &models.TaskReminderLog{}, nil
|
||||
}
|
||||
|
||||
func (m *mockReminderRepo) CleanupOldLogs(daysOld int) (int64, error) {
|
||||
if m.cleanupFn != nil {
|
||||
return m.cleanupFn(daysOld)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type mockOnboardingSender struct {
|
||||
noResFn func() (int, error)
|
||||
noTasksFn func() (int, error)
|
||||
}
|
||||
|
||||
func (m *mockOnboardingSender) CheckAndSendNoResidenceEmails() (int, error) {
|
||||
if m.noResFn != nil {
|
||||
return m.noResFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockOnboardingSender) CheckAndSendNoTasksEmails() (int, error) {
|
||||
if m.noTasksFn != nil {
|
||||
return m.noTasksFn()
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type mockNotificationSender struct {
|
||||
sendFn func(ctx context.Context, userID uint, notificationType models.NotificationType, task *models.Task) error
|
||||
}
|
||||
|
||||
func (m *mockNotificationSender) CreateAndSendTaskNotification(ctx context.Context, userID uint, notificationType models.NotificationType, task *models.Task) error {
|
||||
if m.sendFn != nil {
|
||||
return m.sendFn(ctx, userID, notificationType, task)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Helper to build a handler with mocks ---
|
||||
|
||||
func newTestHandler(opts ...func(*Handler)) *Handler {
|
||||
h := &Handler{
|
||||
config: &config.Config{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(h)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func makeTask(taskType string, payload interface{}) *asynq.Task {
|
||||
data, _ := json.Marshal(payload)
|
||||
return asynq.NewTask(taskType, data)
|
||||
}
|
||||
|
||||
// --- HandleSendEmail tests ---
|
||||
|
||||
func TestHandleSendEmail_Success(t *testing.T) {
|
||||
var called bool
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.emailService = &mockEmailSender{
|
||||
sendFn: func(to, subject, htmlBody, textBody string) error {
|
||||
called = true
|
||||
if to != "test@example.com" {
|
||||
t.Errorf("to = %q, want %q", to, "test@example.com")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
task := makeTask(TypeSendEmail, EmailPayload{
|
||||
To: "test@example.com",
|
||||
Subject: "Hello",
|
||||
HTMLBody: "<b>Hi</b>",
|
||||
TextBody: "Hi",
|
||||
})
|
||||
|
||||
err := h.HandleSendEmail(context.Background(), task)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("expected email service to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSendEmail_InvalidPayload(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.emailService = &mockEmailSender{}
|
||||
})
|
||||
|
||||
task := asynq.NewTask(TypeSendEmail, []byte(`{invalid`))
|
||||
err := h.HandleSendEmail(context.Background(), task)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSendEmail_SendFails(t *testing.T) {
|
||||
sendErr := errors.New("SMTP error")
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.emailService = &mockEmailSender{
|
||||
sendFn: func(_, _, _, _ string) error { return sendErr },
|
||||
}
|
||||
})
|
||||
|
||||
task := makeTask(TypeSendEmail, EmailPayload{To: "a@b.com", Subject: "S"})
|
||||
err := h.HandleSendEmail(context.Background(), task)
|
||||
if !errors.Is(err, sendErr) {
|
||||
t.Errorf("err = %v, want %v", err, sendErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSendEmail_NilService_Noop(t *testing.T) {
|
||||
h := newTestHandler() // emailService is nil
|
||||
|
||||
task := makeTask(TypeSendEmail, EmailPayload{To: "a@b.com", Subject: "S"})
|
||||
err := h.HandleSendEmail(context.Background(), task)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleSendPush tests ---
|
||||
|
||||
func TestHandleSendPush_Success(t *testing.T) {
|
||||
var pushCalled bool
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.notificationRepo = &mockNotificationRepo{
|
||||
getTokensFn: func(userID uint) ([]string, []string, error) {
|
||||
return []string{"ios-token"}, []string{"android-token"}, nil
|
||||
},
|
||||
}
|
||||
h.pushClient = &mockPushSender{
|
||||
sendFn: func(_ context.Context, ios, android []string, title, msg string, data map[string]string) error {
|
||||
pushCalled = true
|
||||
if len(ios) != 1 || ios[0] != "ios-token" {
|
||||
t.Errorf("ios tokens = %v", ios)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
task := makeTask(TypeSendPush, PushPayload{
|
||||
UserID: 1,
|
||||
Title: "Alert",
|
||||
Message: "Hello",
|
||||
Data: map[string]string{"type": "test"},
|
||||
})
|
||||
|
||||
err := h.HandleSendPush(context.Background(), task)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !pushCalled {
|
||||
t.Error("expected push client to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSendPush_InvalidPayload(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.pushClient = &mockPushSender{}
|
||||
})
|
||||
|
||||
task := asynq.NewTask(TypeSendPush, []byte(`{bad`))
|
||||
err := h.HandleSendPush(context.Background(), task)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSendPush_NoTokens(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.notificationRepo = &mockNotificationRepo{
|
||||
getTokensFn: func(userID uint) ([]string, []string, error) {
|
||||
return nil, nil, nil // no tokens
|
||||
},
|
||||
}
|
||||
h.pushClient = &mockPushSender{
|
||||
sendFn: func(_ context.Context, _, _ []string, _, _ string, _ map[string]string) error {
|
||||
t.Error("push should not be called when no tokens")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
task := makeTask(TypeSendPush, PushPayload{UserID: 1, Title: "T", Message: "M"})
|
||||
err := h.HandleSendPush(context.Background(), task)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSendPush_NilClient_Noop(t *testing.T) {
|
||||
h := newTestHandler() // pushClient is nil
|
||||
|
||||
task := makeTask(TypeSendPush, PushPayload{UserID: 1, Title: "T", Message: "M"})
|
||||
err := h.HandleSendPush(context.Background(), task)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleOnboardingEmails tests ---
|
||||
|
||||
func TestHandleOnboardingEmails_Disabled(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.config.Features.OnboardingEmailsEnabled = false
|
||||
})
|
||||
|
||||
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOnboardingEmails_NilService(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.config.Features.OnboardingEmailsEnabled = true
|
||||
// onboardingService is nil
|
||||
})
|
||||
|
||||
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOnboardingEmails_Success(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.config.Features.OnboardingEmailsEnabled = true
|
||||
h.onboardingService = &mockOnboardingSender{
|
||||
noResFn: func() (int, error) { return 2, nil },
|
||||
noTasksFn: func() (int, error) { return 3, nil },
|
||||
}
|
||||
})
|
||||
|
||||
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOnboardingEmails_BothFail(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.config.Features.OnboardingEmailsEnabled = true
|
||||
h.onboardingService = &mockOnboardingSender{
|
||||
noResFn: func() (int, error) { return 0, errors.New("fail1") },
|
||||
noTasksFn: func() (int, error) { return 0, errors.New("fail2") },
|
||||
}
|
||||
})
|
||||
|
||||
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
|
||||
if err == nil {
|
||||
t.Error("expected error when both sub-tasks fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleOnboardingEmails_PartialFail(t *testing.T) {
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.config.Features.OnboardingEmailsEnabled = true
|
||||
h.onboardingService = &mockOnboardingSender{
|
||||
noResFn: func() (int, error) { return 0, errors.New("fail") },
|
||||
noTasksFn: func() (int, error) { return 1, nil },
|
||||
}
|
||||
})
|
||||
|
||||
err := h.HandleOnboardingEmails(context.Background(), asynq.NewTask(TypeOnboardingEmails, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("partial failure should not return error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleReminderLogCleanup tests ---
|
||||
|
||||
func TestHandleReminderLogCleanup_Success(t *testing.T) {
|
||||
var calledDays int
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.reminderRepo = &mockReminderRepo{
|
||||
cleanupFn: func(daysOld int) (int64, error) {
|
||||
calledDays = daysOld
|
||||
return 42, nil
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
err := h.HandleReminderLogCleanup(context.Background(), asynq.NewTask(TypeReminderLogCleanup, nil))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if calledDays != 90 {
|
||||
t.Errorf("cleanup called with %d days, want 90", calledDays)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReminderLogCleanup_Error(t *testing.T) {
|
||||
cleanupErr := errors.New("db error")
|
||||
h := newTestHandler(func(h *Handler) {
|
||||
h.reminderRepo = &mockReminderRepo{
|
||||
cleanupFn: func(daysOld int) (int64, error) { return 0, cleanupErr },
|
||||
}
|
||||
})
|
||||
|
||||
err := h.HandleReminderLogCleanup(context.Background(), asynq.NewTask(TypeReminderLogCleanup, nil))
|
||||
if !errors.Is(err, cleanupErr) {
|
||||
t.Errorf("err = %v, want %v", err, cleanupErr)
|
||||
}
|
||||
}
|
||||
55
internal/worker/jobs/interfaces.go
Normal file
55
internal/worker/jobs/interfaces.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package jobs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/treytartt/honeydue-api/internal/models"
|
||||
"github.com/treytartt/honeydue-api/internal/repositories"
|
||||
)
|
||||
|
||||
// TaskRepo defines task query operations needed by job handlers.
|
||||
type TaskRepo interface {
|
||||
GetOverdueTasks(now time.Time, opts repositories.TaskFilterOptions) ([]models.Task, error)
|
||||
GetDueSoonTasks(now time.Time, daysThreshold int, opts repositories.TaskFilterOptions) ([]models.Task, error)
|
||||
GetActiveTasksForUsers(now time.Time, opts repositories.TaskFilterOptions) ([]models.Task, error)
|
||||
}
|
||||
|
||||
// ResidenceRepo defines residence query operations needed by job handlers.
|
||||
type ResidenceRepo interface {
|
||||
FindResidenceIDsByUser(userID uint) ([]uint, error)
|
||||
}
|
||||
|
||||
// ReminderRepo defines reminder log operations needed by job handlers.
|
||||
type ReminderRepo interface {
|
||||
HasSentReminderBatch(keys []repositories.ReminderKey) (map[int]bool, error)
|
||||
LogReminder(taskID, userID uint, dueDate time.Time, stage models.ReminderStage, notificationID *uint) (*models.TaskReminderLog, error)
|
||||
CleanupOldLogs(daysOld int) (int64, error)
|
||||
}
|
||||
|
||||
// NotificationRepo defines notification preference operations needed by job handlers.
|
||||
type NotificationRepo interface {
|
||||
FindPreferencesByUser(userID uint) (*models.NotificationPreference, error)
|
||||
GetActiveTokensForUser(userID uint) ([]string, []string, error)
|
||||
}
|
||||
|
||||
// NotificationSender creates and sends task notifications.
|
||||
type NotificationSender interface {
|
||||
CreateAndSendTaskNotification(ctx context.Context, userID uint, notificationType models.NotificationType, task *models.Task) error
|
||||
}
|
||||
|
||||
// PushSender sends push notifications to device tokens.
|
||||
type PushSender interface {
|
||||
SendToAll(ctx context.Context, iosTokens, androidTokens []string, title, message string, data map[string]string) error
|
||||
}
|
||||
|
||||
// EmailSender sends emails.
|
||||
type EmailSender interface {
|
||||
SendEmail(to, subject, htmlBody, textBody string) error
|
||||
}
|
||||
|
||||
// OnboardingEmailSender sends onboarding campaign emails.
|
||||
type OnboardingEmailSender interface {
|
||||
CheckAndSendNoResidenceEmails() (int, error)
|
||||
CheckAndSendNoTasksEmails() (int, error)
|
||||
}
|
||||
@@ -1,8 +1,6 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -58,10 +56,7 @@ func (c *TaskClient) Close() error {
|
||||
|
||||
// EnqueueWelcomeEmail enqueues a welcome email task
|
||||
func (c *TaskClient) EnqueueWelcomeEmail(to, firstName, code string) error {
|
||||
payload, err := json.Marshal(WelcomeEmailPayload{
|
||||
EmailPayload: EmailPayload{To: to, FirstName: firstName},
|
||||
ConfirmationCode: code,
|
||||
})
|
||||
payload, err := BuildWelcomeEmailPayload(to, firstName, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -79,10 +74,7 @@ func (c *TaskClient) EnqueueWelcomeEmail(to, firstName, code string) error {
|
||||
|
||||
// EnqueueVerificationEmail enqueues a verification email task
|
||||
func (c *TaskClient) EnqueueVerificationEmail(to, firstName, code string) error {
|
||||
payload, err := json.Marshal(VerificationEmailPayload{
|
||||
EmailPayload: EmailPayload{To: to, FirstName: firstName},
|
||||
Code: code,
|
||||
})
|
||||
payload, err := BuildVerificationEmailPayload(to, firstName, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -100,11 +92,7 @@ func (c *TaskClient) EnqueueVerificationEmail(to, firstName, code string) error
|
||||
|
||||
// EnqueuePasswordResetEmail enqueues a password reset email task
|
||||
func (c *TaskClient) EnqueuePasswordResetEmail(to, firstName, code, resetToken string) error {
|
||||
payload, err := json.Marshal(PasswordResetEmailPayload{
|
||||
EmailPayload: EmailPayload{To: to, FirstName: firstName},
|
||||
Code: code,
|
||||
ResetToken: resetToken,
|
||||
})
|
||||
payload, err := BuildPasswordResetEmailPayload(to, firstName, code, resetToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,7 +110,7 @@ func (c *TaskClient) EnqueuePasswordResetEmail(to, firstName, code, resetToken s
|
||||
|
||||
// EnqueuePasswordChangedEmail enqueues a password changed confirmation email
|
||||
func (c *TaskClient) EnqueuePasswordChangedEmail(to, firstName string) error {
|
||||
payload, err := json.Marshal(EmailPayload{To: to, FirstName: firstName})
|
||||
payload, err := BuildPasswordChangedEmailPayload(to, firstName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
110
internal/worker/scheduler_test.go
Normal file
110
internal/worker/scheduler_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Payload roundtrip tests ---
|
||||
|
||||
func TestWelcomeEmailPayload_MarshalRoundtrip(t *testing.T) {
|
||||
original := WelcomeEmailPayload{
|
||||
EmailPayload: EmailPayload{To: "a@b.com", FirstName: "Alice"},
|
||||
ConfirmationCode: "ABC123",
|
||||
}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var got WelcomeEmailPayload
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.To != original.To || got.FirstName != original.FirstName || got.ConfirmationCode != original.ConfirmationCode {
|
||||
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerificationEmailPayload_MarshalRoundtrip(t *testing.T) {
|
||||
original := VerificationEmailPayload{
|
||||
EmailPayload: EmailPayload{To: "b@c.com", FirstName: "Bob"},
|
||||
Code: "999888",
|
||||
}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var got VerificationEmailPayload
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.To != original.To || got.FirstName != original.FirstName || got.Code != original.Code {
|
||||
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetEmailPayload_MarshalRoundtrip(t *testing.T) {
|
||||
original := PasswordResetEmailPayload{
|
||||
EmailPayload: EmailPayload{To: "c@d.com", FirstName: "Carol"},
|
||||
Code: "XYZ",
|
||||
ResetToken: "tok-abc-123",
|
||||
}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var got PasswordResetEmailPayload
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.To != original.To || got.FirstName != original.FirstName || got.Code != original.Code || got.ResetToken != original.ResetToken {
|
||||
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordChangedEmailPayload_MarshalRoundtrip(t *testing.T) {
|
||||
original := EmailPayload{To: "d@e.com", FirstName: "Dave"}
|
||||
data, err := json.Marshal(original)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
var got EmailPayload
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if got.To != original.To || got.FirstName != original.FirstName {
|
||||
t.Errorf("roundtrip mismatch: got %+v, want %+v", got, original)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Task type constant tests ---
|
||||
|
||||
func TestTaskTypeConstants_Unique(t *testing.T) {
|
||||
types := []string{
|
||||
TypeWelcomeEmail,
|
||||
TypeVerificationEmail,
|
||||
TypePasswordResetEmail,
|
||||
TypePasswordChangedEmail,
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, typ := range types {
|
||||
if seen[typ] {
|
||||
t.Errorf("duplicate task type: %q", typ)
|
||||
}
|
||||
seen[typ] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskTypeConstants_EmailPrefix(t *testing.T) {
|
||||
types := []string{
|
||||
TypeWelcomeEmail,
|
||||
TypeVerificationEmail,
|
||||
TypePasswordResetEmail,
|
||||
TypePasswordChangedEmail,
|
||||
}
|
||||
for _, typ := range types {
|
||||
if len(typ) < 6 || typ[:6] != "email:" {
|
||||
t.Errorf("task type %q does not have 'email:' prefix", typ)
|
||||
}
|
||||
}
|
||||
}
|
||||
123
pkg/utils/logger_test.go
Normal file
123
pkg/utils/logger_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInitLogger_Debug(t *testing.T) {
|
||||
InitLogger(true)
|
||||
assert.Equal(t, zerolog.DebugLevel, zerolog.GlobalLevel())
|
||||
}
|
||||
|
||||
func TestInitLogger_Production(t *testing.T) {
|
||||
InitLogger(false)
|
||||
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
|
||||
}
|
||||
|
||||
func TestInitLoggerWithWriter_AdditionalWriter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
InitLoggerWithWriter(false, &buf)
|
||||
|
||||
log.Info().Msg("test message")
|
||||
|
||||
// The additional writer should have received JSON output
|
||||
assert.Contains(t, buf.String(), "test message")
|
||||
}
|
||||
|
||||
func TestInitLoggerWithWriter_Debug_AdditionalWriter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
InitLoggerWithWriter(true, &buf)
|
||||
|
||||
log.Info().Msg("debug test")
|
||||
|
||||
// Additional writer receives JSON even in debug mode
|
||||
assert.Contains(t, buf.String(), "debug test")
|
||||
}
|
||||
|
||||
func TestInitLoggerWithWriter_NilWriter(t *testing.T) {
|
||||
// Should not panic with nil additional writer
|
||||
InitLoggerWithWriter(false, nil)
|
||||
assert.Equal(t, zerolog.InfoLevel, zerolog.GlobalLevel())
|
||||
}
|
||||
|
||||
func TestEchoLogger_ReturnsMiddleware(t *testing.T) {
|
||||
mw := EchoLogger()
|
||||
assert.NotNil(t, mw)
|
||||
|
||||
// Test that it processes requests without error
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/health/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestEchoRecovery_ReturnsMiddleware(t *testing.T) {
|
||||
mw := EchoRecovery()
|
||||
assert.NotNil(t, mw)
|
||||
|
||||
// Test panic recovery
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/panic/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
// Should not panic — the middleware recovers
|
||||
assert.NotPanics(t, func() {
|
||||
_ = handler(c)
|
||||
})
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Contains(t, rec.Body.String(), "Internal server error")
|
||||
}
|
||||
|
||||
func TestEchoLogger_WithQueryParams(t *testing.T) {
|
||||
mw := EchoLogger()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/tasks/?limit=10&offset=0", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestEchoLogger_ErrorStatus(t *testing.T) {
|
||||
mw := EchoLogger()
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/notfound/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
handler := mw(func(c echo.Context) error {
|
||||
return c.String(http.StatusNotFound, "not found")
|
||||
})
|
||||
|
||||
err := handler(c)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
Reference in New Issue
Block a user